Compare commits

..

29 Commits

Author SHA1 Message Date
Bentlybro
69c9136060 Improve LLM registry consistency and frontend UX
Backend: Refactored LLM registry state updates to use atomic swaps for consistency, made Redis notification publishing async, and improved schema/discriminator mapping access to prevent external mutation. Added stricter slug validation for model creation. Frontend: Enhanced Edit and Delete Model modals to refresh data after actions and show error states, and wrapped the LLM Registry Dashboard in an error boundary for better error handling.
2026-01-12 12:52:40 +00:00
Bentlybro
6ed8bb4f14 Clarify custom pricing override for LLM migrations
Improved documentation and comments for the custom_credit_cost field in backend, frontend, and schema files to clarify its use as a billing override during LLM model migrations. Also removed unused LLM registry types and API methods from frontend code, and renamed useLlmRegistryPage.ts to getLlmRegistryPage.ts for consistency.
2026-01-12 11:40:49 +00:00
Bentlybro
6cf28e58d3 Improve LLM model default selection and admin actions
Backend logic for selecting the default LLM model now prioritizes the recommended model, with improved fallbacks and error handling if no models are enabled. The migration enforces a single recommended model at the database level. Frontend admin actions for LLM models and providers now correctly interpret form values for boolean fields and fix the return type for the delete action.
2026-01-09 15:18:54 +00:00
Bentlybro
632ef24408 Add recommended LLM model feature to admin UI and API
Introduces the ability for admins to mark a model as the recommended default via a new boolean field `isRecommended` on LlmModel. Adds backend endpoints and logic to set, get, and persist the recommended model, including a migration and schema update. Updates the frontend admin UI to allow selecting and displaying the recommended model, and reflects the recommended status in model tables and dropdowns.
2026-01-07 19:43:16 +00:00
Bentlybro
6dc767aafa Improve admin LLM registry UX and error handling
Adds user feedback and error handling to LLM registry modals (add/edit creator, model, provider) in the admin UI, including loading states and error messages. Ensures atomic updates for model costs in the backend using transactions. Improves display of creator website URLs and handles the case where no LLM models are available in analytics config. Updates icon usage and removes unnecessary 'use server' directive.
2026-01-07 14:17:37 +00:00
Bentlybro
23e37fd163 Replace delete button with DeleteCreatorModal
Refactored the creator deletion flow in CreatorsTable to use a new DeleteCreatorModal component, providing a confirmation dialog and improved error handling. The previous DeleteCreatorButton was removed and replaced for better user experience and safety.
2026-01-06 14:22:21 +00:00
Bentlybro
63869fe710 format 2026-01-06 13:40:16 +00:00
Bentlybro
90ae75d475 Delete settings.local.json 2026-01-06 13:07:46 +00:00
Bentlybro
9b6dc3be12 prettier 2026-01-06 13:01:51 +00:00
Bentlybro
9b8b6252c5 Refactor LLM registry admin backend and frontend
Refactored backend imports and test mocks to use new admin LLM routes location. Cleaned up and reordered imports for clarity and consistency. Improved code formatting and readability across backend and frontend files. Renamed useLlmRegistryPage to getLlmRegistryPageData for clarity and updated all usages. No functional changes to business logic.
2026-01-06 12:57:33 +00:00
Bentlybro
0d321323f5 Add GPT-5.2 model and admin LLM endpoints
Introduces a migration to add the GPT-5.2 model and updates the O3 model slug in the database. Refactors backend LLM model registry usage for search and migration logic. Expands the OpenAPI spec with new admin endpoints for managing LLM models, providers, creators, and migrations.
2026-01-06 12:46:20 +00:00
Bentlybro
3ee3ea8f02 Merge branch 'dev' into add-llm-manager-ui 2026-01-06 10:28:43 +00:00
Bentlybro
7a842d35ae Refactor LLM admin to use generated API and types
Replaces usage of the custom BackendApi client and legacy types in admin LLM actions and components with generated OpenAPI endpoints and types. Updates API calls, error handling, and type imports throughout the admin LLM dashboard. Also corrects operationId fields in backend routes and OpenAPI spec for consistency.
2026-01-06 09:43:15 +00:00
Bentlybro
07e8568f57 Refactor LLM admin UI for improved consistency and API support
Refactored admin LLM actions and components to improve code organization, update color schemes to use design tokens, and enhance UI consistency. Updated API types and endpoints to support model creators and migrations, and switched tables to use shared Table components. Added and documented new API endpoints for model migrations, creators, and usage in openapi.json.
2026-01-05 17:10:04 +00:00
Bentlybro
13a0caa5d8 Improve model modal UX and credential provider selection
Add auto-selection of creator based on provider in AddModelModal for better usability. Update EditModelModal to use a select dropdown for credential provider, add helper text, and set credential_type as a hidden default input.
2026-01-05 16:01:36 +00:00
Bentlybro
664523a721 Refactor LLM model cost and update logic, remove 'Enabled' checkbox
Improves backend handling of LLM model cost updates by separating scalar and relation field updates, ensuring costs are deleted and recreated as needed. Optional cost fields are now only included if present, and metadata is handled as a Prisma Json type. On the frontend, removes the 'Enabled' checkbox from the EditModelModal component.
2026-01-05 15:56:45 +00:00
Bentlybro
33b103d09b Improve LLM model migration and add AgentNode index
Refactored model migration and revert logic for atomicity and consistency, including transactional node selection and updates. Enhanced revert API to support optional re-enabling of source models and reporting of nodes not reverted. Added a database index on AgentNode.constantInput->>'model' to optimize migration queries and performance.
2026-01-05 15:22:33 +00:00
Bentlybro
2e3fc99caa Add LLM model creator support to registry and admin UI
Introduces the LlmModelCreator entity to distinguish model creators (e.g., OpenAI, Meta) from providers, with full CRUD API endpoints, database migration, and Prisma schema updates. Backend and frontend are updated to support associating models with creators, including admin UI for managing creators and selecting them when creating or editing models. Existing models are backfilled with known creators via migration.
2026-01-05 10:17:00 +00:00
Bently
52c7b223df Add migration management for LLM models
Introduced a new LlmModelMigration model to track migrations when disabling LLM models, allowing for revert capability. Updated the toggle model API to create migration records with optional reason and custom pricing. Added endpoints for listing and reverting migrations, along with corresponding frontend actions and UI components to manage migrations effectively. Enhanced the admin dashboard to display active migrations, improving overall usability and tracking of model changes.
2025-12-19 00:06:03 +00:00
Bently
24d86fde30 Enhance LLM model toggle functionality with migration support
Updated the toggle LLM model API to include an optional migration feature, allowing workflows to be migrated to a specified replacement model when disabling a model. Refactored related request and response models to accommodate this change. Improved error handling and logging for better debugging. Updated frontend actions and components to support the new migration parameter.
2025-12-18 23:32:41 +00:00
Bentlybro
df7be39724 Refactor add model/provider forms to modal dialogs
Replaces AddModelForm and AddProviderForm components with AddModelModal and AddProviderModal, converting the add model/provider flows to use modal dialogs instead of inline forms. Updates LlmRegistryDashboard to use the new modal components and removes dropdown/form selection logic for a cleaner UI.
2025-12-13 19:39:30 +00:00
Bentlybro
8c7b1af409 Refactor LLM registry to modular structure and improve admin UI
Moved LLM registry backend code into a dedicated llm_registry module with submodules for model types, notifications, schema utilities, and registry logic. Updated all backend imports to use the new structure. On the frontend, redesigned the admin LLM registry page with a dashboard layout, modularized data fetching, and improved forms for adding/editing providers and models. Updated UI components for better usability and maintainability.
2025-12-12 11:32:28 +00:00
Bentlybro
b6e2f05b63 Refactor LlmModel to support dynamic registry slugs
Replaces hardcoded LlmModel enum values with a dynamic approach that accepts any model slug from the registry. Updates block defaults to use a default_factory method that pulls the preferred model from the registry. Refactors model validation, migration, and admin analytics routes to use registry-based model lists, ensuring only enabled models are selectable and recommended. Adds get_default_model_slug to llm_registry for consistent default selection.
2025-12-09 15:49:44 +00:00
Bentlybro
7435739053 Add fallback logic for disabled LLM models
Introduces fallback selection for disabled LLM models in llm_call, preferring enabled models from the same provider. Updates registry utilities to support fallback lookup, model info retrieval, and validation of all known model slugs. Schema utilities now keep all known models in validation enums while showing only enabled models in UI options.
2025-12-08 11:29:31 +00:00
Bentlybro
a97fdba554 Restrict LLM model and provider listings to enabled items
Updated public LLM model and provider listing endpoints to only return enabled models and providers. Refactored database access functions to support filtering by enabled status, and improved transaction safety for model deletion. Adjusted tests and internal documentation to reflect these changes.
2025-12-04 15:56:25 +00:00
Bentlybro
ec705bbbcf format 2025-12-02 14:49:03 +00:00
Bentlybro
7fe6b576ae Add LLM model deletion and migration feature
Introduces backend and frontend support for deleting LLM models with automatic workflow migration to a replacement model. Adds API endpoints, database logic, response models, frontend modal, and actions for safe deletion, including usage count display and error handling. Updates table components to use new modal and refactors table imports.
2025-12-02 14:41:13 +00:00
Bentlybro
dfc42003a1 Refactor LLM registry integration and schema updates
Moved LLM registry schema update logic to a shared utility (llm_schema_utils.py) and refactored block and credentials schema post-processing to use this helper. Extracted executor registry initialization and notification handling into llm_registry_init.py for better separation of concerns. Updated manager.py to use new initialization and subscription functions, improving maintainability and clarity of LLM registry refresh logic.
2025-12-01 17:55:43 +00:00
Bentlybro
6bbeb22943 Refactor LLM model registry to use database
Migrates LLM model metadata and cost configuration from static code to a dynamic database-driven registry. Adds new backend modules for LLM registry and model types, updates block and cost configuration logic to fetch model info and costs from the database, and ensures block schemas and UI options reflect enabled/disabled models. This enables dynamic management of LLM models and costs via the admin UI and database migrations.
2025-12-01 14:37:46 +00:00
333 changed files with 13556 additions and 17290 deletions

View File

@@ -1,9 +0,0 @@
{
"permissions": {
"allow": [
"Bash(ls:*)",
"WebFetch(domain:langfuse.com)",
"Bash(poetry install:*)"
]
}
}

View File

@@ -1,4 +1,4 @@
.PHONY: start-core stop-core logs-core format lint migrate run-backend stop-backend run-frontend load-store-agents backfill-store-embeddings
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend load-store-agents
# Run just Supabase + Redis + RabbitMQ
start-core:
@@ -34,14 +34,7 @@ migrate:
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
stop-backend:
@echo "Stopping backend processes..."
@cd backend && poetry run cli stop 2>/dev/null || true
@echo "Killing any processes using backend ports..."
@lsof -ti:8001,8002,8003,8004,8005,8006,8007 | xargs kill -9 2>/dev/null || true
@echo "Backend stopped"
run-backend: stop-backend
run-backend:
cd backend && poetry run app
run-frontend:
@@ -53,9 +46,6 @@ test-data:
load-store-agents:
cd backend && poetry run load-store-agents
backfill-store-embeddings:
cd backend && poetry run python -m backend.api.features.store.backfill_embeddings
help:
@echo "Usage: make <target>"
@echo "Targets:"
@@ -65,9 +55,7 @@ help:
@echo " logs-core - Tail the logs for core services"
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
@echo " migrate - Run backend database migrations"
@echo " stop-backend - Stop any running backend processes"
@echo " run-backend - Run the backend FastAPI server (stops existing processes first)"
@echo " run-backend - Run the backend FastAPI server"
@echo " run-frontend - Run the frontend Next.js development server"
@echo " test-data - Run the test data creator"
@echo " load-store-agents - Load store agents from agents/ folder into test database"
@echo " backfill-store-embeddings - Generate embeddings for store agents that don't have them"
@echo " load-store-agents - Load store agents from agents/ folder into test database"

View File

@@ -58,13 +58,6 @@ V0_API_KEY=
OPEN_ROUTER_API_KEY=
NVIDIA_API_KEY=
# Langfuse Prompt Management
# Used for managing the CoPilot system prompt externally
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
LANGFUSE_PUBLIC_KEY=
LANGFUSE_SECRET_KEY=
LANGFUSE_HOST=https://cloud.langfuse.com
# OAuth Credentials
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
# e.g. http://localhost:3000/auth/integrations/oauth_callback

View File

@@ -122,6 +122,24 @@ class ConnectionManager:
return len(connections)
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
"""Broadcast a message to all active websocket connections."""
message = WSMessage(
method=method,
data=data,
).model_dump_json()
connections = tuple(self.active_connections)
if not connections:
return 0
await asyncio.gather(
*(connection.send_text(message) for connection in connections),
return_exceptions=True,
)
return len(connections)
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
if channel_key not in self.subscriptions:
self.subscriptions[channel_key] = set()

View File

@@ -173,30 +173,64 @@ async def get_execution_analytics_config(
# Return with provider prefix for clarity
return f"{provider_name}: {model_name}"
# Include all LlmModel values (no more filtering by hardcoded list)
recommended_model = LlmModel.GPT4O_MINI.value
for model in LlmModel:
label = generate_model_label(model)
# Get all models from the registry (dynamic, not hardcoded enum)
from backend.data import llm_registry
from backend.server.v2.llm import db as llm_db
# Get the recommended model from the database (configurable via admin UI)
recommended_model_slug = await llm_db.get_recommended_model_slug()
# Build the available models list
first_enabled_slug = None
for registry_model in llm_registry.iter_dynamic_models():
# Only include enabled models in the list
if not registry_model.is_enabled:
continue
# Track first enabled model as fallback
if first_enabled_slug is None:
first_enabled_slug = registry_model.slug
model_enum = LlmModel(registry_model.slug) # Create enum instance from slug
label = generate_model_label(model_enum)
# Add "(Recommended)" suffix to the recommended model
if model.value == recommended_model:
if registry_model.slug == recommended_model_slug:
label += " (Recommended)"
available_models.append(
ModelInfo(
value=model.value,
value=registry_model.slug,
label=label,
provider=model.provider,
provider=registry_model.metadata.provider,
)
)
# Sort models by provider and name for better UX
available_models.sort(key=lambda x: (x.provider, x.label))
# Handle case where no models are available
if not available_models:
logger.warning(
"No enabled LLM models found in registry. "
"Ensure models are configured and enabled in the LLM Registry."
)
# Provide a placeholder entry so admins see meaningful feedback
available_models.append(
ModelInfo(
value="",
label="No models available - configure in LLM Registry",
provider="none",
)
)
# Use the DB recommended model, or fallback to first enabled model
final_recommended = recommended_model_slug or first_enabled_slug or ""
return ExecutionAnalyticsConfig(
available_models=available_models,
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
default_user_prompt=DEFAULT_USER_PROMPT,
recommended_model=recommended_model,
recommended_model=final_recommended,
)

View File

@@ -0,0 +1,554 @@
import logging
import autogpt_libs.auth
import fastapi
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.server.v2.llm import db as llm_db
from backend.server.v2.llm import model as llm_model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
prefix="/admin/llm",
tags=["llm", "admin"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
)
async def _refresh_runtime_state() -> None:
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
logger.info("Refreshing LLM registry runtime state...")
# Refresh registry from database
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated model options
from backend.data.block import BlockSchema
BlockSchema.clear_all_schema_caches()
logger.info("Cleared all block schema caches")
# Clear the /blocks endpoint cache so frontend gets updated schemas
try:
from backend.api.features.v1 import _get_cached_blocks
_get_cached_blocks.cache_clear()
logger.info("Cleared /blocks endpoint cache")
except Exception as e:
logger.warning("Failed to clear /blocks cache: %s", e)
# Clear the v2 builder providers cache (if it exists)
try:
from backend.api.features.builder import db as builder_db
if hasattr(builder_db, "_get_all_providers"):
builder_db._get_all_providers.cache_clear()
logger.info("Cleared v2 builder providers cache")
except Exception as e:
logger.debug("Could not clear v2 builder cache: %s", e)
# Notify all executor services to refresh their registry cache
from backend.data.llm_registry import publish_registry_refresh_notification
await publish_registry_refresh_notification()
logger.info("Published registry refresh notification")
@router.get(
"/providers",
summary="List LLM providers",
response_model=llm_model.LlmProvidersResponse,
)
async def list_llm_providers(include_models: bool = True):
providers = await llm_db.list_providers(include_models=include_models)
return llm_model.LlmProvidersResponse(providers=providers)
@router.post(
"/providers",
summary="Create LLM provider",
response_model=llm_model.LlmProvider,
)
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
provider = await llm_db.upsert_provider(request=request)
await _refresh_runtime_state()
return provider
@router.patch(
"/providers/{provider_id}",
summary="Update LLM provider",
response_model=llm_model.LlmProvider,
)
async def update_llm_provider(
provider_id: str,
request: llm_model.UpsertLlmProviderRequest,
):
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
await _refresh_runtime_state()
return provider
@router.get(
"/models",
summary="List LLM models",
response_model=llm_model.LlmModelsResponse,
)
async def list_llm_models(provider_id: str | None = fastapi.Query(default=None)):
models = await llm_db.list_models(provider_id=provider_id)
return llm_model.LlmModelsResponse(models=models)
@router.post(
"/models",
summary="Create LLM model",
response_model=llm_model.LlmModel,
)
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
model = await llm_db.create_model(request=request)
await _refresh_runtime_state()
return model
@router.patch(
"/models/{model_id}",
summary="Update LLM model",
response_model=llm_model.LlmModel,
)
async def update_llm_model(
model_id: str,
request: llm_model.UpdateLlmModelRequest,
):
model = await llm_db.update_model(model_id=model_id, request=request)
await _refresh_runtime_state()
return model
@router.patch(
"/models/{model_id}/toggle",
summary="Toggle LLM model availability",
response_model=llm_model.ToggleLlmModelResponse,
)
async def toggle_llm_model(
model_id: str,
request: llm_model.ToggleLlmModelRequest,
):
"""
Toggle a model's enabled status, optionally migrating workflows when disabling.
If disabling a model and `migrate_to_slug` is provided, all workflows using
this model will be migrated to the specified replacement model before disabling.
A migration record is created which can be reverted later using the revert endpoint.
Optional fields:
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
- `custom_credit_cost`: Custom pricing override for billing during migration
"""
try:
result = await llm_db.toggle_model(
model_id=model_id,
is_enabled=request.is_enabled,
migrate_to_slug=request.migrate_to_slug,
migration_reason=request.migration_reason,
custom_credit_cost=request.custom_credit_cost,
)
await _refresh_runtime_state()
if result.nodes_migrated > 0:
logger.info(
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
result.model.slug,
"enabled" if request.is_enabled else "disabled",
result.nodes_migrated,
result.migrated_to_slug,
result.migration_id,
)
return result
except ValueError as exc:
logger.warning("Model toggle validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to toggle model availability",
) from exc
@router.get(
"/models/{model_id}/usage",
summary="Get model usage count",
response_model=llm_model.LlmModelUsageResponse,
)
async def get_llm_model_usage(model_id: str):
"""Get the number of workflow nodes using this model."""
try:
return await llm_db.get_model_usage(model_id=model_id)
except ValueError as exc:
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to get model usage %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get model usage",
) from exc
@router.delete(
"/models/{model_id}",
summary="Delete LLM model and migrate workflows",
response_model=llm_model.DeleteLlmModelResponse,
)
async def delete_llm_model(
model_id: str,
replacement_model_slug: str = fastapi.Query(
..., description="Slug of the model to migrate existing workflows to"
),
):
"""
Delete a model and automatically migrate all workflows using it to a replacement model.
This endpoint:
1. Validates the replacement model exists and is enabled
2. Counts how many workflow nodes use the model being deleted
3. Updates all AgentNode.constantInput->model fields to the replacement
4. Deletes the model record
5. Refreshes all caches and notifies executors
Example: DELETE /admin/llm/models/{id}?replacement_model_slug=gpt-4o
"""
try:
result = await llm_db.delete_model(
model_id=model_id, replacement_model_slug=replacement_model_slug
)
await _refresh_runtime_state()
logger.info(
"Deleted model '%s' and migrated %d nodes to '%s'",
result.deleted_model_slug,
result.nodes_migrated,
result.replacement_model_slug,
)
return result
except ValueError as exc:
# Validation errors (model not found, replacement invalid, etc.)
logger.warning("Model deletion validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to delete model and migrate workflows",
) from exc
# ============================================================================
# Migration Management Endpoints
# ============================================================================
@router.get(
"/migrations",
summary="List model migrations",
response_model=llm_model.LlmMigrationsResponse,
)
async def list_llm_migrations(
include_reverted: bool = fastapi.Query(
default=False, description="Include reverted migrations in the list"
),
):
"""
List all model migrations.
Migrations are created when disabling a model with the migrate_to_slug option.
They can be reverted to restore the original model configuration.
"""
try:
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
return llm_model.LlmMigrationsResponse(migrations=migrations)
except Exception as exc:
logger.exception("Failed to list migrations: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to list migrations",
) from exc
@router.get(
"/migrations/{migration_id}",
summary="Get migration details",
response_model=llm_model.LlmModelMigration,
)
async def get_llm_migration(migration_id: str):
"""Get details of a specific migration."""
try:
migration = await llm_db.get_migration(migration_id)
if not migration:
raise fastapi.HTTPException(
status_code=404, detail=f"Migration '{migration_id}' not found"
)
return migration
except fastapi.HTTPException:
raise
except Exception as exc:
logger.exception("Failed to get migration %s: %s", migration_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get migration",
) from exc
@router.post(
"/migrations/{migration_id}/revert",
summary="Revert a model migration",
response_model=llm_model.RevertMigrationResponse,
)
async def revert_llm_migration(
migration_id: str,
request: llm_model.RevertMigrationRequest | None = None,
):
"""
Revert a model migration, restoring affected workflows to their original model.
This only reverts the specific nodes that were part of the migration.
The source model must exist for the revert to succeed.
Options:
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
Response includes:
- `nodes_reverted`: Number of nodes successfully reverted
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
- `source_model_re_enabled`: Whether the source model was re-enabled
Requirements:
- Migration must not already be reverted
- Source model must exist
"""
try:
re_enable = request.re_enable_source_model if request else True
result = await llm_db.revert_migration(
migration_id,
re_enable_source_model=re_enable,
)
await _refresh_runtime_state()
logger.info(
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
"(%d already changed, source re-enabled=%s)",
migration_id,
result.nodes_reverted,
result.target_model_slug,
result.source_model_slug,
result.nodes_already_changed,
result.source_model_re_enabled,
)
return result
except ValueError as exc:
logger.warning("Migration revert validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to revert migration",
) from exc
# ============================================================================
# Creator Management Endpoints
# ============================================================================
@router.get(
"/creators",
summary="List model creators",
response_model=llm_model.LlmCreatorsResponse,
)
async def list_llm_creators():
"""
List all model creators.
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
This is distinct from providers who host/serve the models (e.g., OpenRouter).
"""
try:
creators = await llm_db.list_creators()
return llm_model.LlmCreatorsResponse(creators=creators)
except Exception as exc:
logger.exception("Failed to list creators: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to list creators",
) from exc
@router.get(
"/creators/{creator_id}",
summary="Get creator details",
operation_id="getV2GetLlmCreatorDetails",
response_model=llm_model.LlmModelCreator,
)
async def get_llm_creator(creator_id: str):
"""Get details of a specific model creator."""
try:
creator = await llm_db.get_creator(creator_id)
if not creator:
raise fastapi.HTTPException(
status_code=404, detail=f"Creator '{creator_id}' not found"
)
return creator
except fastapi.HTTPException:
raise
except Exception as exc:
logger.exception("Failed to get creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get creator",
) from exc
@router.post(
"/creators",
summary="Create model creator",
response_model=llm_model.LlmModelCreator,
)
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
"""
Create a new model creator.
A creator represents an organization that creates/trains AI models,
such as OpenAI, Anthropic, Meta, or Google.
"""
try:
creator = await llm_db.upsert_creator(request=request)
await _refresh_runtime_state()
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
return creator
except Exception as exc:
logger.exception("Failed to create creator: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to create creator",
) from exc
@router.patch(
"/creators/{creator_id}",
summary="Update model creator",
response_model=llm_model.LlmModelCreator,
)
async def update_llm_creator(
creator_id: str,
request: llm_model.UpsertLlmCreatorRequest,
):
"""Update an existing model creator."""
try:
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
await _refresh_runtime_state()
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
return creator
except Exception as exc:
logger.exception("Failed to update creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to update creator",
) from exc
@router.delete(
"/creators/{creator_id}",
summary="Delete model creator",
response_model=dict,
)
async def delete_llm_creator(creator_id: str):
"""
Delete a model creator.
This will remove the creator association from all models that reference it
(sets creatorId to NULL), but will not delete the models themselves.
"""
try:
await llm_db.delete_creator(creator_id)
await _refresh_runtime_state()
logger.info("Deleted model creator '%s'", creator_id)
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
except ValueError as exc:
logger.warning("Creator deletion validation failed: %s", exc)
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to delete creator",
) from exc
# ============================================================================
# Recommended Model Endpoints
# ============================================================================
@router.get(
"/recommended-model",
summary="Get recommended model",
response_model=llm_model.RecommendedModelResponse,
)
async def get_recommended_model():
"""
Get the currently recommended LLM model.
The recommended model is shown to users as the default/suggested option
in model selection dropdowns.
"""
try:
model = await llm_db.get_recommended_model()
return llm_model.RecommendedModelResponse(
model=model,
slug=model.slug if model else None,
)
except Exception as exc:
logger.exception("Failed to get recommended model: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get recommended model",
) from exc
@router.post(
"/recommended-model",
summary="Set recommended model",
response_model=llm_model.SetRecommendedModelResponse,
)
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
"""
Set a model as the recommended model.
This clears the recommended flag from any other model and sets it on
the specified model. The model must be enabled to be set as recommended.
The recommended model is displayed to users as the default/suggested
option in model selection dropdowns throughout the platform.
"""
try:
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
await _refresh_runtime_state()
logger.info(
"Set recommended model to '%s' (previous: %s)",
model.slug,
previous_slug or "none",
)
return llm_model.SetRecommendedModelResponse(
model=model,
previous_recommended_slug=previous_slug,
message=f"Model '{model.display_name}' is now the recommended model",
)
except ValueError as exc:
logger.warning("Set recommended model validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to set recommended model: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to set recommended model",
) from exc

View File

@@ -0,0 +1,405 @@
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
import backend.api.features.admin.llm_routes as llm_routes
app = fastapi.FastAPI()
app.include_router(llm_routes.router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_list_llm_providers_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful listing of LLM providers"""
# Mock the database function
mock_providers = [
{
"id": "provider-1",
"name": "openai",
"display_name": "OpenAI",
"description": "OpenAI LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": True,
"metadata": {},
"models": [],
},
{
"id": "provider-2",
"name": "anthropic",
"display_name": "Anthropic",
"description": "Anthropic LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": True,
"metadata": {},
"models": [],
},
]
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.list_providers",
new=AsyncMock(return_value=mock_providers),
)
response = client.get("/admin/llm/providers")
assert response.status_code == 200
response_data = response.json()
assert len(response_data["providers"]) == 2
assert response_data["providers"][0]["name"] == "openai"
# Snapshot test the response
configured_snapshot.assert_match(response_data, "list_llm_providers_success.json")
def test_list_llm_models_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful listing of LLM models"""
# Mock the database function
mock_models = [
{
"id": "model-1",
"slug": "gpt-4o",
"display_name": "GPT-4o",
"description": "GPT-4 Optimized",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-1",
"credit_cost": 10,
"credential_provider": "openai",
"metadata": {},
}
],
}
]
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.list_models",
new=AsyncMock(return_value=mock_models),
)
response = client.get("/admin/llm/models")
assert response.status_code == 200
response_data = response.json()
assert len(response_data["models"]) == 1
assert response_data["models"][0]["slug"] == "gpt-4o"
# Snapshot test the response
configured_snapshot.assert_match(response_data, "list_llm_models_success.json")
def test_create_llm_provider_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful creation of LLM provider"""
mock_provider = {
"id": "new-provider-id",
"name": "groq",
"display_name": "Groq",
"description": "Groq LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": False,
"metadata": {},
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
new=AsyncMock(return_value=mock_provider),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"name": "groq",
"display_name": "Groq",
"description": "Groq LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": False,
"metadata": {},
}
response = client.post("/admin/llm/providers", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["name"] == "groq"
assert response_data["display_name"] == "Groq"
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response
configured_snapshot.assert_match(response_data, "create_llm_provider_success.json")
def test_create_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful creation of LLM model"""
mock_model = {
"id": "new-model-id",
"slug": "gpt-4.1-mini",
"display_name": "GPT-4.1 Mini",
"description": "Latest GPT-4.1 Mini model",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-id",
"credit_cost": 5,
"credential_provider": "openai",
"metadata": {},
}
],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.create_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"slug": "gpt-4.1-mini",
"display_name": "GPT-4.1 Mini",
"description": "Latest GPT-4.1 Mini model",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"credit_cost": 5,
"credential_provider": "openai",
"metadata": {},
}
],
}
response = client.post("/admin/llm/models", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["slug"] == "gpt-4.1-mini"
assert response_data["is_enabled"] is True
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response
configured_snapshot.assert_match(response_data, "create_llm_model_success.json")
def test_update_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful update of LLM model"""
mock_model = {
"id": "model-1",
"slug": "gpt-4o",
"display_name": "GPT-4o Updated",
"description": "Updated description",
"provider_id": "provider-1",
"context_window": 256000,
"max_output_tokens": 32768,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-1",
"credit_cost": 15,
"credential_provider": "openai",
"metadata": {},
}
],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.update_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"display_name": "GPT-4o Updated",
"description": "Updated description",
"context_window": 256000,
"max_output_tokens": 32768,
}
response = client.patch("/admin/llm/models/model-1", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["display_name"] == "GPT-4o Updated"
assert response_data["context_window"] == 256000
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response
configured_snapshot.assert_match(response_data, "update_llm_model_success.json")
def test_toggle_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful toggling of LLM model enabled status"""
mock_model = {
"id": "model-1",
"slug": "gpt-4o",
"display_name": "GPT-4o",
"description": "GPT-4 Optimized",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": False,
"capabilities": {},
"metadata": {},
"costs": [],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {"is_enabled": False}
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["is_enabled"] is False
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response
configured_snapshot.assert_match(response_data, "toggle_llm_model_success.json")
def test_delete_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful deletion of LLM model with migration"""
mock_response = {
"deleted_model_slug": "gpt-3.5-turbo",
"deleted_model_display_name": "GPT-3.5 Turbo",
"replacement_model_slug": "gpt-4o-mini",
"nodes_migrated": 42,
"message": "Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(return_value=type("obj", (object,), mock_response)()),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
response = client.delete(
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
)
assert response.status_code == 200
response_data = response.json()
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
assert response_data["nodes_migrated"] == 42
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response
configured_snapshot.assert_match(response_data, "delete_llm_model_success.json")
def test_delete_llm_model_validation_error(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion fails with proper error when validation fails"""
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
)
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
assert response.status_code == 400
assert "Replacement model 'invalid' not found" in response.json()["detail"]
def test_delete_llm_model_missing_replacement(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion fails when replacement_model_slug is not provided"""
response = client.delete("/admin/llm/models/model-1")
# FastAPI will return 422 for missing required query params
assert response.status_code == 422

View File

@@ -9,7 +9,6 @@ import prisma.enums
import backend.api.features.store.cache as store_cache
import backend.api.features.store.db as store_db
import backend.api.features.store.embeddings as store_embeddings
import backend.api.features.store.model as store_model
import backend.util.json
@@ -151,54 +150,3 @@ async def admin_download_agent_file(
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
@router.get(
"/embeddings/stats",
summary="Get Embedding Statistics",
)
async def get_embedding_stats() -> dict[str, typing.Any]:
"""
Get statistics about embedding coverage for store listings.
Returns counts of total approved listings, listings with embeddings,
listings without embeddings, and coverage percentage.
"""
try:
stats = await store_embeddings.get_embedding_stats()
return stats
except Exception as e:
logger.exception("Error getting embedding stats: %s", e)
raise fastapi.HTTPException(
status_code=500,
detail="An error occurred while retrieving embedding stats",
)
@router.post(
"/embeddings/backfill",
summary="Backfill Missing Embeddings",
)
async def backfill_embeddings(
batch_size: int = 10,
) -> dict[str, typing.Any]:
"""
Trigger backfill of embeddings for approved listings that don't have them.
Args:
batch_size: Number of embeddings to generate in one call (default 10)
Returns:
Dict with processed count, success count, failure count, and message
"""
try:
result = await store_embeddings.backfill_missing_embeddings(
batch_size=batch_size
)
return result
except Exception as e:
logger.exception("Error backfilling embeddings: %s", e)
raise fastapi.HTTPException(
status_code=500,
detail="An error occurred while backfilling embeddings",
)

View File

@@ -15,6 +15,7 @@ from backend.blocks import load_all_blocks
from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
from backend.data.db import query_raw_with_schema
from backend.data.llm_registry import get_all_model_slugs_for_validation
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
@@ -31,7 +32,14 @@ from .model import (
)
logger = logging.getLogger(__name__)
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
def _get_llm_models() -> list[str]:
"""Get LLM model names for search matching from the registry."""
return [
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
]
MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
@@ -496,8 +504,8 @@ async def _get_static_counts():
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if field.annotation == LlmModel:
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
# Check if query matches any value in llm_models from registry
if any(query in name for name in _get_llm_models()):
return True
return False

View File

@@ -12,11 +12,7 @@ class ChatConfig(BaseSettings):
# OpenAI API Configuration
model: str = Field(
default="anthropic/claude-opus-4.5", description="Default model to use"
)
title_model: str = Field(
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
@@ -45,13 +41,6 @@ class ChatConfig(BaseSettings):
default=3, description="Maximum number of agent schedules"
)
# Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field(
default="CoPilot Prompt",
description="Name of the prompt in Langfuse to fetch",
)
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):
@@ -83,31 +72,8 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",
"onboarding": "prompts/onboarding_system.md",
}
def get_system_prompt_for_type(
self, prompt_type: str = "default", **template_vars
) -> str:
"""Load and render a system prompt by type.
Args:
prompt_type: The type of prompt to load ("default" or "onboarding")
**template_vars: Variables to substitute in the template
Returns:
Rendered system prompt string
"""
prompt_path_str = self.PROMPT_PATHS.get(
prompt_type, self.PROMPT_PATHS["default"]
)
return self._load_prompt_from_path(prompt_path_str, **template_vars)
def get_system_prompt(self, **template_vars) -> str:
"""Load and render the default system prompt from file.
"""Load and render the system prompt from file.
Args:
**template_vars: Variables to substitute in the template
@@ -116,21 +82,9 @@ class ChatConfig(BaseSettings):
Rendered system prompt string
"""
return self._load_prompt_from_path(self.system_prompt_path, **template_vars)
def _load_prompt_from_path(self, prompt_path_str: str, **template_vars) -> str:
"""Load and render a system prompt from a given path.
Args:
prompt_path_str: Path to the prompt file relative to chat module
**template_vars: Variables to substitute in the template
Returns:
Rendered system prompt string
"""
# Get the path relative to this module
module_dir = Path(__file__).parent
prompt_path = module_dir / prompt_path_str
prompt_path = module_dir / self.system_prompt_path
# Check for .j2 extension first (Jinja2 template)
j2_path = Path(str(prompt_path) + ".j2")

View File

@@ -1,195 +0,0 @@
"""Database operations for chat sessions."""
import logging
from datetime import UTC, datetime
from typing import Any
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import ChatSessionUpdateInput
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
"""Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
include={"Messages": True},
)
if session and session.Messages:
# Sort messages by sequence in Python since Prisma doesn't support order_by in include
session.Messages.sort(key=lambda m: m.sequence)
return session
async def create_chat_session(
session_id: str,
user_id: str | None,
) -> PrismaChatSession:
"""Create a new chat session in the database."""
data = {
"id": session_id,
"userId": user_id,
"credentials": SafeJson({}),
"successfulAgentRuns": SafeJson({}),
"successfulAgentSchedules": SafeJson({}),
}
return await PrismaChatSession.prisma().create(
data=data,
include={"Messages": True},
)
async def update_chat_session(
session_id: str,
credentials: dict[str, Any] | None = None,
successful_agent_runs: dict[str, Any] | None = None,
successful_agent_schedules: dict[str, Any] | None = None,
total_prompt_tokens: int | None = None,
total_completion_tokens: int | None = None,
title: str | None = None,
) -> PrismaChatSession | None:
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
if credentials is not None:
data["credentials"] = SafeJson(credentials)
if successful_agent_runs is not None:
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
if successful_agent_schedules is not None:
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
if total_prompt_tokens is not None:
data["totalPromptTokens"] = total_prompt_tokens
if total_completion_tokens is not None:
data["totalCompletionTokens"] = total_completion_tokens
if title is not None:
data["title"] = title
session = await PrismaChatSession.prisma().update(
where={"id": session_id},
data=data,
include={"Messages": True},
)
if session and session.Messages:
session.Messages.sort(key=lambda m: m.sequence)
return session
async def add_chat_message(
session_id: str,
role: str,
sequence: int,
content: str | None = None,
name: str | None = None,
tool_call_id: str | None = None,
refusal: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
function_call: dict[str, Any] | None = None,
) -> PrismaChatMessage:
"""Add a message to a chat session."""
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": role,
"sequence": sequence,
}
if content is not None:
data["content"] = content
if name is not None:
data["name"] = name
if tool_call_id is not None:
data["toolCallId"] = tool_call_id
if refusal is not None:
data["refusal"] = refusal
if tool_calls is not None:
data["toolCalls"] = SafeJson(tool_calls)
if function_call is not None:
data["functionCall"] = SafeJson(function_call)
# Update session's updatedAt timestamp
await PrismaChatSession.prisma().update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
return await PrismaChatMessage.prisma().create(data=data)
async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> list[PrismaChatMessage]:
"""Add multiple messages to a chat session in a batch."""
if not messages:
return []
created_messages = []
for i, msg in enumerate(messages):
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma().create(data=data)
created_messages.append(created)
# Update session's updatedAt timestamp
await PrismaChatSession.prisma().update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
return created_messages
async def get_user_chat_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> list[PrismaChatSession]:
"""Get chat sessions for a user, ordered by most recent."""
return await PrismaChatSession.prisma().find_many(
where={"userId": user_id},
order={"updatedAt": "desc"},
take=limit,
skip=offset,
)
async def get_user_session_count(user_id: str) -> int:
"""Get the total number of chat sessions for a user."""
return await PrismaChatSession.prisma().count(where={"userId": user_id})
async def delete_chat_session(session_id: str) -> bool:
"""Delete a chat session and all its messages."""
try:
await PrismaChatSession.prisma().delete(where={"id": session_id})
return True
except Exception as e:
logger.error(f"Failed to delete chat session {session_id}: {e}")
return False
async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count

View File

@@ -16,15 +16,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel
from backend.data.redis_client import get_redis_async
from backend.util import json
from backend.util.exceptions import RedisError
from . import db as chat_db
from .config import ChatConfig
logger = logging.getLogger(__name__)
@@ -50,7 +46,6 @@ class Usage(BaseModel):
class ChatSession(BaseModel):
session_id: str
user_id: str | None
title: str | None = None
messages: list[ChatMessage]
usage: list[Usage]
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
@@ -64,7 +59,6 @@ class ChatSession(BaseModel):
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
@@ -72,85 +66,6 @@ class ChatSession(BaseModel):
updated_at=datetime.now(UTC),
)
@staticmethod
def from_prisma(
prisma_session: PrismaChatSession,
prisma_messages: list[PrismaChatMessage] | None = None,
) -> "ChatSession":
"""Convert Prisma models to Pydantic ChatSession."""
messages = []
if prisma_messages:
for msg in prisma_messages:
tool_calls = None
if msg.toolCalls:
tool_calls = (
json.loads(msg.toolCalls)
if isinstance(msg.toolCalls, str)
else msg.toolCalls
)
function_call = None
if msg.functionCall:
function_call = (
json.loads(msg.functionCall)
if isinstance(msg.functionCall, str)
else msg.functionCall
)
messages.append(
ChatMessage(
role=msg.role,
content=msg.content,
name=msg.name,
tool_call_id=msg.toolCallId,
refusal=msg.refusal,
tool_calls=tool_calls,
function_call=function_call,
)
)
# Parse JSON fields from Prisma
credentials = (
json.loads(prisma_session.credentials)
if isinstance(prisma_session.credentials, str)
else prisma_session.credentials or {}
)
successful_agent_runs = (
json.loads(prisma_session.successfulAgentRuns)
if isinstance(prisma_session.successfulAgentRuns, str)
else prisma_session.successfulAgentRuns or {}
)
successful_agent_schedules = (
json.loads(prisma_session.successfulAgentSchedules)
if isinstance(prisma_session.successfulAgentSchedules, str)
else prisma_session.successfulAgentSchedules or {}
)
# Calculate usage from token counts
usage = []
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
usage.append(
Usage(
prompt_tokens=prisma_session.totalPromptTokens or 0,
completion_tokens=prisma_session.totalCompletionTokens or 0,
total_tokens=(prisma_session.totalPromptTokens or 0)
+ (prisma_session.totalCompletionTokens or 0),
)
)
return ChatSession(
session_id=prisma_session.id,
user_id=prisma_session.userId,
title=prisma_session.title,
messages=messages,
usage=usage,
credentials=credentials,
started_at=prisma_session.createdAt,
updated_at=prisma_session.updatedAt,
successful_agent_runs=successful_agent_runs,
successful_agent_schedules=successful_agent_schedules,
)
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
messages = []
for message in self.messages:
@@ -240,234 +155,50 @@ class ChatSession(BaseModel):
return messages
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = f"chat:session:{session_id}"
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
async def _cache_session(session: ChatSession) -> None:
"""Cache a chat session in Redis."""
redis_key = f"chat:session:{session.session_id}"
async_redis = await get_redis_async()
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)
if not prisma_session:
return None
messages = prisma_session.Messages
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_prisma(prisma_session, messages)
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
# Check if session exists in DB
existing = await chat_db.get_chat_session(session.session_id)
if not existing:
# Create new session
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await chat_db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def get_chat_session(
session_id: str,
user_id: str | None,
) -> ChatSession | None:
"""Get a chat session by ID.
"""Get a chat session by ID."""
redis_key = f"chat:session:{session_id}"
async_redis = await get_redis_async()
Checks Redis cache first, falls back to database if not found.
Caches database results back to Redis.
"""
# Try cache first
try:
session = await _get_session_from_cache(session_id)
if session:
# Verify user ownership
if session.user_id is not None and session.user_id != user_id:
logger.warning(
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
)
return None
return session
except RedisError:
logger.warning(f"Cache error for session {session_id}, trying database")
except Exception as e:
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
raw_session: bytes | None = await async_redis.get(redis_key)
# Fall back to database
logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
logger.warning(f"Session {session_id} not found in cache or database")
if raw_session is None:
logger.warning(f"Session {session_id} not found in Redis")
return None
# Verify user ownership
try:
session = ChatSession.model_validate_json(raw_session)
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
if session.user_id is not None and session.user_id != user_id:
logger.warning(
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
)
return None
# Cache the session from DB
try:
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
return session
async def upsert_chat_session(
session: ChatSession,
) -> ChatSession:
"""Update a chat session in both cache and database."""
# Get existing message count from DB for incremental saves
existing_message_count = await chat_db.get_chat_session_message_count(
session.session_id
"""Update a chat session with the given messages."""
redis_key = f"chat:session:{session.session_id}"
async_redis = await get_redis_async()
resp = await async_redis.setex(
redis_key, config.session_ttl, session.model_dump_json()
)
# Save to database
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
logger.error(f"Failed to save session {session.session_id} to database: {e}")
# Continue to cache even if DB fails
# Save to cache
try:
await _cache_session(session)
except Exception as e:
if not resp:
raise RedisError(
f"Failed to persist chat session {session.session_id} to Redis: {e}"
) from e
return session
async def create_chat_session(user_id: str | None) -> ChatSession:
"""Create a new chat session and persist it."""
session = ChatSession.new(user_id)
# Create in database first
try:
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=user_id,
f"Failed to persist chat session {session.session_id} to Redis: {resp}"
)
except Exception as e:
logger.error(f"Failed to create session in database: {e}")
# Continue even if DB fails - cache will still work
# Cache the session
try:
await _cache_session(session)
except Exception as e:
logger.warning(f"Failed to cache new session: {e}")
return session
async def get_user_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> list[ChatSession]:
"""Get all chat sessions for a user from the database."""
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
sessions = []
for prisma_session in prisma_sessions:
# Convert without messages for listing (lighter weight)
sessions.append(ChatSession.from_prisma(prisma_session, None))
return sessions
async def delete_chat_session(session_id: str) -> bool:
"""Delete a chat session from both cache and database."""
# Delete from cache
try:
redis_key = f"chat:session:{session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
# Delete from database
return await chat_db.delete_chat_session(session_id)

View File

@@ -68,50 +68,3 @@ async def test_chatsession_redis_storage_user_id_mismatch():
s2 = await get_chat_session(s.session_id, None)
assert s2 is None
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_db_storage():
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=None)
s.messages = messages # Contains user, assistant, and tool messages
# Upsert to save to both cache and DB
s = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
# Load from DB (cache was cleared)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 is not None, "Session not found after loading from DB"
assert len(s2.messages) == len(
s.messages
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
# Verify all roles are present
roles = [m.role for m in s2.messages]
assert "user" in roles, f"User message missing. Roles found: {roles}"
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
# Verify message content
for orig, loaded in zip(s.messages, s2.messages):
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
assert (
orig.content == loaded.content
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
if orig.tool_calls:
assert (
loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls)

View File

@@ -1,80 +1,12 @@
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find, create, and set up AutoGPT agents to solve their business problems.
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find and set up AutoGPT agents to solve their business problems.
Here are the functions available to you:
<functions>
**Understanding & Discovery:**
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
3. **find_library_agent** - Search the user's personal library of saved agents
4. **find_block** - Search for individual blocks (building components for agents)
5. **search_platform_docs** - Search AutoGPT documentation for help
**Agent Creation & Editing:**
6. **create_agent** - Create a new custom agent from scratch based on user requirements
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
**Execution & Output:**
8. **run_agent** - Run or schedule an agent (automatically handles setup)
9. **run_block** - Run a single block directly without creating an agent
10. **agent_output** - Get the output/results from a running or completed agent execution
1. **find_agent** - Search for agents that solve the user's problem
2. **run_agent** - Run or schedule an agent (automatically handles setup)
</functions>
## ALWAYS GET THE USER'S NAME
**This is critical:** If you don't know the user's name, ask for it in your first response. Use a friendly, natural approach:
- "Hi! I'm Otto. What's your name?"
- "Hey there! Before we dive in, what should I call you?"
Once you have their name, immediately save it with `add_understanding(user_name="...")` and use it throughout the conversation.
## BUILDING USER UNDERSTANDING
**If no User Business Context is provided below**, gather information naturally during conversation - don't interrogate them.
**Key information to gather (in priority order):**
1. Their name (ALWAYS first if unknown)
2. Their job title and role
3. Their business/company and industry
4. Pain points and what they want to automate
5. Tools they currently use
**How to gather this information:**
- Ask naturally as part of helping them (e.g., "What's your role?" or "What industry are you in?")
- When they share information, immediately save it using `add_understanding`
- Don't ask all questions at once - spread them across the conversation
- Prioritize understanding their immediate problem first
**Example:**
```
User: "I need help automating my social media"
Otto: I can help with that! I'm Otto - what's your name?
User: "I'm Sarah"
Otto: [calls add_understanding with user_name="Sarah"]
Nice to meet you, Sarah! What's your role - are you a social media manager or business owner?
User: "I'm the marketing director at a fintech startup"
Otto: [calls add_understanding with job_title="Marketing Director", industry="fintech", business_size="startup"]
Great! Let me find social media automation agents for you.
[calls find_agent with query="social media automation marketing"]
```
## WHEN TO USE WHICH TOOL
**Finding existing agents:**
- `find_agent` - Search the marketplace for pre-built agents others have created
- `find_library_agent` - Search agents the user has already saved to their library
**Creating/editing agents:**
- `create_agent` - When user wants a custom agent that doesn't exist, or has specific requirements
- `edit_agent` - When user wants to modify an existing agent (change inputs, add blocks, etc.)
**Running agents:**
- `run_agent` - To execute an agent (handles credentials and inputs automatically)
- `agent_output` - To check the results of a running or completed agent execution
**Direct execution:**
- `run_block` - Run a single block directly without needing a full agent
## HOW run_agent WORKS
The `run_agent` tool automatically handles the entire setup flow:
@@ -89,61 +21,49 @@ Parameters:
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
- `schedule_name` + `cron`: For scheduled execution
## HOW create_agent WORKS
Use `create_agent` when the user wants to build a custom automation:
- Describe what the agent should do
- The tool will create the agent structure with appropriate blocks
- Returns the agent ID for further editing or running
## HOW agent_output WORKS
Use `agent_output` to get results from agent executions:
- Pass the execution_id from a run_agent response
- Returns the current status and any outputs produced
- Useful for checking if an agent has completed and what it produced
## WORKFLOW
1. **Get their name** - If unknown, ask for it first
2. **Understand context** - Ask 1-2 questions about their problem while helping
3. **Find or create** - Use find_agent for existing solutions, create_agent for custom needs
4. **Set up and run** - Use run_agent to execute, agent_output to get results
1. **find_agent** - Search for agents that solve the user's problem
2. **run_agent** (first call, no inputs) - Get available inputs for the agent
3. **Ask user** what values they want to use OR if they want to use defaults
4. **run_agent** (second call) - Either with `inputs={...}` or `use_defaults=true`
## YOUR APPROACH
**Step 1: Greet and Identify**
- If you don't know their name, ask for it
- Be friendly and conversational
**Step 2: Understand the Problem**
**Step 1: Understand the Problem**
- Ask maximum 1-2 targeted questions
- Focus on: What business problem are they solving?
- If they want to create/edit an agent, understand what it should do
- Move quickly to searching for solutions
**Step 3: Find or Create**
- For existing solutions: Use `find_agent` with relevant keywords
- For custom needs: Use `create_agent` with their requirements
- For modifications: Use `edit_agent` on an existing agent
**Step 2: Find Agents**
- Use `find_agent` immediately with relevant keywords
- Suggest the best option from search results
- Explain briefly how it solves their problem
**Step 4: Execute**
- Call `run_agent` without inputs first to see what's available
- Ask user what values they want or if defaults are okay
- Call `run_agent` again with inputs or `use_defaults=true`
- Use `agent_output` to check results when needed
**Step 3: Get Agent Inputs**
- Call `run_agent(username_agent_slug="creator/agent-name")` without inputs
- This returns the available inputs (required and optional)
- Present these to the user and ask what values they want
## USING add_understanding
**Step 4: Run with User's Choice**
- If user provides values: `run_agent(username_agent_slug="...", inputs={...})`
- If user says "use defaults": `run_agent(username_agent_slug="...", use_defaults=true)`
- On success, share the agent link with the user
Call `add_understanding` whenever you learn something about the user:
**For Scheduled Execution:**
- Add `schedule_name` and `cron` parameters
- Example: `run_agent(username_agent_slug="...", inputs={...}, schedule_name="Daily Report", cron="0 9 * * *")`
**User info:** `user_name`, `job_title`
**Business:** `business_name`, `industry`, `business_size` (1-10, 11-50, 51-200, 201-1000, 1000+), `user_role` (decision maker, implementer, end user)
**Processes:** `key_workflows` (array), `daily_activities` (array)
**Pain points:** `pain_points` (array), `bottlenecks` (array), `manual_tasks` (array), `automation_goals` (array)
**Tools:** `current_software` (array), `existing_automation` (array)
**Other:** `additional_notes`
## FUNCTION CALL FORMAT
Example: `add_understanding(user_name="Sarah", job_title="Marketing Director", industry="fintech")`
To call a function, use this exact format:
`<function_call>function_name(parameter="value")</function_call>`
Examples:
- `<function_call>find_agent(query="social media automation")</function_call>`
- `<function_call>run_agent(username_agent_slug="creator/agent-name")</function_call>` (get inputs)
- `<function_call>run_agent(username_agent_slug="creator/agent-name", inputs={"topic": "AI news"})</function_call>`
- `<function_call>run_agent(username_agent_slug="creator/agent-name", use_defaults=true)</function_call>`
## KEY RULES
@@ -153,12 +73,8 @@ Example: `add_understanding(user_name="Sarah", job_title="Marketing Director", i
- Don't run agents without first showing available inputs to the user
- Don't use `use_defaults=true` without user explicitly confirming
- Don't write responses longer than 3 sentences
- Don't interrogate users with many questions - gather info naturally
**What You DO:**
- ALWAYS ask for user's name if you don't have it
- Save user information with `add_understanding` as you learn it
- Use their name when addressing them
- Always call run_agent first without inputs to see what's available
- Ask user what values they want OR if they want to use defaults
- Keep all responses to maximum 3 sentences
@@ -171,22 +87,18 @@ Example: `add_understanding(user_name="Sarah", job_title="Marketing Director", i
## RESPONSE STRUCTURE
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
- Check if you know the user's name - if not, ask for it
- Check if you have user context - if not, plan to gather some naturally
- Extract the key business problem or request from the user's message
- Determine what function call (if any) you need to make next
- Plan your response to stay under the 3-sentence maximum
Example interaction:
```
User: "Hi, I want to build an agent that monitors my competitors"
Otto: <thinking>I don't know this user's name. I should ask for it while acknowledging their request.</thinking>
Hi! I'm Otto and I'd love to help you build a competitor monitoring agent. What's your name?
User: "I'm Mike"
Otto: [calls add_understanding with user_name="Mike"]
<thinking>Now I know Mike wants competitor monitoring. I should search for existing agents first.</thinking>
Great to meet you, Mike! Let me search for competitor monitoring agents.
[calls find_agent with query="competitor monitoring analysis"]
User: "Run the AI news agent for me"
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news")</function_call>
[Tool returns: Agent accepts inputs - Required: topic. Optional: num_articles (default: 5)]
Otto: The AI News agent needs a topic. What topic would you like news about, or should I use the defaults?
User: "Use defaults"
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news", use_defaults=true)</function_call>
```
KEEP ANSWERS TO 3 SENTENCES

View File

@@ -1,155 +0,0 @@
You are Otto, an AI Co-Pilot helping new users get started with AutoGPT, an AI Business Automation platform. Your mission is to welcome them, learn about their needs, and help them run their first successful agent.
Here are the functions available to you:
<functions>
**Understanding & Discovery:**
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
3. **find_library_agent** - Search the user's personal library of saved agents
4. **find_block** - Search for individual blocks (building components for agents)
5. **search_platform_docs** - Search AutoGPT documentation for help
**Agent Creation & Editing:**
6. **create_agent** - Create a new custom agent from scratch based on user requirements
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
**Execution & Output:**
8. **run_agent** - Run or schedule an agent (automatically handles setup)
9. **run_block** - Run a single block directly without creating an agent
10. **agent_output** - Get the output/results from a running or completed agent execution
</functions>
## YOUR ONBOARDING MISSION
You are guiding a new user through their first experience with AutoGPT. Your goal is to:
1. Welcome them warmly and get their name
2. Learn about them and their business
3. Find or create an agent that solves a real problem for them
4. Get that agent running successfully
5. Celebrate their success and point them to next steps
## PHASE 1: WELCOME & INTRODUCTION
**Start every conversation by:**
- Giving a warm, friendly greeting
- Introducing yourself as Otto, their AI assistant
- Asking for their name immediately
**Example opening:**
```
Hi! I'm Otto, your AI assistant. Welcome to AutoGPT! I'm here to help you set up your first automation. What's your name?
```
Once you have their name, save it immediately with `add_understanding(user_name="...")` and use it throughout.
## PHASE 2: DISCOVERY
**After getting their name, learn about them:**
- What's their role/job title?
- What industry/business are they in?
- What's one thing they'd love to automate?
**Keep it conversational - don't interrogate. Example:**
```
Nice to meet you, Sarah! What do you do for work, and what's one task you wish you could automate?
```
Save everything you learn with `add_understanding`.
## PHASE 3: FIND OR CREATE AN AGENT
**Once you understand their need:**
- Search for existing agents with `find_agent`
- Present the best match and explain how it helps them
- If nothing fits, offer to create a custom agent with `create_agent`
**Be enthusiastic about the solution:**
```
I found a great agent for you! The "Social Media Scheduler" can automatically post to your accounts on a schedule. Want to try it?
```
## PHASE 4: SETUP & RUN
**Guide them through running the agent:**
1. Call `run_agent` without inputs first to see what's needed
2. Explain each input in simple terms
3. Ask what values they want to use
4. Run the agent with their inputs or defaults
**Don't mention credentials** - the UI handles that automatically.
## PHASE 5: CELEBRATE & HANDOFF
**After successful execution:**
- Congratulate them on their first automation!
- Tell them where to find this agent (their Library)
- Mention they can explore more agents in the Marketplace
- Offer to help with anything else
**Example:**
```
You did it! Your first agent is running. You can find it anytime in your Library. Ready to explore more automations?
```
## KEY RULES
**What You DON'T Do:**
- Don't help with login (frontend handles this)
- Don't mention credentials (UI handles automatically)
- Don't run agents without showing inputs first
- Don't use `use_defaults=true` without explicit confirmation
- Don't write responses longer than 3 sentences
- Don't overwhelm with too many questions at once
**What You DO:**
- ALWAYS get the user's name first
- Be warm, encouraging, and celebratory
- Save info with `add_understanding` as you learn it
- Use their name when addressing them
- Keep responses to maximum 3 sentences
- Make them feel successful at each step
## USING add_understanding
Save information as you learn it:
**User info:** `user_name`, `job_title`
**Business:** `business_name`, `industry`, `business_size`, `user_role`
**Pain points:** `pain_points`, `manual_tasks`, `automation_goals`
**Tools:** `current_software`
Example: `add_understanding(user_name="Sarah", job_title="Marketing Manager", automation_goals=["social media scheduling"])`
## HOW run_agent WORKS
1. **First call** (no inputs) → Shows available inputs
2. **Credentials** → UI handles automatically (don't mention)
3. **Execution** → Run with `inputs={...}` or `use_defaults=true`
## RESPONSE STRUCTURE
Before responding, plan your approach in <thinking> tags:
- What phase am I in? (Welcome/Discovery/Find/Setup/Celebrate)
- Do I know their name? If not, ask for it
- What's the next step to move them forward?
- Keep response under 3 sentences
**Example flow:**
```
User: "Hi"
Otto: <thinking>Phase 1 - I need to welcome them and get their name.</thinking>
Hi! I'm Otto, welcome to AutoGPT! I'm here to help you set up your first automation - what's your name?
User: "I'm Alex"
Otto: [calls add_understanding with user_name="Alex"]
<thinking>Got their name. Phase 2 - learn about them.</thinking>
Great to meet you, Alex! What do you do for work, and what's one task you'd love to automate?
User: "I run an e-commerce store and spend hours on customer support emails"
Otto: [calls add_understanding with industry="e-commerce", pain_points=["customer support emails"]]
<thinking>Phase 3 - search for agents.</thinking>
[calls find_agent with query="customer support email automation"]
```
KEEP ANSWERS TO 3 SENTENCES - Be warm, helpful, and focused on their success!

View File

@@ -26,14 +26,6 @@ router = APIRouter(
# ========== Request/Response Models ==========
class StreamChatRequest(BaseModel):
"""Request model for streaming chat with optional context."""
message: str
is_user_message: bool = True
context: dict[str, str] | None = None # {url: str, content: str}
class CreateSessionResponse(BaseModel):
"""Response model containing information on a newly created chat session."""
@@ -52,64 +44,9 @@ class SessionDetailResponse(BaseModel):
messages: list[dict]
class SessionSummaryResponse(BaseModel):
"""Response model for a session summary (without messages)."""
id: str
created_at: str
updated_at: str
title: str | None = None
class ListSessionsResponse(BaseModel):
"""Response model for listing chat sessions."""
sessions: list[SessionSummaryResponse]
total: int
# ========== Routes ==========
@router.get(
"/sessions",
dependencies=[Security(auth.requires_user)],
)
async def list_sessions(
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(default=50, ge=1, le=100),
offset: int = Query(default=0, ge=0),
) -> ListSessionsResponse:
"""
List chat sessions for the authenticated user.
Returns a paginated list of chat sessions belonging to the current user,
ordered by most recently updated.
Args:
user_id: The authenticated user's ID.
limit: Maximum number of sessions to return (1-100).
offset: Number of sessions to skip for pagination.
Returns:
ListSessionsResponse: List of session summaries and total count.
"""
sessions = await chat_service.get_user_sessions(user_id, limit, offset)
return ListSessionsResponse(
sessions=[
SessionSummaryResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
title=None, # TODO: Add title support
)
for session in sessions
],
total=len(sessions),
)
@router.post(
"/sessions",
)
@@ -165,89 +102,26 @@ async def get_session(
session = await chat_service.get_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found")
messages = [message.model_dump() for message in session.messages]
logger.info(
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
)
@router.post(
"/sessions/{session_id}/stream",
)
async def stream_chat_post(
session_id: str,
request: StreamChatRequest,
user_id: str | None = Depends(auth.get_user_id),
):
"""
Stream chat responses for a session (POST with context support).
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
- Tool call UI elements (if invoked)
- Tool execution results
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
"""
# Validate session exists before starting the stream
# This prevents errors after the response has already started
session = await chat_service.get_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found. ")
if session.user_id is None and user_id is not None:
session = await chat_service.assign_user_to_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
yield chunk.to_sse()
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
messages=[message.model_dump() for message in session.messages],
)
@router.get(
"/sessions/{session_id}/stream",
)
async def stream_chat_get(
async def stream_chat(
session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
):
"""
Stream chat responses for a session (GET - legacy endpoint).
Stream chat responses for a session.
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
@@ -319,133 +193,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Onboarding Routes ==========
# These routes use a specialized onboarding system prompt
@router.post(
"/onboarding/sessions",
)
async def create_onboarding_session(
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CreateSessionResponse:
"""
Create a new onboarding chat session.
Initiates a new chat session specifically for user onboarding,
using a specialized prompt that guides users through their first
experience with AutoGPT.
Args:
user_id: The optional authenticated user ID parsed from the JWT.
Returns:
CreateSessionResponse: Details of the created onboarding session.
"""
logger.info(
f"Creating onboarding session with user_id: "
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
)
session = await chat_service.create_chat_session(user_id)
return CreateSessionResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
user_id=session.user_id or None,
)
@router.get(
"/onboarding/sessions/{session_id}",
)
async def get_onboarding_session(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> SessionDetailResponse:
"""
Retrieve the details of an onboarding chat session.
Args:
session_id: The unique identifier for the onboarding session.
user_id: The optional authenticated user ID.
Returns:
SessionDetailResponse: Details for the requested session.
"""
session = await chat_service.get_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found")
messages = [message.model_dump() for message in session.messages]
logger.info(
f"Returning onboarding session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
)
@router.post(
"/onboarding/sessions/{session_id}/stream",
)
async def stream_onboarding_chat(
session_id: str,
request: StreamChatRequest,
user_id: str | None = Depends(auth.get_user_id),
):
"""
Stream onboarding chat responses for a session.
Uses the specialized onboarding system prompt to guide new users
through their first experience with AutoGPT. Streams AI responses
in real time over Server-Sent Events (SSE).
Args:
session_id: The onboarding session identifier.
request: Request body containing message and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
"""
session = await chat_service.get_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found.")
if session.user_id is None and user_id is not None:
session = await chat_service.assign_user_to_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session,
context=request.context,
prompt_type="onboarding", # Use onboarding system prompt
):
yield chunk.to_sse()
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# ========== Health Check ==========

View File

@@ -4,18 +4,11 @@ from datetime import UTC, datetime
from typing import Any
import orjson
from langfuse import Langfuse
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
from backend.data.understanding import (
format_understanding_for_prompt,
get_business_understanding,
)
from backend.util.exceptions import NotFoundError
from backend.util.settings import Settings
from . import db as chat_db
from .config import ChatConfig
from .model import (
ChatMessage,
@@ -24,9 +17,6 @@ from .model import (
get_chat_session,
upsert_chat_session,
)
from .model import (
create_chat_session as model_create_chat_session,
)
from .response_model import (
StreamBaseResponse,
StreamEnd,
@@ -43,154 +33,8 @@ from .tools import execute_tool, tools
logger = logging.getLogger(__name__)
config = ChatConfig()
settings = Settings()
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
# Langfuse client (lazy initialization)
_langfuse_client: Langfuse | None = None
def _get_langfuse_client() -> Langfuse:
"""Get or create the Langfuse client for prompt management."""
global _langfuse_client
if _langfuse_client is None:
if not settings.secrets.langfuse_public_key or not settings.secrets.langfuse_secret_key:
raise ValueError(
"Langfuse credentials not configured. "
"Set LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables."
)
_langfuse_client = Langfuse(
public_key=settings.secrets.langfuse_public_key,
secret_key=settings.secrets.langfuse_secret_key,
host=settings.secrets.langfuse_host or "https://cloud.langfuse.com",
)
return _langfuse_client
def _get_langfuse_prompt() -> str:
"""Fetch the latest production prompt from Langfuse.
Returns:
The compiled prompt text from Langfuse.
Raises:
Exception: If Langfuse is unavailable or prompt fetch fails.
"""
try:
langfuse = _get_langfuse_client()
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
compiled = prompt.compile()
logger.info(
f"Fetched prompt '{config.langfuse_prompt_name}' from Langfuse "
f"(version: {prompt.version})"
)
return compiled
except Exception as e:
logger.error(f"Failed to fetch prompt from Langfuse: {e}")
raise
async def _is_first_session(user_id: str) -> bool:
"""Check if this is the user's first chat session.
Returns True if the user has 1 or fewer sessions (meaning this is their first).
"""
try:
session_count = await chat_db.get_user_session_count(user_id)
return session_count <= 1
except Exception as e:
logger.warning(f"Failed to check session count for user {user_id}: {e}")
return False # Default to non-onboarding if we can't check
async def _build_system_prompt(
user_id: str | None, prompt_type: str = "default"
) -> str:
"""Build the full system prompt including business understanding if available.
Args:
user_id: The user ID for fetching business understanding
prompt_type: The type of prompt to load ("default" or "onboarding")
If "default" and this is the user's first session, will use "onboarding" instead.
Returns:
The full system prompt with business understanding context if available
"""
# Auto-detect: if using default prompt and this is user's first session, use onboarding
effective_prompt_type = prompt_type
if prompt_type == "default" and user_id:
if await _is_first_session(user_id):
logger.info("First session detected for user, using onboarding prompt")
effective_prompt_type = "onboarding"
# Start with the base system prompt for the specified type
if effective_prompt_type == "default":
# Fetch from Langfuse for the default prompt
base_prompt = _get_langfuse_prompt()
else:
# Use local file for other prompt types (e.g., onboarding)
base_prompt = config.get_system_prompt_for_type(effective_prompt_type)
# If user is authenticated, try to fetch their business understanding
if user_id:
try:
understanding = await get_business_understanding(user_id)
if understanding:
context = format_understanding_for_prompt(understanding)
if context:
return (
f"{base_prompt}\n\n---\n\n"
f"{context}\n\n"
"Use this context to provide more personalized recommendations "
"and to better understand the user's business needs when "
"suggesting agents and automations."
)
except Exception as e:
logger.warning(f"Failed to fetch business understanding: {e}")
return base_prompt
async def _generate_session_title(message: str) -> str | None:
"""Generate a concise title for a chat session based on the first message.
Args:
message: The first user message in the session
Returns:
A short title (3-6 words) or None if generation fails
"""
try:
response = await client.chat.completions.create(
model=config.title_model,
messages=[
{
"role": "system",
"content": (
"Generate a very short title (3-6 words) for a chat conversation "
"based on the user's first message. The title should capture the "
"main topic or intent. Return ONLY the title, no quotes or punctuation."
),
},
{"role": "user", "content": message[:500]}, # Limit input length
],
max_tokens=20,
temperature=0.7,
)
title = response.choices[0].message.content
if title:
# Clean up the title
title = title.strip().strip("\"'")
# Limit length
if len(title) > 50:
title = title[:47] + "..."
return title
return None
except Exception as e:
logger.warning(f"Failed to generate session title: {e}")
return None
async def create_chat_session(
user_id: str | None = None,
@@ -198,7 +42,9 @@ async def create_chat_session(
"""
Create a new chat session and persist it to the database.
"""
return await model_create_chat_session(user_id)
session = ChatSession.new(user_id)
# Persist the session immediately so it can be used for streaming
return await upsert_chat_session(session)
async def get_session(
@@ -211,19 +57,6 @@ async def get_session(
return await get_chat_session(session_id, user_id)
async def get_user_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> list[ChatSession]:
"""
Get all chat sessions for a user.
"""
from .model import get_user_sessions as model_get_user_sessions
return await model_get_user_sessions(user_id, limit, offset)
async def assign_user_to_session(
session_id: str,
user_id: str,
@@ -245,8 +78,6 @@ async def stream_chat_completion(
user_id: str | None = None,
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None, # {url: str, content: str}
prompt_type: str = "default",
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Main entry point for streaming chat completions with database handling.
@@ -258,7 +89,6 @@ async def stream_chat_completion(
user_message: User's input message
user_id: User ID for authentication (None for anonymous)
session: Optional pre-loaded session object (for recursive calls to avoid Redis refetch)
prompt_type: The type of prompt to use ("default" or "onboarding")
Yields:
StreamBaseResponse objects formatted as SSE
@@ -291,18 +121,9 @@ async def stream_chat_completion(
)
if message:
# Build message content with context if provided
message_content = message
if context and context.get("url") and context.get("content"):
context_text = f"Page URL: {context['url']}\n\nPage Content:\n{context['content']}\n\n---\n\nUser Message: {message}"
message_content = context_text
logger.info(
f"Including page context: URL={context['url']}, content_length={len(context['content'])}"
)
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message_content
role="user" if is_user_message else "assistant", content=message
)
)
logger.info(
@@ -320,32 +141,6 @@ async def stream_chat_completion(
session = await upsert_chat_session(session)
assert session, "Session not found"
# Generate title for new sessions on first user message (non-blocking)
# Check: is_user_message, no title yet, and this is the first user message
if is_user_message and message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
# First user message - generate title in background
import asyncio
async def _update_title():
try:
title = await _generate_session_title(message)
if title:
session.title = title
await upsert_chat_session(session)
logger.info(
f"Generated title for session {session_id}: {title}"
)
except Exception as e:
logger.warning(f"Failed to update session title: {e}")
# Fire and forget - don't block the chat response
asyncio.create_task(_update_title())
# Build system prompt with business understanding
system_prompt = await _build_system_prompt(user_id, prompt_type)
assistant_response = ChatMessage(
role="assistant",
content="",
@@ -364,7 +159,6 @@ async def stream_chat_completion(
async for chunk in _stream_chat_chunks(
session=session,
tools=tools,
system_prompt=system_prompt,
):
if isinstance(chunk, StreamTextChunk):
@@ -485,7 +279,6 @@ async def stream_chat_completion(
user_id=user_id,
retry_count=retry_count + 1,
session=session,
prompt_type=prompt_type,
):
yield chunk
return # Exit after retry to avoid double-saving in finally block
@@ -531,7 +324,6 @@ async def stream_chat_completion(
session_id=session.session_id,
user_id=user_id,
session=session, # Pass session object to avoid Redis refetch
prompt_type=prompt_type,
):
yield chunk
@@ -539,7 +331,6 @@ async def stream_chat_completion(
async def _stream_chat_chunks(
session: ChatSession,
tools: list[ChatCompletionToolParam],
system_prompt: str | None = None,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""
Pure streaming function for OpenAI chat completions with tool calling.
@@ -547,9 +338,9 @@ async def _stream_chat_chunks(
This function is database-agnostic and focuses only on streaming logic.
Args:
session: Chat session with conversation history
tools: Available tools for the model
system_prompt: System prompt to prepend to messages
messages: Conversation context as ChatCompletionMessageParam list
session_id: Session ID
user_id: User ID for tool execution
Yields:
SSE formatted JSON response objects
@@ -559,17 +350,6 @@ async def _stream_chat_chunks(
logger.info("Starting pure chat stream")
# Build messages with system prompt prepended
messages = session.to_openai_messages()
if system_prompt:
from openai.types.chat import ChatCompletionSystemMessageParam
system_message = ChatCompletionSystemMessageParam(
role="system",
content=system_prompt,
)
messages = [system_message] + messages
# Loop to handle tool calls and continue conversation
while True:
try:
@@ -578,7 +358,7 @@ async def _stream_chat_chunks(
# Create the stream with proper types
stream = await client.chat.completions.create(
model=model,
messages=messages,
messages=session.to_openai_messages(),
tools=tools,
tool_choice="auto",
stream=True,
@@ -722,12 +502,8 @@ async def _yield_tool_call(
"""
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
# Parse tool call arguments - handle empty arguments gracefully
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
if raw_arguments:
arguments = orjson.loads(raw_arguments)
else:
arguments = {}
# Parse tool call arguments - exceptions will propagate to caller
arguments = orjson.loads(tool_calls[yield_idx]["function"]["arguments"])
yield StreamToolCall(
tool_id=tool_calls[yield_idx]["id"],

View File

@@ -4,45 +4,21 @@ from openai.types.chat import ChatCompletionToolParam
from backend.api.features.chat.model import ChatSession
from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .create_agent import CreateAgentTool
from .edit_agent import EditAgentTool
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolExecutionResult
# Initialize tool instances
add_understanding_tool = AddUnderstandingTool()
create_agent_tool = CreateAgentTool()
edit_agent_tool = EditAgentTool()
find_agent_tool = FindAgentTool()
find_block_tool = FindBlockTool()
find_library_agent_tool = FindLibraryAgentTool()
run_agent_tool = RunAgentTool()
run_block_tool = RunBlockTool()
search_docs_tool = SearchDocsTool()
agent_output_tool = AgentOutputTool()
# Export tools as OpenAI format
tools: list[ChatCompletionToolParam] = [
add_understanding_tool.as_openai_tool(),
create_agent_tool.as_openai_tool(),
edit_agent_tool.as_openai_tool(),
find_agent_tool.as_openai_tool(),
find_block_tool.as_openai_tool(),
find_library_agent_tool.as_openai_tool(),
run_agent_tool.as_openai_tool(),
run_block_tool.as_openai_tool(),
search_docs_tool.as_openai_tool(),
agent_output_tool.as_openai_tool(),
]
@@ -55,16 +31,8 @@ async def execute_tool(
) -> "StreamToolExecutionResult":
tool_map: dict[str, BaseTool] = {
"add_understanding": add_understanding_tool,
"create_agent": create_agent_tool,
"edit_agent": edit_agent_tool,
"find_agent": find_agent_tool,
"find_block": find_block_tool,
"find_library_agent": find_library_agent_tool,
"run_agent": run_agent_tool,
"run_block": run_block_tool,
"search_platform_docs": search_docs_tool,
"agent_output": agent_output_tool,
}
if tool_name not in tool_map:
raise ValueError(f"Tool {tool_name} not found")

View File

@@ -1,206 +0,0 @@
"""Tool for capturing user business understanding incrementally."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from .base import BaseTool
from .models import (
ErrorResponse,
ToolResponseBase,
UnderstandingUpdatedResponse,
)
logger = logging.getLogger(__name__)
class AddUnderstandingTool(BaseTool):
"""Tool for capturing user's business understanding incrementally."""
@property
def name(self) -> str:
return "add_understanding"
@property
def description(self) -> str:
return """Capture and store information about the user's business context,
workflows, pain points, and automation goals. Call this tool whenever the user
shares information about their business. Each call incrementally adds to the
existing understanding - you don't need to provide all fields at once.
Use this to build a comprehensive profile that helps recommend better agents
and automations for the user's specific needs."""
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"user_name": {
"type": "string",
"description": "The user's name",
},
"job_title": {
"type": "string",
"description": "The user's job title (e.g., 'Marketing Manager', 'CEO', 'Software Engineer')",
},
"business_name": {
"type": "string",
"description": "Name of the user's business or organization",
},
"industry": {
"type": "string",
"description": "Industry or sector (e.g., 'e-commerce', 'healthcare', 'finance')",
},
"business_size": {
"type": "string",
"description": "Company size: '1-10', '11-50', '51-200', '201-1000', or '1000+'",
},
"user_role": {
"type": "string",
"description": "User's role in organization context (e.g., 'decision maker', 'implementer', 'end user')",
},
"key_workflows": {
"type": "array",
"items": {"type": "string"},
"description": "Key business workflows (e.g., 'lead qualification', 'content publishing')",
},
"daily_activities": {
"type": "array",
"items": {"type": "string"},
"description": "Regular daily activities the user performs",
},
"pain_points": {
"type": "array",
"items": {"type": "string"},
"description": "Current pain points or challenges",
},
"bottlenecks": {
"type": "array",
"items": {"type": "string"},
"description": "Process bottlenecks slowing things down",
},
"manual_tasks": {
"type": "array",
"items": {"type": "string"},
"description": "Manual or repetitive tasks that could be automated",
},
"automation_goals": {
"type": "array",
"items": {"type": "string"},
"description": "Desired automation outcomes or goals",
},
"current_software": {
"type": "array",
"items": {"type": "string"},
"description": "Software and tools currently in use",
},
"existing_automation": {
"type": "array",
"items": {"type": "string"},
"description": "Any existing automations or integrations",
},
"additional_notes": {
"type": "string",
"description": "Any other relevant context or notes",
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
"""Requires authentication to store user-specific data."""
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""
Capture and store business understanding incrementally.
Each call merges new data with existing understanding:
- String fields are overwritten if provided
- List fields are appended (with deduplication)
"""
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required to save business understanding.",
session_id=session_id,
)
# Check if any data was provided
if not any(v is not None for v in kwargs.values()):
return ErrorResponse(
message="Please provide at least one field to update.",
session_id=session_id,
)
# Build input model
input_data = BusinessUnderstandingInput(
user_name=kwargs.get("user_name"),
job_title=kwargs.get("job_title"),
business_name=kwargs.get("business_name"),
industry=kwargs.get("industry"),
business_size=kwargs.get("business_size"),
user_role=kwargs.get("user_role"),
key_workflows=kwargs.get("key_workflows"),
daily_activities=kwargs.get("daily_activities"),
pain_points=kwargs.get("pain_points"),
bottlenecks=kwargs.get("bottlenecks"),
manual_tasks=kwargs.get("manual_tasks"),
automation_goals=kwargs.get("automation_goals"),
current_software=kwargs.get("current_software"),
existing_automation=kwargs.get("existing_automation"),
additional_notes=kwargs.get("additional_notes"),
)
# Track which fields were updated
updated_fields = [k for k, v in kwargs.items() if v is not None]
# Upsert with merge
understanding = await upsert_business_understanding(user_id, input_data)
# Build current understanding summary for the response
current_understanding = {
"user_name": understanding.user_name,
"job_title": understanding.job_title,
"business_name": understanding.business_name,
"industry": understanding.industry,
"business_size": understanding.business_size,
"user_role": understanding.user_role,
"key_workflows": understanding.key_workflows,
"daily_activities": understanding.daily_activities,
"pain_points": understanding.pain_points,
"bottlenecks": understanding.bottlenecks,
"manual_tasks": understanding.manual_tasks,
"automation_goals": understanding.automation_goals,
"current_software": understanding.current_software,
"existing_automation": understanding.existing_automation,
"additional_notes": understanding.additional_notes,
}
# Filter out empty values for cleaner response
current_understanding = {
k: v
for k, v in current_understanding.items()
if v is not None and v != [] and v != ""
}
return UnderstandingUpdatedResponse(
message=f"Updated understanding with: {', '.join(updated_fields)}. "
"I now have a better picture of your business context.",
session_id=session_id,
updated_fields=updated_fields,
current_understanding=current_understanding,
)

View File

@@ -1,29 +0,0 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
apply_agent_patch,
decompose_goal,
generate_agent,
generate_agent_patch,
get_agent_as_json,
save_agent_to_library,
)
from .fixer import apply_all_fixes
from .utils import get_blocks_info
from .validator import validate_agent
__all__ = [
# Core functions
"decompose_goal",
"generate_agent",
"generate_agent_patch",
"apply_agent_patch",
"save_agent_to_library",
"get_agent_as_json",
# Fixer
"apply_all_fixes",
# Validator
"validate_agent",
# Utils
"get_blocks_info",
]

View File

@@ -1,25 +0,0 @@
"""OpenRouter client configuration for agent generation."""
import os
from openai import AsyncOpenAI
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY") or os.getenv("OPENROUTER_API_KEY")
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
# OpenRouter client (OpenAI-compatible API)
_client: AsyncOpenAI | None = None
def get_client() -> AsyncOpenAI:
"""Get or create the OpenRouter client."""
global _client
if _client is None:
if not OPENROUTER_API_KEY:
raise ValueError("OPENROUTER_API_KEY environment variable is required")
_client = AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=OPENROUTER_API_KEY,
)
return _client

View File

@@ -1,390 +0,0 @@
"""Core agent generation functions."""
import copy
import json
import logging
import uuid
from typing import Any
from backend.api.features.library import db as library_db
from backend.data.graph import Graph, Link, Node, create_graph
from .client import AGENT_GENERATOR_MODEL, get_client
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
from .utils import get_block_summaries, parse_json_from_llm
logger = logging.getLogger(__name__)
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
"""
client = get_client()
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
full_description = description
if context:
full_description = f"{description}\n\nAdditional context:\n{context}"
try:
response = await client.chat.completions.create(
model=AGENT_GENERATOR_MODEL,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": full_description},
],
temperature=0,
)
content = response.choices[0].message.content
if content is None:
logger.error("LLM returned empty content for decomposition")
return None
result = parse_json_from_llm(content)
if result is None:
logger.error(f"Failed to parse decomposition response: {content[:200]}")
return None
return result
except Exception as e:
logger.error(f"Error decomposing goal: {e}")
return None
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
Returns:
Agent JSON dict or None on error
"""
client = get_client()
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
try:
response = await client.chat.completions.create(
model=AGENT_GENERATOR_MODEL,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": json.dumps(instructions, indent=2)},
],
temperature=0,
)
content = response.choices[0].message.content
if content is None:
logger.error("LLM returned empty content for agent generation")
return None
result = parse_json_from_llm(content)
if result is None:
logger.error(f"Failed to parse agent JSON: {content[:200]}")
return None
# Ensure required fields
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
result["version"] = 1
if "is_active" not in result:
result["is_active"] = True
return result
except Exception as e:
logger.error(f"Error generating agent: {e}")
return None
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
"""Convert agent JSON dict to Graph model.
Args:
agent_json: Agent JSON with nodes and links
Returns:
Graph ready for saving
"""
nodes = []
for n in agent_json.get("nodes", []):
node = Node(
id=n.get("id", str(uuid.uuid4())),
block_id=n["block_id"],
input_default=n.get("input_default", {}),
metadata=n.get("metadata", {}),
)
nodes.append(node)
links = []
for link_data in agent_json.get("links", []):
link = Link(
id=link_data.get("id", str(uuid.uuid4())),
source_id=link_data["source_id"],
sink_id=link_data["sink_id"],
source_name=link_data["source_name"],
sink_name=link_data["sink_name"],
is_static=link_data.get("is_static", False),
)
links.append(link)
return Graph(
id=agent_json.get("id", str(uuid.uuid4())),
version=agent_json.get("version", 1),
is_active=agent_json.get("is_active", True),
name=agent_json.get("name", "Generated Agent"),
description=agent_json.get("description", ""),
nodes=nodes,
links=links,
)
def _reassign_node_ids(graph: Graph) -> None:
"""Reassign all node and link IDs to new UUIDs.
This is needed when creating a new version to avoid unique constraint violations.
"""
# Create mapping from old node IDs to new UUIDs
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
# Reassign node IDs
for node in graph.nodes:
node.id = id_map[node.id]
# Update link references to use new node IDs
for link in graph.links:
link.id = str(uuid.uuid4()) # Also give links new IDs
if link.source_id in id_map:
link.source_id = id_map[link.source_id]
if link.sink_id in id_map:
link.sink_id = id_map[link.sink_id]
async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False
) -> tuple[Graph, Any]:
"""Save agent to database and user's library.
Args:
agent_json: Agent JSON dict
user_id: User ID
is_update: Whether this is an update to an existing agent
Returns:
Tuple of (created Graph, LibraryAgent)
"""
from backend.data.graph import get_graph_all_versions
graph = json_to_graph(agent_json)
if is_update:
# For updates, keep the same graph ID but increment version
# and reassign node/link IDs to avoid conflicts
if graph.id:
existing_versions = await get_graph_all_versions(graph.id, user_id)
if existing_versions:
latest_version = max(v.version for v in existing_versions)
graph.version = latest_version + 1
# Reassign node IDs (but keep graph ID the same)
_reassign_node_ids(graph)
logger.info(f"Updating agent {graph.id} to version {graph.version}")
else:
# For new agents, always generate a fresh UUID to avoid collisions
graph.id = str(uuid.uuid4())
graph.version = 1
# Reassign all node IDs as well
_reassign_node_ids(graph)
logger.info(f"Creating new agent with ID {graph.id}")
# Save to database
created_graph = await create_graph(graph, user_id)
# Add to user's library (or update existing library agent)
library_agents = await library_db.create_library_agent(
graph=created_graph,
user_id=user_id,
create_library_agents_for_sub_graphs=False,
)
return created_graph, library_agents[0]
async def get_agent_as_json(
graph_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args:
graph_id: Graph ID or library agent ID
user_id: User ID
Returns:
Agent as JSON dict or None if not found
"""
from backend.data.graph import get_graph
# Try to get the graph (version=None gets the active version)
graph = await get_graph(graph_id, version=None, user_id=user_id)
if not graph:
return None
# Convert to JSON format
nodes = []
for node in graph.nodes:
nodes.append(
{
"id": node.id,
"block_id": node.block_id,
"input_default": node.input_default,
"metadata": node.metadata,
}
)
links = []
for node in graph.nodes:
for link in node.output_links:
links.append(
{
"id": link.id,
"source_id": link.source_id,
"sink_id": link.sink_id,
"source_name": link.source_name,
"sink_name": link.sink_name,
"is_static": link.is_static,
}
)
return {
"id": graph.id,
"name": graph.name,
"description": graph.description,
"version": graph.version,
"is_active": graph.is_active,
"nodes": nodes,
"links": links,
}
async def generate_agent_patch(
update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None:
"""Generate a patch to update an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
Returns:
Patch dict or clarifying questions, or None on error
"""
client = get_client()
prompt = PATCH_PROMPT.format(
current_agent=json.dumps(current_agent, indent=2),
block_summaries=get_block_summaries(),
)
try:
response = await client.chat.completions.create(
model=AGENT_GENERATOR_MODEL,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": update_request},
],
temperature=0,
)
content = response.choices[0].message.content
if content is None:
logger.error("LLM returned empty content for patch generation")
return None
return parse_json_from_llm(content)
except Exception as e:
logger.error(f"Error generating patch: {e}")
return None
def apply_agent_patch(
current_agent: dict[str, Any], patch: dict[str, Any]
) -> dict[str, Any]:
"""Apply a patch to an existing agent.
Args:
current_agent: Current agent JSON
patch: Patch dict with operations
Returns:
Updated agent JSON
"""
agent = copy.deepcopy(current_agent)
patches = patch.get("patches", [])
for p in patches:
patch_type = p.get("type")
if patch_type == "modify":
node_id = p.get("node_id")
changes = p.get("changes", {})
for node in agent.get("nodes", []):
if node["id"] == node_id:
_deep_update(node, changes)
logger.debug(f"Modified node {node_id}")
break
elif patch_type == "add":
new_nodes = p.get("new_nodes", [])
new_links = p.get("new_links", [])
agent["nodes"] = agent.get("nodes", []) + new_nodes
agent["links"] = agent.get("links", []) + new_links
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
elif patch_type == "remove":
node_ids_to_remove = set(p.get("node_ids", []))
link_ids_to_remove = set(p.get("link_ids", []))
# Remove nodes
agent["nodes"] = [
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
]
# Remove links (both explicit and those referencing removed nodes)
agent["links"] = [
link
for link in agent.get("links", [])
if link["id"] not in link_ids_to_remove
and link["source_id"] not in node_ids_to_remove
and link["sink_id"] not in node_ids_to_remove
]
logger.debug(
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
)
return agent
def _deep_update(target: dict, source: dict) -> None:
"""Recursively update a dict with another dict."""
for key, value in source.items():
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
_deep_update(target[key], value)
else:
target[key] = value

View File

@@ -1,606 +0,0 @@
"""Agent fixer - Fixes common LLM generation errors."""
import logging
import re
import uuid
from typing import Any
from .utils import (
ADDTODICTIONARY_BLOCK_ID,
ADDTOLIST_BLOCK_ID,
CODE_EXECUTION_BLOCK_ID,
CONDITION_BLOCK_ID,
CREATEDICT_BLOCK_ID,
CREATELIST_BLOCK_ID,
DATA_SAMPLING_BLOCK_ID,
DOUBLE_CURLY_BRACES_BLOCK_IDS,
GET_CURRENT_DATE_BLOCK_ID,
STORE_VALUE_BLOCK_ID,
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
get_blocks_info,
is_valid_uuid,
)
logger = logging.getLogger(__name__)
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix invalid UUIDs in agent and link IDs."""
# Fix agent ID
if not is_valid_uuid(agent.get("id", "")):
agent["id"] = str(uuid.uuid4())
logger.debug(f"Fixed agent ID: {agent['id']}")
# Fix node IDs
id_mapping = {} # Old ID -> New ID
for node in agent.get("nodes", []):
if not is_valid_uuid(node.get("id", "")):
old_id = node.get("id", "")
new_id = str(uuid.uuid4())
id_mapping[old_id] = new_id
node["id"] = new_id
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
# Fix link IDs and update references
for link in agent.get("links", []):
if not is_valid_uuid(link.get("id", "")):
link["id"] = str(uuid.uuid4())
logger.debug(f"Fixed link ID: {link['id']}")
# Update source/sink IDs if they were remapped
if link.get("source_id") in id_mapping:
link["source_id"] = id_mapping[link["source_id"]]
if link.get("sink_id") in id_mapping:
link["sink_id"] = id_mapping[link["sink_id"]]
return agent
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix single curly braces to double in template blocks."""
for node in agent.get("nodes", []):
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
continue
input_data = node.get("input_default", {})
for key in ("prompt", "format"):
if key in input_data and isinstance(input_data[key], str):
original = input_data[key]
# Fix simple variable references: {var} -> {{var}}
fixed = re.sub(
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
r"{{\1}}",
original,
)
if fixed != original:
input_data[key] = fixed
logger.debug(f"Fixed curly braces in {key}")
return agent
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
# Find all ConditionBlock nodes
condition_node_ids = {
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
}
if not condition_node_ids:
return agent
new_nodes = []
new_links = []
processed_conditions = set()
for link in links:
sink_id = link.get("sink_id")
sink_name = link.get("sink_name")
# Check if this link goes to a ConditionBlock's value2
if sink_id in condition_node_ids and sink_name == "value2":
source_node = next(
(n for n in nodes if n["id"] == link.get("source_id")), None
)
# Skip if source is already a StoreValueBlock
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
continue
# Skip if we already processed this condition
if sink_id in processed_conditions:
continue
processed_conditions.add(sink_id)
# Create StoreValueBlock
store_node_id = str(uuid.uuid4())
store_node = {
"id": store_node_id,
"block_id": STORE_VALUE_BLOCK_ID,
"input_default": {"data": None},
"metadata": {"position": {"x": 0, "y": -100}},
}
new_nodes.append(store_node)
# Create link: original source -> StoreValueBlock
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": link["source_id"],
"source_name": link["source_name"],
"sink_id": store_node_id,
"sink_name": "input",
"is_static": False,
}
)
# Update original link: StoreValueBlock -> ConditionBlock
link["source_id"] = store_node_id
link["source_name"] = "output"
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
if new_nodes:
agent["nodes"] = nodes + new_nodes
return agent
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
When an AddToList block is found:
1. Checks if there's a CreateListBlock before it
2. Removes CreateListBlock if linked directly to AddToList
3. Adds an empty AddToList block before the original
4. Ensures the original has a self-referencing link
"""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
new_nodes = []
original_addtolist_ids = set()
nodes_to_remove = set()
links_to_remove = []
# First pass: identify CreateListBlock nodes to remove
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
if (
source_node
and sink_node
and source_node.get("block_id") == CREATELIST_BLOCK_ID
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
):
nodes_to_remove.add(source_node.get("id"))
links_to_remove.append(link)
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
# Second pass: process AddToList blocks
filtered_nodes = []
for node in nodes:
if node.get("id") in nodes_to_remove:
continue
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
original_addtolist_ids.add(node.get("id"))
node_id = node.get("id")
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
# Check if already has prerequisite
has_prereq = any(
link.get("sink_id") == node_id
and link.get("sink_name") == "list"
and link.get("source_name") == "updated_list"
for link in links
)
if not has_prereq:
# Remove links to "list" input (except self-reference)
for link in links:
if (
link.get("sink_id") == node_id
and link.get("sink_name") == "list"
and link.get("source_id") != node_id
and link not in links_to_remove
):
links_to_remove.append(link)
# Create prerequisite AddToList block
prereq_id = str(uuid.uuid4())
prereq_node = {
"id": prereq_id,
"block_id": ADDTOLIST_BLOCK_ID,
"input_default": {"list": [], "entry": None, "entries": []},
"metadata": {
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
},
}
new_nodes.append(prereq_node)
# Link prerequisite to original
links.append(
{
"id": str(uuid.uuid4()),
"source_id": prereq_id,
"source_name": "updated_list",
"sink_id": node_id,
"sink_name": "list",
"is_static": False,
}
)
logger.debug(f"Added prerequisite AddToList block for {node_id}")
filtered_nodes.append(node)
# Remove marked links
filtered_links = [link for link in links if link not in links_to_remove]
# Add self-referencing links for original AddToList blocks
for node in filtered_nodes + new_nodes:
if (
node.get("block_id") == ADDTOLIST_BLOCK_ID
and node.get("id") in original_addtolist_ids
):
node_id = node.get("id")
has_self_ref = any(
link["source_id"] == node_id
and link["sink_id"] == node_id
and link["source_name"] == "updated_list"
and link["sink_name"] == "list"
for link in filtered_links
)
if not has_self_ref:
filtered_links.append(
{
"id": str(uuid.uuid4()),
"source_id": node_id,
"source_name": "updated_list",
"sink_id": node_id,
"sink_name": "list",
"is_static": False,
}
)
logger.debug(f"Added self-reference for AddToList {node_id}")
agent["nodes"] = filtered_nodes + new_nodes
agent["links"] = filtered_links
return agent
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
nodes_to_remove = set()
links_to_remove = []
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
if (
source_node
and sink_node
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
):
nodes_to_remove.add(source_node.get("id"))
links_to_remove.append(link)
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
agent["links"] = [link for link in links if link not in links_to_remove]
return agent
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
if (
source_node
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
and link.get("source_name") == "response"
):
link["source_name"] = "stdout_logs"
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
return agent
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
links_to_remove = []
for node in nodes:
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
node_id = node.get("id")
input_default = node.get("input_default", {})
# Remove links to sample_size
for link in links:
if (
link.get("sink_id") == node_id
and link.get("sink_name") == "sample_size"
):
links_to_remove.append(link)
# Set default
input_default["sample_size"] = 1
node["input_default"] = input_default
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
if links_to_remove:
agent["links"] = [link for link in links if link not in links_to_remove]
return agent
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
node_lookup = {n.get("id"): n for n in nodes}
for link in links:
source_id = link.get("source_id")
sink_id = link.get("sink_id")
source_node = node_lookup.get(source_id)
sink_node = node_lookup.get(sink_id)
if not source_node or not sink_node:
continue
source_pos = source_node.get("metadata", {}).get("position", {})
sink_pos = sink_node.get("metadata", {}).get("position", {})
source_x = source_pos.get("x", 0)
sink_x = sink_pos.get("x", 0)
if abs(sink_x - source_x) < 800:
new_x = source_x + 800
if "metadata" not in sink_node:
sink_node["metadata"] = {}
if "position" not in sink_node["metadata"]:
sink_node["metadata"]["position"] = {}
sink_node["metadata"]["position"]["x"] = new_x
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
return agent
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
for node in agent.get("nodes", []):
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
input_default = node.get("input_default", {})
if "offset" in input_default:
offset = input_default["offset"]
if isinstance(offset, (int, float)) and offset < 0:
input_default["offset"] = abs(offset)
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
return agent
def fix_ai_model_parameter(
agent: dict[str, Any],
blocks_info: list[dict[str, Any]],
default_model: str = "gpt-4o",
) -> dict[str, Any]:
"""Add default model parameter to AI blocks if missing."""
block_map = {b.get("id"): b for b in blocks_info}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_map.get(block_id)
if not block:
continue
# Check if block has AI category
categories = block.get("categories", [])
is_ai_block = any(
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
)
if is_ai_block:
input_default = node.get("input_default", {})
if "model" not in input_default:
input_default["model"] = default_model
node["input_default"] = input_default
logger.debug(
f"Added model '{default_model}' to AI block {node.get('id')}"
)
return agent
def fix_link_static_properties(
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> dict[str, Any]:
"""Fix is_static property based on source block's staticOutput."""
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
for link in agent.get("links", []):
source_node = node_lookup.get(link.get("source_id"))
if not source_node:
continue
source_block = block_map.get(source_node.get("block_id"))
if not source_block:
continue
static_output = source_block.get("staticOutput", False)
if link.get("is_static") != static_output:
link["is_static"] = static_output
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
return agent
def fix_data_type_mismatch(
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> dict[str, Any]:
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in nodes}
def get_property_type(schema: dict, name: str) -> str | None:
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema:
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_types_compatible(src: str, sink: str) -> bool:
if {src, sink} <= {"integer", "number"}:
return True
return src == sink
type_mapping = {
"string": "string",
"text": "string",
"integer": "number",
"number": "number",
"float": "number",
"boolean": "boolean",
"bool": "boolean",
"array": "list",
"list": "list",
"object": "dictionary",
"dict": "dictionary",
"dictionary": "dictionary",
}
new_links = []
nodes_to_add = []
for link in links:
source_node = node_lookup.get(link.get("source_id"))
sink_node = node_lookup.get(link.get("sink_id"))
if not source_node or not sink_node:
new_links.append(link)
continue
source_block = block_map.get(source_node.get("block_id"))
sink_block = block_map.get(sink_node.get("block_id"))
if not source_block or not sink_block:
new_links.append(link)
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_property_type(source_outputs, link.get("source_name", ""))
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
if (
source_type
and sink_type
and not are_types_compatible(source_type, sink_type)
):
# Insert type converter
converter_id = str(uuid.uuid4())
target_type = type_mapping.get(sink_type, sink_type)
converter_node = {
"id": converter_id,
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
"input_default": {"type": target_type},
"metadata": {"position": {"x": 0, "y": 100}},
}
nodes_to_add.append(converter_node)
# source -> converter
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": link["source_id"],
"source_name": link["source_name"],
"sink_id": converter_id,
"sink_name": "value",
"is_static": False,
}
)
# converter -> sink
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": converter_id,
"source_name": "value",
"sink_id": link["sink_id"],
"sink_name": link["sink_name"],
"is_static": False,
}
)
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
else:
new_links.append(link)
if nodes_to_add:
agent["nodes"] = nodes + nodes_to_add
agent["links"] = new_links
return agent
def apply_all_fixes(
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""Apply all fixes to an agent JSON.
Args:
agent: Agent JSON dict
blocks_info: Optional list of block info dicts for advanced fixes
Returns:
Fixed agent JSON
"""
# Basic fixes (no block info needed)
agent = fix_agent_ids(agent)
agent = fix_double_curly_braces(agent)
agent = fix_storevalue_before_condition(agent)
agent = fix_addtolist_blocks(agent)
agent = fix_addtodictionary_blocks(agent)
agent = fix_code_execution_output(agent)
agent = fix_data_sampling_sample_size(agent)
agent = fix_node_x_coordinates(agent)
agent = fix_getcurrentdate_offset(agent)
# Advanced fixes (require block info)
if blocks_info is None:
blocks_info = get_blocks_info()
agent = fix_ai_model_parameter(agent, blocks_info)
agent = fix_link_static_properties(agent, blocks_info)
agent = fix_data_type_mismatch(agent, blocks_info)
return agent

View File

@@ -1,225 +0,0 @@
"""Prompt templates for agent generation."""
DECOMPOSITION_PROMPT = """
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
---
FIRST: Analyze the user's goal and determine:
1) Design-time configuration (fixed settings that won't change per run)
2) Runtime inputs (values the agent's end-user will provide each time it runs)
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
- DO NOT ask for the actual value
- Instead, define it as an Agent Input with a clear name, type, and description
Only ask clarifying questions about design-time config that affects how you build the workflow:
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
- Business rules that must be hard-coded
IMPORTANT CLARIFICATIONS POLICY:
- Ask no more than five essential questions
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
- Do not ask for API keys or credentials; the platform handles those directly
- If there is enough information to infer reasonable defaults, prefer to propose defaults
---
GUIDELINES:
1. List each step as a numbered item
2. Describe the action clearly and specify inputs/outputs
3. Ensure steps are in logical, sequential order
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
5. Help the user reach their goal efficiently
---
RULES:
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
2. USE ONLY THE BLOCKS PROVIDED
3. ALL required_input fields must be provided
4. Data types of linked properties must match
5. Write expert-level prompts for AI-related blocks
---
CRITICAL BLOCK RESTRICTIONS:
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
3. ConditionBlock: value2 is reference, value1 is contrast
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
---
OUTPUT FORMAT:
If more information is needed:
```json
{{
"type": "clarifying_questions",
"questions": [
{{
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
"keyword": "email_provider",
"example": "Gmail"
}}
]
}}
```
If ready to proceed:
```json
{{
"type": "instructions",
"steps": [
{{
"step_number": 1,
"block_name": "AgentShortTextInputBlock",
"description": "Get the URL of the content to analyze.",
"inputs": [{{"name": "name", "value": "URL"}}],
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
}}
]
}}
```
---
AVAILABLE BLOCKS:
{block_summaries}
"""
GENERATION_PROMPT = """
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
---
NODES:
Each node must include:
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
- `block_id`: The block identifier (must match an Allowed Block)
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
- `metadata`: Must contain:
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
- `customized_name`: Clear name describing this block's purpose in the workflow
---
LINKS:
Each link connects a source node's output to a sink node's input:
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
- `source_id`: ID of the source node
- `source_name`: Output field name from the source block
- `sink_id`: ID of the sink node
- `sink_name`: Input field name on the sink block
- `is_static`: true only if source block has static_output: true
CRITICAL: All IDs must be valid UUID v4 format!
---
AGENT (GRAPH):
Wrap nodes and links in:
- `id`: UUID of the agent
- `name`: Short, generic name (avoid specific company names, URLs)
- `description`: Short, generic description
- `nodes`: List of all nodes
- `links`: List of all links
- `version`: 1
- `is_active`: true
---
TIPS:
- All required_input fields must be provided via input_default or a valid link
- Ensure consistent source_id and sink_id references
- Avoid dangling links
- Input/output pins must match block schemas
- Do not invent unknown block_ids
---
ALLOWED BLOCKS:
{block_summaries}
---
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
"""
PATCH_PROMPT = """
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
CURRENT AGENT:
{current_agent}
AVAILABLE BLOCKS:
{block_summaries}
---
PATCH FORMAT:
Return a JSON object with the following structure:
```json
{{
"type": "patch",
"intent": "Brief description of what the patch does",
"patches": [
{{
"type": "modify",
"node_id": "uuid-of-node-to-modify",
"changes": {{
"input_default": {{"field": "new_value"}},
"metadata": {{"customized_name": "New Name"}}
}}
}},
{{
"type": "add",
"new_nodes": [
{{
"id": "new-uuid",
"block_id": "block-uuid",
"input_default": {{}},
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
}}
],
"new_links": [
{{
"id": "link-uuid",
"source_id": "source-node-id",
"source_name": "output_field",
"sink_id": "sink-node-id",
"sink_name": "input_field"
}}
]
}},
{{
"type": "remove",
"node_ids": ["uuid-of-node-to-remove"],
"link_ids": ["uuid-of-link-to-remove"]
}}
]
}}
```
If you need more information, return:
```json
{{
"type": "clarifying_questions",
"questions": [
{{
"question": "What specific change do you want?",
"keyword": "change_type",
"example": "Add error handling"
}}
]
}}
```
Generate the minimal patch needed. Output ONLY valid JSON.
"""

View File

@@ -1,213 +0,0 @@
"""Utilities for agent generation."""
import json
import re
from typing import Any
from backend.data.block import get_blocks
# UUID validation regex
UUID_REGEX = re.compile(
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
)
# Block IDs for various fixes
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
"3b191d9f-356f-482d-8238-ba04b6d18381",
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
"716a67b3-6760-42e7-86dc-18645c6e00fc",
"530cf046-2ce0-4854-ae2c-659db17c7a46",
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
]
def is_valid_uuid(value: str) -> bool:
"""Check if a string is a valid UUID v4."""
return isinstance(value, str) and UUID_REGEX.match(value) is not None
def _compact_schema(schema: dict) -> dict[str, str]:
"""Extract compact type info from a JSON schema properties dict.
Returns a dict of {field_name: type_string} for essential info only.
"""
props = schema.get("properties", {})
result = {}
for name, prop in props.items():
# Skip internal/complex fields
if name.startswith("_"):
continue
# Get type string
type_str = prop.get("type", "any")
# Handle anyOf/oneOf (optional types)
if "anyOf" in prop:
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
type_str = "|".join(types) if types else "any"
elif "allOf" in prop:
type_str = "object"
# Add array item type if present
if type_str == "array" and "items" in prop:
items = prop["items"]
if isinstance(items, dict):
item_type = items.get("type", "any")
type_str = f"array[{item_type}]"
result[name] = type_str
return result
def get_block_summaries(include_schemas: bool = True) -> str:
"""Generate compact block summaries for prompts.
Args:
include_schemas: Whether to include input/output type info
Returns:
Formatted string of block summaries (compact format)
"""
blocks = get_blocks()
summaries = []
for block_id, block_cls in blocks.items():
block = block_cls()
name = block.name
desc = getattr(block, "description", "") or ""
# Truncate description
if len(desc) > 150:
desc = desc[:147] + "..."
if not include_schemas:
summaries.append(f"- {name} (id: {block_id}): {desc}")
else:
# Compact format with type info only
inputs = {}
outputs = {}
required = []
if hasattr(block, "input_schema"):
try:
schema = block.input_schema.jsonschema()
inputs = _compact_schema(schema)
required = schema.get("required", [])
except Exception:
pass
if hasattr(block, "output_schema"):
try:
schema = block.output_schema.jsonschema()
outputs = _compact_schema(schema)
except Exception:
pass
# Build compact line format
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
req_str = f" req=[{','.join(required)}]" if required else ""
static = " [static]" if getattr(block, "static_output", False) else ""
line = f"- {name} (id: {block_id}): {desc}"
if in_str:
line += f"\n in: {{{in_str}}}{req_str}"
if out_str:
line += f"\n out: {{{out_str}}}{static}"
summaries.append(line)
return "\n".join(summaries)
def get_blocks_info() -> list[dict[str, Any]]:
"""Get block information with schemas for validation and fixing."""
blocks = get_blocks()
blocks_info = []
for block_id, block_cls in blocks.items():
block = block_cls()
blocks_info.append(
{
"id": block_id,
"name": block.name,
"description": getattr(block, "description", ""),
"categories": getattr(block, "categories", []),
"staticOutput": getattr(block, "static_output", False),
"inputSchema": (
block.input_schema.jsonschema()
if hasattr(block, "input_schema")
else {}
),
"outputSchema": (
block.output_schema.jsonschema()
if hasattr(block, "output_schema")
else {}
),
}
)
return blocks_info
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
"""Extract JSON from LLM response (handles markdown code blocks)."""
if not text:
return None
# Try fenced code block
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
if match:
try:
return json.loads(match.group(1).strip())
except json.JSONDecodeError:
pass
# Try raw text
try:
return json.loads(text.strip())
except json.JSONDecodeError:
pass
# Try finding {...} span
start = text.find("{")
end = text.rfind("}")
if start != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
# Try finding [...] span
start = text.find("[")
end = text.rfind("]")
if start != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
return None

View File

@@ -1,279 +0,0 @@
"""Agent validator - Validates agent structure and connections."""
import logging
import re
from typing import Any
from .utils import get_blocks_info
logger = logging.getLogger(__name__)
class AgentValidator:
"""Validator for AutoGPT agents with detailed error reporting."""
def __init__(self):
self.errors: list[str] = []
def add_error(self, error: str) -> None:
"""Add an error message."""
self.errors.append(error)
def validate_block_existence(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate all block IDs exist in the blocks library."""
valid = True
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
node_id = node.get("id")
if not block_id:
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
valid = False
continue
if block_id not in valid_block_ids:
self.add_error(
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
)
valid = False
return valid
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
"""Validate all node IDs referenced in links exist."""
valid = True
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
for link in agent.get("links", []):
link_id = link.get("id", "Unknown")
source_id = link.get("source_id")
sink_id = link.get("sink_id")
if not source_id:
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
valid = False
elif source_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references non-existent source_id '{source_id}'."
)
valid = False
if not sink_id:
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
valid = False
elif sink_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
)
valid = False
return valid
def validate_required_inputs(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate required inputs are provided."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_map.get(block_id)
if not block:
continue
required_inputs = block.get("inputSchema", {}).get("required", [])
input_defaults = node.get("input_default", {})
node_id = node.get("id")
# Get linked inputs
linked_inputs = {
link["sink_name"]
for link in agent.get("links", [])
if link.get("sink_id") == node_id
}
for req_input in required_inputs:
if (
req_input not in input_defaults
and req_input not in linked_inputs
and req_input != "credentials"
):
block_name = block.get("name", "Unknown Block")
self.add_error(
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
)
valid = False
return valid
def validate_data_type_compatibility(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate linked data types are compatible."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
def get_type(schema: dict, name: str) -> str | None:
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema:
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_compatible(src: str, sink: str) -> bool:
if {src, sink} <= {"integer", "number"}:
return True
return src == sink
for link in agent.get("links", []):
source_node = node_lookup.get(link.get("source_id"))
sink_node = node_lookup.get(link.get("sink_id"))
if not source_node or not sink_node:
continue
source_block = block_map.get(source_node.get("block_id"))
sink_block = block_map.get(sink_node.get("block_id"))
if not source_block or not sink_block:
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_type(source_outputs, link.get("source_name", ""))
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
if source_type and sink_type and not are_compatible(source_type, sink_type):
self.add_error(
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
)
valid = False
return valid
def validate_nested_sink_links(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate nested sink links (with _#_ notation)."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
for link in agent.get("links", []):
sink_name = link.get("sink_name", "")
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
sink_node = node_lookup.get(link.get("sink_id"))
if not sink_node:
continue
block = block_map.get(sink_node.get("block_id"))
if not block:
continue
input_props = block.get("inputSchema", {}).get("properties", {})
parent_schema = input_props.get(parent)
if not parent_schema:
self.add_error(
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
)
valid = False
continue
if not parent_schema.get("additionalProperties"):
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and child in parent_schema.get("properties", {})
):
self.add_error(
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
)
valid = False
return valid
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
"""Validate prompts don't have spaces in template variables."""
valid = True
for node in agent.get("nodes", []):
input_default = node.get("input_default", {})
prompt = input_default.get("prompt", "")
if not isinstance(prompt, str):
continue
# Find {{...}} with spaces
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
for match in matches:
content = match.group(1)
if " " in content:
self.add_error(
f"Node '{node.get('id')}' has spaces in template variable: "
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
)
valid = False
return valid
def validate(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> tuple[bool, str | None]:
"""Run all validations.
Returns:
Tuple of (is_valid, error_message)
"""
self.errors = []
if blocks_info is None:
blocks_info = get_blocks_info()
checks = [
self.validate_block_existence(agent, blocks_info),
self.validate_link_node_references(agent),
self.validate_required_inputs(agent, blocks_info),
self.validate_data_type_compatibility(agent, blocks_info),
self.validate_nested_sink_links(agent, blocks_info),
self.validate_prompt_spaces(agent),
]
all_passed = all(checks)
if all_passed:
logger.info("Agent validation successful")
return True, None
error_message = "Agent validation failed:\n"
for i, error in enumerate(self.errors, 1):
error_message += f"{i}. {error}\n"
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
return False, error_message
def validate_agent(
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> tuple[bool, str | None]:
"""Convenience function to validate an agent.
Returns:
Tuple of (is_valid, error_message)
"""
validator = AgentValidator()
return validator.validate(agent, blocks_info)

View File

@@ -1,455 +0,0 @@
"""Tool for retrieving agent execution outputs from user's library."""
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.api.features.library.model import LibraryAgent
from backend.data import execution as execution_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
from .models import (
AgentOutputResponse,
ErrorResponse,
ExecutionOutputInfo,
NoResultsResponse,
ToolResponseBase,
)
from .utils import fetch_graph_from_store_slug
logger = logging.getLogger(__name__)
class AgentOutputInput(BaseModel):
"""Input parameters for the agent_output tool."""
agent_name: str = ""
library_agent_id: str = ""
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
@field_validator(
"agent_name",
"library_agent_id",
"store_slug",
"execution_id",
"run_time",
mode="before",
)
@classmethod
def strip_strings(cls, v: Any) -> Any:
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else v
def parse_time_expression(
time_expr: str | None,
) -> tuple[datetime | None, datetime | None]:
"""
Parse time expression into datetime range (start, end).
Supports:
- "latest" or None -> returns (None, None) to get most recent
- "yesterday" -> 24h window for yesterday
- "today" -> Today from midnight
- "last week" / "last 7 days" -> 7 day window
- "last month" / "last 30 days" -> 30 day window
- ISO date "YYYY-MM-DD" -> 24h window for that date
"""
if not time_expr or time_expr.lower() == "latest":
return None, None
now = datetime.now(timezone.utc)
expr = time_expr.lower().strip()
# Relative expressions
if expr == "yesterday":
end = now.replace(hour=0, minute=0, second=0, microsecond=0)
start = end - timedelta(days=1)
return start, end
if expr in ("last week", "last 7 days"):
return now - timedelta(days=7), now
if expr in ("last month", "last 30 days"):
return now - timedelta(days=30), now
if expr == "today":
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
return start, now
# Try ISO date format (YYYY-MM-DD)
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
if date_match:
year, month, day = map(int, date_match.groups())
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
end = start + timedelta(days=1)
return start, end
# Try ISO datetime
try:
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
# Return +/- 1 hour window around the specified time
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
except ValueError:
pass
# Fallback: treat as "latest"
return None, None
class AgentOutputTool(BaseTool):
"""Tool for retrieving execution outputs from user's library agents."""
@property
def name(self) -> str:
return "agent_output"
@property
def description(self) -> str:
return """Retrieve execution outputs from agents in the user's library.
Identify the agent using one of:
- agent_name: Fuzzy search in user's library
- library_agent_id: Exact library agent ID
- store_slug: Marketplace format 'username/agent-name'
Select which run to retrieve using:
- execution_id: Specific execution ID
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
"""
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_name": {
"type": "string",
"description": "Agent name to search for in user's library (fuzzy match)",
},
"library_agent_id": {
"type": "string",
"description": "Exact library agent ID",
},
"store_slug": {
"type": "string",
"description": "Marketplace identifier: 'username/agent-slug'",
},
"execution_id": {
"type": "string",
"description": "Specific execution ID to retrieve",
},
"run_time": {
"type": "string",
"description": (
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _resolve_agent(
self,
user_id: str,
agent_name: str | None,
library_agent_id: str | None,
store_slug: str | None,
) -> tuple[LibraryAgent | None, str | None]:
"""
Resolve agent from provided identifiers.
Returns (library_agent, error_message).
"""
# Priority 1: Exact library agent ID
if library_agent_id:
try:
agent = await library_db.get_library_agent(library_agent_id, user_id)
return agent, None
except Exception as e:
logger.warning(f"Failed to get library agent by ID: {e}")
return None, f"Library agent '{library_agent_id}' not found"
# Priority 2: Store slug (username/agent-name)
if store_slug and "/" in store_slug:
username, agent_slug = store_slug.split("/", 1)
graph, _ = await fetch_graph_from_store_slug(username, agent_slug)
if not graph:
return None, f"Agent '{store_slug}' not found in marketplace"
# Find in user's library by graph_id
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
if not agent:
return (
None,
f"Agent '{store_slug}' is not in your library. "
"Add it first to see outputs.",
)
return agent, None
# Priority 3: Fuzzy name search in library
if agent_name:
try:
response = await library_db.list_library_agents(
user_id=user_id,
search_term=agent_name,
page_size=5,
)
if not response.agents:
return (
None,
f"No agents matching '{agent_name}' found in your library",
)
# Return best match (first result from search)
return response.agents[0], None
except Exception as e:
logger.error(f"Error searching library agents: {e}")
return None, f"Error searching for agent: {e}"
return (
None,
"Please specify an agent name, library_agent_id, or store_slug",
)
async def _get_execution(
self,
user_id: str,
graph_id: str,
execution_id: str | None,
time_start: datetime | None,
time_end: datetime | None,
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
"""
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
"""
# If specific execution_id provided, fetch it directly
if execution_id:
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not execution:
return None, [], f"Execution '{execution_id}' not found"
return execution, [], None
# Get completed executions with time filters
executions = await execution_db.get_graph_executions(
graph_id=graph_id,
user_id=user_id,
statuses=[ExecutionStatus.COMPLETED],
created_time_gte=time_start,
created_time_lte=time_end,
limit=10,
)
if not executions:
return None, [], None # No error, just no executions
# If only one execution, fetch full details
if len(executions) == 1:
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
)
return full_execution, [], None
# Multiple executions - return latest with full details, plus list of available
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
)
return full_execution, executions, None
def _build_response(
self,
agent: LibraryAgent,
execution: GraphExecution | None,
available_executions: list[GraphExecutionMeta],
session_id: str | None,
) -> AgentOutputResponse:
"""Build the response based on execution data."""
library_agent_link = f"/library/agents/{agent.id}"
if not execution:
return AgentOutputResponse(
message=f"No completed executions found for agent '{agent.name}'",
session_id=session_id,
agent_name=agent.name,
agent_id=agent.graph_id,
library_agent_id=agent.id,
library_agent_link=library_agent_link,
total_executions=0,
)
execution_info = ExecutionOutputInfo(
execution_id=execution.id,
status=execution.status.value,
started_at=execution.started_at,
ended_at=execution.ended_at,
outputs=dict(execution.outputs),
inputs_summary=execution.inputs if execution.inputs else None,
)
available_list = None
if len(available_executions) > 1:
available_list = [
{
"id": e.id,
"status": e.status.value,
"started_at": e.started_at.isoformat() if e.started_at else None,
}
for e in available_executions[:5]
]
message = f"Found execution outputs for agent '{agent.name}'"
if len(available_executions) > 1:
message += (
f". Showing latest of {len(available_executions)} matching executions."
)
return AgentOutputResponse(
message=message,
session_id=session_id,
agent_name=agent.name,
agent_id=agent.graph_id,
library_agent_id=agent.id,
library_agent_link=library_agent_link,
execution=execution_info,
available_executions=available_list,
total_executions=len(available_executions) if available_executions else 1,
)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the agent_output tool."""
session_id = session.session_id
# Parse and validate input
try:
input_data = AgentOutputInput(**kwargs)
except Exception as e:
logger.error(f"Invalid input: {e}")
return ErrorResponse(
message="Invalid input parameters",
error=str(e),
session_id=session_id,
)
# Ensure user_id is present (should be guaranteed by requires_auth)
if not user_id:
return ErrorResponse(
message="User authentication required",
session_id=session_id,
)
# Check if at least one identifier is provided
if not any(
[
input_data.agent_name,
input_data.library_agent_id,
input_data.store_slug,
input_data.execution_id,
]
):
return ErrorResponse(
message=(
"Please specify at least one of: agent_name, "
"library_agent_id, store_slug, or execution_id"
),
session_id=session_id,
)
# If only execution_id provided, we need to find the agent differently
if (
input_data.execution_id
and not input_data.agent_name
and not input_data.library_agent_id
and not input_data.store_slug
):
# Fetch execution directly to get graph_id
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=input_data.execution_id,
include_node_executions=False,
)
if not execution:
return ErrorResponse(
message=f"Execution '{input_data.execution_id}' not found",
session_id=session_id,
)
# Find library agent by graph_id
agent = await library_db.get_library_agent_by_graph_id(
user_id, execution.graph_id
)
if not agent:
return NoResultsResponse(
message=(
f"Execution found but agent not in your library. "
f"Graph ID: {execution.graph_id}"
),
session_id=session_id,
suggestions=["Add the agent to your library to see more details"],
)
return self._build_response(agent, execution, [], session_id)
# Resolve agent from identifiers
agent, error = await self._resolve_agent(
user_id=user_id,
agent_name=input_data.agent_name or None,
library_agent_id=input_data.library_agent_id or None,
store_slug=input_data.store_slug or None,
)
if error or not agent:
return NoResultsResponse(
message=error or "Agent not found",
session_id=session_id,
suggestions=[
"Check the agent name or ID",
"Make sure the agent is in your library",
],
)
# Parse time expression
time_start, time_end = parse_time_expression(input_data.run_time)
# Fetch execution(s)
execution, available_executions, exec_error = await self._get_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=input_data.execution_id or None,
time_start=time_start,
time_end=time_end,
)
if exec_error:
return ErrorResponse(
message=exec_error,
session_id=session_id,
)
return self._build_response(agent, execution, available_executions, session_id)

File diff suppressed because one or more lines are too long

View File

@@ -1,279 +0,0 @@
"""CreateAgentTool - Creates agents from natural language descriptions."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
apply_all_fixes,
decompose_goal,
generate_agent,
get_blocks_info,
save_agent_to_library,
validate_agent,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
# Maximum retries for agent generation with validation feedback
MAX_GENERATION_RETRIES = 2
class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions."""
@property
def name(self) -> str:
return "create_agent"
@property
def description(self) -> str:
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"description": {
"type": "string",
"description": (
"Natural language description of what the agent should do. "
"Be specific about inputs, outputs, and the workflow steps."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions. "
"Include any preferences or constraints mentioned by the user."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the agent to the user's library. "
"Default is true. Set to false for preview only."
),
"default": True,
},
},
"required": ["description"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the create_agent tool.
Flow:
1. Decompose the description into steps (may return clarifying questions)
2. Generate agent JSON from the steps
3. Apply fixes to correct common LLM errors
4. Preview or save based on the save parameter
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
error="Missing description parameter",
session_id=session_id,
)
# Step 1: Decompose goal into steps
try:
decomposition_result = await decompose_goal(description, context)
except ValueError as e:
# Handle missing API key or configuration errors
return ErrorResponse(
message=f"Agent generation is not configured: {str(e)}",
error="configuration_error",
session_id=session_id,
)
if decomposition_result is None:
return ErrorResponse(
message="Failed to analyze the goal. Please try rephrasing.",
error="Decomposition failed",
session_id=session_id,
)
# Check if LLM returned clarifying questions
if decomposition_result.get("type") == "clarifying_questions":
questions = decomposition_result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information to create this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
# Check for unachievable/vague goals
if decomposition_result.get("type") == "unachievable_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get("reason", "")
return ErrorResponse(
message=(
f"This goal cannot be accomplished with the available blocks. "
f"{reason} "
f"Suggestion: {suggested}"
),
error="unachievable_goal",
details={"suggested_goal": suggested, "reason": reason},
session_id=session_id,
)
if decomposition_result.get("type") == "vague_goal":
suggested = decomposition_result.get("suggested_goal", "")
return ErrorResponse(
message=(
f"The goal is too vague to create a specific workflow. "
f"Suggestion: {suggested}"
),
error="vague_goal",
details={"suggested_goal": suggested},
session_id=session_id,
)
# Step 2: Generate agent JSON with retry on validation failure
blocks_info = get_blocks_info()
agent_json = None
validation_errors = None
for attempt in range(MAX_GENERATION_RETRIES + 1):
# Generate agent (include validation errors from previous attempt)
if attempt == 0:
agent_json = await generate_agent(decomposition_result)
else:
# Retry with validation error feedback
logger.info(
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
)
retry_instructions = {
**decomposition_result,
"previous_errors": validation_errors,
"retry_instructions": (
"The previous generation had validation errors. "
"Please fix these issues in the new generation:\n"
f"{validation_errors}"
),
}
agent_json = await generate_agent(retry_instructions)
if agent_json is None:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message="Failed to generate the agent. Please try again.",
error="Generation failed",
session_id=session_id,
)
continue
# Step 3: Apply fixes to correct common errors
agent_json = apply_all_fixes(agent_json, blocks_info)
# Step 4: Validate the agent
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
if is_valid:
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
break
logger.warning(
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
)
if attempt == MAX_GENERATION_RETRIES:
# Return error with validation details
return ErrorResponse(
message=(
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
f"Please try rephrasing your request or simplify the workflow."
),
error="validation_failed",
details={"validation_errors": validation_errors},
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Step 4: Preview or save
if not save:
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
f"Review it and call create_agent with save=true to save it to your library."
),
agent_json=agent_json,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
# Save to library
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id
)
return AgentSavedResponse(
message=f"Agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

File diff suppressed because one or more lines are too long

View File

@@ -1,294 +0,0 @@
"""EditAgentTool - Edits existing agents using natural language."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
apply_agent_patch,
apply_all_fixes,
generate_agent_patch,
get_agent_as_json,
get_blocks_info,
save_agent_to_library,
validate_agent,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
# Maximum retries for patch generation with validation feedback
MAX_GENERATION_RETRIES = 2
class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language."""
@property
def name(self) -> str:
return "edit_agent"
@property
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates a patch to update the agent while preserving unchanged parts."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"description": (
"The ID of the agent to edit. "
"Can be a graph ID or library agent ID."
),
},
"changes": {
"type": "string",
"description": (
"Natural language description of what changes to make. "
"Be specific about what to add, remove, or modify."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the changes. "
"Default is true. Set to false for preview only."
),
"default": True,
},
},
"required": ["agent_id", "changes"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the edit_agent tool.
Flow:
1. Fetch the current agent
2. Generate a patch based on the requested changes
3. Apply the patch to create an updated agent
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
error="Missing agent_id parameter",
session_id=session_id,
)
if not changes:
return ErrorResponse(
message="Please describe what changes you want to make.",
error="Missing changes parameter",
session_id=session_id,
)
# Step 1: Fetch current agent
current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None:
return ErrorResponse(
message=f"Could not find agent with ID '{agent_id}' in your library.",
error="agent_not_found",
session_id=session_id,
)
# Build the update request with context
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
# Step 2: Generate patch with retry on validation failure
blocks_info = get_blocks_info()
updated_agent = None
validation_errors = None
intent = "Applied requested changes"
for attempt in range(MAX_GENERATION_RETRIES + 1):
# Generate patch (include validation errors from previous attempt)
try:
if attempt == 0:
patch_result = await generate_agent_patch(
update_request, current_agent
)
else:
# Retry with validation error feedback
logger.info(
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
)
retry_request = (
f"{update_request}\n\n"
f"IMPORTANT: The previous edit had validation errors. "
f"Please fix these issues:\n{validation_errors}"
)
patch_result = await generate_agent_patch(
retry_request, current_agent
)
except ValueError as e:
# Handle missing API key or configuration errors
return ErrorResponse(
message=f"Agent generation is not configured: {str(e)}",
error="configuration_error",
session_id=session_id,
)
if patch_result is None:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message="Failed to generate changes. Please try rephrasing.",
error="Patch generation failed",
session_id=session_id,
)
continue
# Check if LLM returned clarifying questions
if patch_result.get("type") == "clarifying_questions":
questions = patch_result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
# Step 3: Apply patch and fixes
try:
updated_agent = apply_agent_patch(current_agent, patch_result)
updated_agent = apply_all_fixes(updated_agent, blocks_info)
except Exception as e:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message=f"Failed to apply changes: {str(e)}",
error="patch_apply_failed",
details={"exception": str(e)},
session_id=session_id,
)
validation_errors = str(e)
continue
# Step 4: Validate the updated agent
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
if is_valid:
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
intent = patch_result.get("intent", "Applied requested changes")
break
logger.warning(
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
)
if attempt == MAX_GENERATION_RETRIES:
# Return error with validation details
return ErrorResponse(
message=(
f"Updated agent has validation errors after "
f"{MAX_GENERATION_RETRIES + 1} attempts. "
f"Please try rephrasing your request or simplify the changes."
),
error="validation_failed",
details={"validation_errors": validation_errors},
session_id=session_id,
)
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
assert updated_agent is not None
agent_name = updated_agent.get("name", "Updated Agent")
agent_description = updated_agent.get("description", "")
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
# Step 5: Preview or save
if not save:
return AgentPreviewResponse(
message=(
f"I've updated the agent. Changes: {intent}. "
f"The agent now has {node_count} blocks. "
f"Review it and call edit_agent with save=true to save the changes."
),
agent_json=updated_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
# Save to library (creates a new version)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
updated_agent, user_id, is_update=True
)
return AgentSavedResponse(
message=(
f"Updated agent '{created_graph.name}' has been saved to your library! "
f"Changes: {intent}"
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to save the updated agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

View File

@@ -1,253 +0,0 @@
"""Tool for searching available blocks using hybrid search."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.blocks import load_all_blocks
from .base import BaseTool
from .models import (
BlockInfoSummary,
BlockListResponse,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
from .search_blocks import get_block_search_index
logger = logging.getLogger(__name__)
class FindBlockTool(BaseTool):
"""Tool for searching available blocks."""
@property
def name(self) -> str:
return "find_block"
@property
def description(self) -> str:
return (
"Search for available blocks by name or description. "
"Blocks are reusable components that perform specific tasks like "
"sending emails, making API calls, processing text, etc. "
"Use this to find blocks that can be executed directly."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"Search query to find blocks by name or description. "
"Use keywords like 'email', 'http', 'text', 'ai', etc."
),
},
},
"required": ["query"],
}
@property
def requires_auth(self) -> bool:
return True
def _matches_query(self, block, query: str) -> tuple[int, bool]:
"""
Check if a block matches the query and return a priority score.
Returns (priority, matches) where:
- priority 0: exact name match
- priority 1: name contains query
- priority 2: description contains query
- priority 3: category contains query
"""
query_lower = query.lower()
name_lower = block.name.lower()
desc_lower = block.description.lower()
# Exact name match
if query_lower == name_lower:
return 0, True
# Name contains query
if query_lower in name_lower:
return 1, True
# Description contains query
if query_lower in desc_lower:
return 2, True
# Category contains query
for category in block.categories:
if query_lower in category.name.lower():
return 3, True
return 4, False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Search for blocks matching the query.
Args:
user_id: User ID (required)
session: Chat session
query: Search query
Returns:
BlockListResponse: List of matching blocks
NoResultsResponse: No blocks found
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
session_id = session.session_id
if not query:
return ErrorResponse(
message="Please provide a search query",
session_id=session_id,
)
try:
# Try hybrid search first
search_results = self._hybrid_search(query)
if search_results is not None:
# Hybrid search succeeded
if not search_results:
return NoResultsResponse(
message=f"No blocks found matching '{query}'",
session_id=session_id,
suggestions=[
"Try more general terms",
"Search by category: ai, text, social, search, etc.",
"Check block names like 'SendEmail', 'HttpRequest', etc.",
],
)
# Get full block info for each result
all_blocks = load_all_blocks()
blocks = []
for result in search_results:
block_cls = all_blocks.get(result.block_id)
if block_cls:
block = block_cls()
blocks.append(
BlockInfoSummary(
id=block.id,
name=block.name,
description=block.description,
categories=[cat.name for cat in block.categories],
input_schema=block.input_schema.jsonschema(),
output_schema=block.output_schema.jsonschema(),
)
)
return BlockListResponse(
message=(
f"Found {len(blocks)} block{'s' if len(blocks) != 1 else ''} "
f"matching '{query}'. Use run_block to execute a block with "
"the required inputs."
),
blocks=blocks,
count=len(blocks),
query=query,
session_id=session_id,
)
# Fallback to simple search if hybrid search failed
return self._simple_search(query, session_id)
except Exception as e:
logger.error(f"Error searching blocks: {e}", exc_info=True)
return ErrorResponse(
message="Failed to search blocks. Please try again.",
error=str(e),
session_id=session_id,
)
def _hybrid_search(self, query: str) -> list | None:
"""
Perform hybrid search using embeddings and BM25.
Returns:
List of BlockSearchResult if successful, None if index not available
"""
try:
index = get_block_search_index()
if not index.load():
logger.info(
"Block search index not available, falling back to simple search"
)
return None
results = index.search(query, top_k=10)
logger.info(f"Hybrid search found {len(results)} blocks for: {query}")
return results
except Exception as e:
logger.warning(f"Hybrid search failed, falling back to simple: {e}")
return None
def _simple_search(self, query: str, session_id: str) -> ToolResponseBase:
"""Fallback simple search using substring matching."""
all_blocks = load_all_blocks()
logger.info(f"Simple searching {len(all_blocks)} blocks for: {query}")
# Find matching blocks with priority scores
matches: list[tuple[int, Any]] = []
for block_id, block_cls in all_blocks.items():
block = block_cls()
priority, is_match = self._matches_query(block, query)
if is_match:
matches.append((priority, block))
# Sort by priority (lower is better)
matches.sort(key=lambda x: x[0])
# Take top 10 results
top_matches = [block for _, block in matches[:10]]
if not top_matches:
return NoResultsResponse(
message=f"No blocks found matching '{query}'",
session_id=session_id,
suggestions=[
"Try more general terms",
"Search by category: ai, text, social, search, etc.",
"Check block names like 'SendEmail', 'HttpRequest', etc.",
],
)
# Build response
blocks = []
for block in top_matches:
blocks.append(
BlockInfoSummary(
id=block.id,
name=block.name,
description=block.description,
categories=[cat.name for cat in block.categories],
input_schema=block.input_schema.jsonschema(),
output_schema=block.output_schema.jsonschema(),
)
)
return BlockListResponse(
message=(
f"Found {len(blocks)} block{'s' if len(blocks) != 1 else ''} "
f"matching '{query}'. Use run_block to execute a block with "
"the required inputs."
),
blocks=blocks,
count=len(blocks),
query=query,
session_id=session_id,
)

View File

@@ -1,157 +0,0 @@
"""Tool for searching agents in the user's library."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.util.exceptions import DatabaseError
from .base import BaseTool
from .models import (
AgentCarouselResponse,
AgentInfo,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class FindLibraryAgentTool(BaseTool):
"""Tool for searching agents in the user's library."""
@property
def name(self) -> str:
return "find_library_agent"
@property
def description(self) -> str:
return (
"Search for agents in the user's library. Use this to find agents "
"the user has already added to their library, including agents they "
"created or added from the marketplace."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"Search query to find agents by name or description. "
"Use keywords for best results."
),
},
},
"required": ["query"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Search for agents in the user's library.
Args:
user_id: User ID (required)
session: Chat session
query: Search query
Returns:
AgentCarouselResponse: List of agents found in the library
NoResultsResponse: No agents found
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
session_id = session.session_id
if not query:
return ErrorResponse(
message="Please provide a search query",
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="User authentication required to search library",
session_id=session_id,
)
agents = []
try:
logger.info(f"Searching user library for: {query}")
library_results = await library_db.list_library_agents(
user_id=user_id,
search_term=query,
page_size=10,
)
logger.info(
f"Find library agents tool found {len(library_results.agents)} agents"
)
for agent in library_results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
),
)
except DatabaseError as e:
logger.error(f"Error searching library agents: {e}", exc_info=True)
return ErrorResponse(
message="Failed to search library. Please try again.",
error=str(e),
session_id=session_id,
)
if not agents:
return NoResultsResponse(
message=(
f"No agents found matching '{query}' in your library. "
"Try different keywords or use find_agent to search the marketplace."
),
session_id=session_id,
suggestions=[
"Try more general terms",
"Use find_agent to search the marketplace",
"Check your library at /library",
],
)
title = (
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
f"in your library for '{query}'"
)
return AgentCarouselResponse(
message=(
"Found agents in the user's library. You can provide a link to "
"view an agent at: /library/agents/{agent_id}. "
"Use agent_output to get execution results, or run_agent to execute."
),
title=title,
agents=agents,
count=len(agents),
session_id=session_id,
)

View File

@@ -1,483 +0,0 @@
#!/usr/bin/env python3
"""
Block Indexer for Hybrid Search
Creates a hybrid search index from blocks:
- OpenAI embeddings (text-embedding-3-small)
- BM25 index for lexical search
- Name index for title matching boost
Supports incremental updates by tracking content hashes.
Usage:
python -m backend.server.v2.chat.tools.index_blocks [--force]
"""
import argparse
import base64
import hashlib
import json
import logging
import os
import re
import sys
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import numpy as np
logger = logging.getLogger(__name__)
# Check for OpenAI availability
try:
import openai # noqa: F401
HAS_OPENAI = True
except ImportError:
HAS_OPENAI = False
print("Warning: openai not installed. Run: pip install openai")
# Default embedding model (OpenAI)
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
DEFAULT_EMBEDDING_DIM = 1536
# Output path (relative to this file)
INDEX_PATH = Path(__file__).parent / "blocks_index.json"
# Stopwords for tokenization
STOPWORDS = {
"the",
"a",
"an",
"is",
"are",
"was",
"were",
"be",
"been",
"being",
"have",
"has",
"had",
"do",
"does",
"did",
"will",
"would",
"could",
"should",
"may",
"might",
"must",
"shall",
"can",
"need",
"dare",
"ought",
"used",
"to",
"of",
"in",
"for",
"on",
"with",
"at",
"by",
"from",
"as",
"into",
"through",
"during",
"before",
"after",
"above",
"below",
"between",
"under",
"again",
"further",
"then",
"once",
"and",
"but",
"or",
"nor",
"so",
"yet",
"both",
"either",
"neither",
"not",
"only",
"own",
"same",
"than",
"too",
"very",
"just",
"also",
"now",
"here",
"there",
"when",
"where",
"why",
"how",
"all",
"each",
"every",
"few",
"more",
"most",
"other",
"some",
"such",
"no",
"any",
"this",
"that",
"these",
"those",
"it",
"its",
"block", # Too common in block context
}
def tokenize(text: str) -> list[str]:
"""Simple tokenizer for BM25."""
text = text.lower()
# Remove code blocks if any
text = re.sub(r"```[\s\S]*?```", "", text)
text = re.sub(r"`[^`]+`", "", text)
# Extract words (including camelCase split)
# First, split camelCase
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
# Extract words
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
# Remove very short words and stopwords
return [w for w in words if len(w) > 2 and w not in STOPWORDS]
def build_searchable_text(block: Any) -> str:
"""Build searchable text from block attributes."""
parts = []
# Block name (split camelCase for better tokenization)
name = block.name
# Split camelCase: GetCurrentTimeBlock -> Get Current Time Block
name_split = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
parts.append(name_split)
# Description
if block.description:
parts.append(block.description)
# Categories
for category in block.categories:
parts.append(category.name)
# Input schema field names and descriptions
try:
input_schema = block.input_schema.jsonschema()
if "properties" in input_schema:
for field_name, field_info in input_schema["properties"].items():
parts.append(field_name)
if "description" in field_info:
parts.append(field_info["description"])
except Exception:
pass
# Output schema field names
try:
output_schema = block.output_schema.jsonschema()
if "properties" in output_schema:
for field_name in output_schema["properties"]:
parts.append(field_name)
except Exception:
pass
return " ".join(parts)
def compute_content_hash(text: str) -> str:
"""Compute MD5 hash of text for change detection."""
return hashlib.md5(text.encode()).hexdigest()
def load_existing_index(index_path: Path) -> dict[str, Any] | None:
"""Load existing index if present."""
if not index_path.exists():
return None
try:
with open(index_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load existing index: {e}")
return None
def create_embeddings(
texts: list[str],
model_name: str = DEFAULT_EMBEDDING_MODEL,
batch_size: int = 100,
) -> np.ndarray:
"""Create embeddings using OpenAI API."""
if not HAS_OPENAI:
raise RuntimeError("openai not installed. Run: pip install openai")
# Import here to satisfy type checker
from openai import OpenAI
# Check for API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = OpenAI(api_key=api_key)
embeddings = []
print(f"Creating embeddings for {len(texts)} texts using {model_name}...")
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
# Truncate texts to max token limit (8191 tokens for text-embedding-3-small)
# Roughly 4 chars per token, so ~32000 chars max
batch = [text[:32000] for text in batch]
response = client.embeddings.create(
model=model_name,
input=batch,
)
for embedding_data in response.data:
embeddings.append(embedding_data.embedding)
print(f" Processed {min(i + batch_size, len(texts))}/{len(texts)} texts")
return np.array(embeddings, dtype=np.float32)
def build_bm25_data(
blocks_data: list[dict[str, Any]],
) -> dict[str, Any]:
"""Build BM25 metadata from block data."""
# Tokenize all searchable texts
tokenized_docs = []
for block in blocks_data:
tokens = tokenize(block["searchable_text"])
tokenized_docs.append(tokens)
# Calculate document frequencies
doc_freq: dict[str, int] = {}
for tokens in tokenized_docs:
seen = set()
for token in tokens:
if token not in seen:
doc_freq[token] = doc_freq.get(token, 0) + 1
seen.add(token)
n_docs = len(tokenized_docs)
doc_lens = [len(d) for d in tokenized_docs]
avgdl = sum(doc_lens) / max(n_docs, 1)
return {
"n_docs": n_docs,
"avgdl": avgdl,
"df": doc_freq,
"doc_lens": doc_lens,
}
def build_name_index(
blocks_data: list[dict[str, Any]],
) -> dict[str, list[list[int | float]]]:
"""Build inverted index for name search boost."""
index: dict[str, list[list[int | float]]] = defaultdict(list)
for idx, block in enumerate(blocks_data):
# Tokenize block name
name_tokens = tokenize(block["name"])
seen = set()
for i, token in enumerate(name_tokens):
if token in seen:
continue
seen.add(token)
# Score: first token gets higher weight
score = 1.5 if i == 0 else 1.0
index[token].append([idx, score])
return dict(index)
def build_block_index(
force_rebuild: bool = False,
output_path: Path = INDEX_PATH,
) -> dict[str, Any]:
"""
Build the block search index.
Args:
force_rebuild: If True, rebuild all embeddings even if unchanged
output_path: Path to save the index
Returns:
The generated index dictionary
"""
# Import here to avoid circular imports
from backend.blocks import load_all_blocks
print("Loading all blocks...")
all_blocks = load_all_blocks()
print(f"Found {len(all_blocks)} blocks")
# Load existing index for incremental updates
existing_index = None if force_rebuild else load_existing_index(output_path)
existing_blocks: dict[str, dict[str, Any]] = {}
if existing_index:
print(
f"Loaded existing index with {len(existing_index.get('blocks', []))} blocks"
)
for block in existing_index.get("blocks", []):
existing_blocks[block["id"]] = block
# Process each block
blocks_data: list[dict[str, Any]] = []
blocks_needing_embedding: list[tuple[int, str]] = [] # (index, searchable_text)
for block_id, block_cls in all_blocks.items():
try:
block = block_cls()
# Skip disabled blocks
if block.disabled:
continue
searchable_text = build_searchable_text(block)
content_hash = compute_content_hash(searchable_text)
block_data = {
"id": block.id,
"name": block.name,
"description": block.description,
"categories": [cat.name for cat in block.categories],
"searchable_text": searchable_text,
"content_hash": content_hash,
"emb": None, # Will be filled later
}
# Check if we can reuse existing embedding
if (
block.id in existing_blocks
and existing_blocks[block.id].get("content_hash") == content_hash
and existing_blocks[block.id].get("emb")
):
# Reuse existing embedding
block_data["emb"] = existing_blocks[block.id]["emb"]
else:
# Need new embedding
blocks_needing_embedding.append((len(blocks_data), searchable_text))
blocks_data.append(block_data)
except Exception as e:
logger.warning(f"Failed to process block {block_id}: {e}")
continue
print(f"Processed {len(blocks_data)} blocks")
print(f"Blocks needing new embeddings: {len(blocks_needing_embedding)}")
# Create embeddings for new/changed blocks
if blocks_needing_embedding and HAS_OPENAI:
texts_to_embed = [text for _, text in blocks_needing_embedding]
try:
embeddings = create_embeddings(texts_to_embed)
# Assign embeddings to blocks
for i, (block_idx, _) in enumerate(blocks_needing_embedding):
emb = embeddings[i].astype(np.float32)
# Encode as base64
blocks_data[block_idx]["emb"] = base64.b64encode(emb.tobytes()).decode(
"ascii"
)
except Exception as e:
print(f"Warning: Failed to create embeddings: {e}")
elif blocks_needing_embedding:
print(
"Warning: Cannot create embeddings (openai not installed or OPENAI_API_KEY not set)"
)
# Build BM25 data
print("Building BM25 index...")
bm25_data = build_bm25_data(blocks_data)
# Build name index
print("Building name index...")
name_index = build_name_index(blocks_data)
# Build final index
index = {
"version": "1.0.0",
"embedding_model": DEFAULT_EMBEDDING_MODEL,
"embedding_dim": DEFAULT_EMBEDDING_DIM,
"generated_at": datetime.now(timezone.utc).isoformat(),
"blocks": blocks_data,
"bm25": bm25_data,
"name_index": name_index,
}
# Save index
print(f"Saving index to {output_path}...")
with open(output_path, "w", encoding="utf-8") as f:
json.dump(index, f, separators=(",", ":"))
size_kb = output_path.stat().st_size / 1024
print(f"Index saved ({size_kb:.1f} KB)")
# Print statistics
print("\nIndex Statistics:")
print(f" Blocks indexed: {len(blocks_data)}")
print(f" BM25 vocabulary size: {len(bm25_data['df'])}")
print(f" Name index terms: {len(name_index)}")
print(f" Embeddings: {'Yes' if any(b.get('emb') for b in blocks_data) else 'No'}")
return index
def main():
parser = argparse.ArgumentParser(description="Build hybrid search index for blocks")
parser.add_argument(
"--force",
action="store_true",
help="Force rebuild all embeddings even if unchanged",
)
parser.add_argument(
"--output",
type=Path,
default=INDEX_PATH,
help=f"Output index file path (default: {INDEX_PATH})",
)
args = parser.parse_args()
try:
build_block_index(
force_rebuild=args.force,
output_path=args.output,
)
except Exception as e:
print(f"Error building index: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,5 @@
"""Pydantic models for tool responses."""
from datetime import datetime
from enum import Enum
from typing import Any
@@ -20,15 +19,6 @@ class ResponseType(str, Enum):
ERROR = "error"
NO_RESULTS = "no_results"
SUCCESS = "success"
DOC_SEARCH_RESULTS = "doc_search_results"
AGENT_OUTPUT = "agent_output"
BLOCK_LIST = "block_list"
BLOCK_OUTPUT = "block_output"
UNDERSTANDING_UPDATED = "understanding_updated"
# Agent generation responses
AGENT_PREVIEW = "agent_preview"
AGENT_SAVED = "agent_saved"
CLARIFICATION_NEEDED = "clarification_needed"
# Base response model
@@ -183,128 +173,3 @@ class ErrorResponse(ToolResponseBase):
type: ResponseType = ResponseType.ERROR
error: str | None = None
details: dict[str, Any] | None = None
# Documentation search models
class DocSearchResult(BaseModel):
"""A single documentation search result."""
title: str
path: str
section: str
snippet: str # Short excerpt for UI display
content: str # Full text content for LLM to read and understand
score: float
doc_url: str | None = None
class DocSearchResultsResponse(ToolResponseBase):
"""Response for search_docs tool."""
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
results: list[DocSearchResult]
count: int
query: str
# Agent output models
class ExecutionOutputInfo(BaseModel):
"""Summary of a single execution's outputs."""
execution_id: str
status: str
started_at: datetime | None = None
ended_at: datetime | None = None
outputs: dict[str, list[Any]]
inputs_summary: dict[str, Any] | None = None
class AgentOutputResponse(ToolResponseBase):
"""Response for agent_output tool."""
type: ResponseType = ResponseType.AGENT_OUTPUT
agent_name: str
agent_id: str
library_agent_id: str | None = None
library_agent_link: str | None = None
execution: ExecutionOutputInfo | None = None
available_executions: list[dict[str, Any]] | None = None
total_executions: int = 0
# Block models
class BlockInfoSummary(BaseModel):
"""Summary of a block for search results."""
id: str
name: str
description: str
categories: list[str]
input_schema: dict[str, Any]
output_schema: dict[str, Any]
class BlockListResponse(ToolResponseBase):
"""Response for find_block tool."""
type: ResponseType = ResponseType.BLOCK_LIST
blocks: list[BlockInfoSummary]
count: int
query: str
class BlockOutputResponse(ToolResponseBase):
"""Response for run_block tool."""
type: ResponseType = ResponseType.BLOCK_OUTPUT
block_id: str
block_name: str
outputs: dict[str, list[Any]]
success: bool = True
# Business understanding models
class UnderstandingUpdatedResponse(ToolResponseBase):
"""Response for add_understanding tool."""
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
updated_fields: list[str] = Field(default_factory=list)
current_understanding: dict[str, Any] = Field(default_factory=dict)
# Agent generation models
class ClarifyingQuestion(BaseModel):
"""A question that needs user clarification."""
question: str
keyword: str
example: str | None = None
class AgentPreviewResponse(ToolResponseBase):
"""Response for previewing a generated agent before saving."""
type: ResponseType = ResponseType.AGENT_PREVIEW
agent_json: dict[str, Any]
agent_name: str
description: str
node_count: int
link_count: int = 0
class AgentSavedResponse(ToolResponseBase):
"""Response when an agent is saved to the library."""
type: ResponseType = ResponseType.AGENT_SAVED
agent_id: str
agent_name: str
library_agent_id: str
library_agent_link: str
agent_page_link: str # Link to the agent builder/editor page
class ClarificationNeededResponse(ToolResponseBase):
"""Response when the LLM needs more information from the user."""
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
questions: list[ClarifyingQuestion] = Field(default_factory=list)

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field, field_validator
from backend.api.features.chat.config import ChatConfig
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.data.user import get_user_by_id
@@ -58,7 +57,6 @@ class RunAgentInput(BaseModel):
"""Input parameters for the run_agent tool."""
username_agent_slug: str = ""
library_agent_id: str = ""
inputs: dict[str, Any] = Field(default_factory=dict)
use_defaults: bool = False
schedule_name: str = ""
@@ -66,12 +64,7 @@ class RunAgentInput(BaseModel):
timezone: str = "UTC"
@field_validator(
"username_agent_slug",
"library_agent_id",
"schedule_name",
"cron",
"timezone",
mode="before",
"username_agent_slug", "schedule_name", "cron", "timezone", mode="before"
)
@classmethod
def strip_strings(cls, v: Any) -> Any:
@@ -97,7 +90,7 @@ class RunAgentTool(BaseTool):
@property
def description(self) -> str:
return """Run or schedule an agent from the marketplace or user's library.
return """Run or schedule an agent from the marketplace.
The tool automatically handles the setup flow:
- Returns missing inputs if required fields are not provided
@@ -105,10 +98,6 @@ class RunAgentTool(BaseTool):
- Executes immediately if all requirements are met
- Schedules execution if cron expression is provided
Identify the agent using either:
- username_agent_slug: Marketplace format 'username/agent-name'
- library_agent_id: ID of an agent in the user's library
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
@property
@@ -120,10 +109,6 @@ class RunAgentTool(BaseTool):
"type": "string",
"description": "Agent identifier in format 'username/agent-name'",
},
"library_agent_id": {
"type": "string",
"description": "Library agent ID from user's library",
},
"inputs": {
"type": "object",
"description": "Input values for the agent",
@@ -146,7 +131,7 @@ class RunAgentTool(BaseTool):
"description": "IANA timezone for schedule (default: UTC)",
},
},
"required": [],
"required": ["username_agent_slug"],
}
@property
@@ -164,16 +149,10 @@ class RunAgentTool(BaseTool):
params = RunAgentInput(**kwargs)
session_id = session.session_id
# Validate at least one identifier is provided
has_slug = params.username_agent_slug and "/" in params.username_agent_slug
has_library_id = bool(params.library_agent_id)
if not has_slug and not has_library_id:
# Validate agent slug format
if not params.username_agent_slug or "/" not in params.username_agent_slug:
return ErrorResponse(
message=(
"Please provide either a username_agent_slug "
"(format 'username/agent-name') or a library_agent_id"
),
message="Please provide an agent slug in format 'username/agent-name'",
session_id=session_id,
)
@@ -188,41 +167,13 @@ class RunAgentTool(BaseTool):
is_schedule = bool(params.schedule_name or params.cron)
try:
# Step 1: Fetch agent details
graph: GraphModel | None = None
library_agent = None
# Priority: library_agent_id if provided
if has_library_id:
library_agent = await library_db.get_library_agent(
params.library_agent_id, user_id
)
if not library_agent:
return ErrorResponse(
message=f"Library agent '{params.library_agent_id}' not found",
session_id=session_id,
)
# Get the graph from the library agent
from backend.data.graph import get_graph
graph = await get_graph(
library_agent.graph_id,
library_agent.graph_version,
user_id=user_id,
)
else:
# Fetch from marketplace slug
username, agent_name = params.username_agent_slug.split("/", 1)
graph, _ = await fetch_graph_from_store_slug(username, agent_name)
# Step 1: Fetch agent details (always happens first)
username, agent_name = params.username_agent_slug.split("/", 1)
graph, store_agent = await fetch_graph_from_store_slug(username, agent_name)
if not graph:
identifier = (
params.library_agent_id
if has_library_id
else params.username_agent_slug
)
return ErrorResponse(
message=f"Agent '{identifier}' not found",
message=f"Agent '{params.username_agent_slug}' not found in marketplace",
session_id=session_id,
)

View File

@@ -1,287 +0,0 @@
"""Tool for executing blocks directly."""
import logging
from collections import defaultdict
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .models import (
BlockOutputResponse,
ErrorResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
logger = logging.getLogger(__name__)
class RunBlockTool(BaseTool):
"""Tool for executing a block and returning its outputs."""
@property
def name(self) -> str:
return "run_block"
@property
def description(self) -> str:
return (
"Execute a specific block with the provided input data. "
"Use find_block to discover available blocks and their input schemas. "
"The block will run and return its outputs once complete."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"block_id": {
"type": "string",
"description": "The UUID of the block to execute",
},
"input_data": {
"type": "object",
"description": (
"Input values for the block. Must match the block's input schema. "
"Check the block's input_schema from find_block for required fields."
),
},
},
"required": ["block_id", "input_data"],
}
@property
def requires_auth(self) -> bool:
return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
# field_info.provider is a frozenset of acceptable providers
# field_info.supported_types is a frozenset of acceptable types
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in field_info.provider
and cred.type in field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute a block with the given input data.
Args:
user_id: User ID (required)
session: Chat session
block_id: Block UUID to execute
input_data: Input values for the block
Returns:
BlockOutputResponse: Block execution outputs
SetupRequirementsResponse: Missing credentials
ErrorResponse: Error message
"""
block_id = kwargs.get("block_id", "").strip()
input_data = kwargs.get("input_data", {})
session_id = session.session_id
if not block_id:
return ErrorResponse(
message="Please provide a block_id",
session_id=session_id,
)
if not isinstance(input_data, dict):
return ErrorResponse(
message="input_data must be an object",
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
# Get the block
block = get_block(block_id)
if not block:
return ErrorResponse(
message=f"Block '{block_id}' not found",
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
# Check credentials
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block
)
if missing_credentials:
# Return setup requirements response with missing credentials
missing_creds_dict = {c.id: c.model_dump() for c in missing_credentials}
return SetupRequirementsResponse(
message=(
f"Block '{block.name}' requires credentials that are not configured. "
"Please set up the required credentials before running this block."
),
session_id=session_id,
setup_info=SetupInfo(
agent_id=block_id,
agent_name=block.name,
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_creds_dict,
ready_to_run=False,
),
requirements={
"credentials": [c.model_dump() for c in missing_credentials],
"inputs": self._get_inputs_list(block),
"execution_modes": ["immediate"],
},
),
graph_id=None,
graph_version=None,
)
try:
# Fetch actual credentials and prepare kwargs for block execution
exec_kwargs: dict[str, Any] = {"user_id": user_id}
for field_name, cred_meta in matched_credentials.items():
# Inject metadata into input_data (for validation)
if field_name not in input_data:
input_data[field_name] = cred_meta.model_dump()
# Fetch actual credentials and pass as kwargs (for execution)
actual_credentials = await creds_manager.get(
user_id, cred_meta.id, lock=False
)
if actual_credentials:
exec_kwargs[field_name] = actual_credentials
else:
return ErrorResponse(
message=f"Failed to retrieve credentials for {field_name}",
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
input_data,
**exec_kwargs,
):
outputs[output_name].append(output_data)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,
session_id=session_id,
)
except BlockError as e:
logger.warning(f"Block execution failed: {e}")
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to execute block: {str(e)}",
error=str(e),
session_id=session_id,
)
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
for field_name, field_schema in properties.items():
# Skip credential fields
if field_name in credentials_fields:
continue
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
return inputs_list

View File

@@ -1,460 +0,0 @@
"""
Block Hybrid Search
Combines multiple ranking signals for block search:
- Semantic search (OpenAI embeddings + cosine similarity)
- Lexical search (BM25)
- Name matching (boost for block name matches)
- Category matching (boost for category matches)
Based on the docs search implementation.
"""
import base64
import json
import logging
import math
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import numpy as np
logger = logging.getLogger(__name__)
# OpenAI embedding model
EMBEDDING_MODEL = "text-embedding-3-small"
# Path to the JSON index file
INDEX_PATH = Path(__file__).parent / "blocks_index.json"
# Stopwords for tokenization (same as index_blocks.py)
STOPWORDS = {
"the",
"a",
"an",
"is",
"are",
"was",
"were",
"be",
"been",
"being",
"have",
"has",
"had",
"do",
"does",
"did",
"will",
"would",
"could",
"should",
"may",
"might",
"must",
"shall",
"can",
"need",
"dare",
"ought",
"used",
"to",
"of",
"in",
"for",
"on",
"with",
"at",
"by",
"from",
"as",
"into",
"through",
"during",
"before",
"after",
"above",
"below",
"between",
"under",
"again",
"further",
"then",
"once",
"and",
"but",
"or",
"nor",
"so",
"yet",
"both",
"either",
"neither",
"not",
"only",
"own",
"same",
"than",
"too",
"very",
"just",
"also",
"now",
"here",
"there",
"when",
"where",
"why",
"how",
"all",
"each",
"every",
"few",
"more",
"most",
"other",
"some",
"such",
"no",
"any",
"this",
"that",
"these",
"those",
"it",
"its",
"block",
}
def tokenize(text: str) -> list[str]:
"""Simple tokenizer for search."""
text = text.lower()
# Remove code blocks if any
text = re.sub(r"```[\s\S]*?```", "", text)
text = re.sub(r"`[^`]+`", "", text)
# Split camelCase
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
# Extract words
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
# Remove very short words and stopwords
return [w for w in words if len(w) > 2 and w not in STOPWORDS]
@dataclass
class SearchWeights:
"""Configuration for hybrid search signal weights."""
semantic: float = 0.40 # Embedding similarity
bm25: float = 0.25 # Lexical matching
name_match: float = 0.25 # Block name matches
category_match: float = 0.10 # Category matches
@dataclass
class BlockSearchResult:
"""A single block search result."""
block_id: str
name: str
description: str
categories: list[str]
score: float
# Individual signal scores (for debugging)
semantic_score: float = 0.0
bm25_score: float = 0.0
name_score: float = 0.0
category_score: float = 0.0
class BlockSearchIndex:
"""Hybrid search index for blocks combining BM25 + embeddings."""
def __init__(self, index_path: Path = INDEX_PATH):
self.blocks: list[dict[str, Any]] = []
self.bm25_data: dict[str, Any] = {}
self.name_index: dict[str, list[list[int | float]]] = {}
self.embeddings: Optional[np.ndarray] = None
self.normalized_embeddings: Optional[np.ndarray] = None
self._loaded = False
self._index_path = index_path
self._embedding_model: Any = None
def load(self) -> bool:
"""Load the index from JSON file."""
if self._loaded:
return True
if not self._index_path.exists():
logger.warning(f"Block index not found at {self._index_path}")
return False
try:
with open(self._index_path, "r", encoding="utf-8") as f:
data = json.load(f)
self.blocks = data.get("blocks", [])
self.bm25_data = data.get("bm25", {})
self.name_index = data.get("name_index", {})
# Decode embeddings from base64
embeddings_list = []
for block in self.blocks:
if block.get("emb"):
emb_bytes = base64.b64decode(block["emb"])
emb = np.frombuffer(emb_bytes, dtype=np.float32)
embeddings_list.append(emb)
else:
# No embedding, use zeros
dim = data.get("embedding_dim", 384)
embeddings_list.append(np.zeros(dim, dtype=np.float32))
if embeddings_list:
self.embeddings = np.stack(embeddings_list)
# Precompute normalized embeddings for cosine similarity
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
self.normalized_embeddings = self.embeddings / (norms + 1e-10)
self._loaded = True
logger.info(f"Loaded block index with {len(self.blocks)} blocks")
return True
except Exception as e:
logger.error(f"Failed to load block index: {e}")
return False
def _get_openai_client(self) -> Any:
"""Get OpenAI client for query embedding."""
if self._embedding_model is None:
try:
from openai import OpenAI
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.warning("OPENAI_API_KEY not set")
return None
self._embedding_model = OpenAI(api_key=api_key)
except ImportError:
logger.warning("openai not installed")
return None
return self._embedding_model
def _embed_query(self, query: str) -> Optional[np.ndarray]:
"""Embed the search query using OpenAI."""
client = self._get_openai_client()
if client is None:
return None
try:
response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=query,
)
embedding = response.data[0].embedding
return np.array(embedding, dtype=np.float32)
except Exception as e:
logger.warning(f"Failed to embed query: {e}")
return None
def _compute_semantic_scores(self, query_embedding: np.ndarray) -> np.ndarray:
"""Compute cosine similarity between query and all blocks."""
if self.normalized_embeddings is None:
return np.zeros(len(self.blocks))
# Normalize query embedding
query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
# Cosine similarity via dot product
similarities = self.normalized_embeddings @ query_norm
# Scale to [0, 1] (cosine ranges from -1 to 1)
return (similarities + 1) / 2
def _compute_bm25_scores(self, query_tokens: list[str]) -> np.ndarray:
"""Compute BM25 scores for all blocks."""
scores = np.zeros(len(self.blocks))
if not self.bm25_data or not query_tokens:
return scores
# BM25 parameters
k1 = 1.5
b = 0.75
n_docs = self.bm25_data.get("n_docs", len(self.blocks))
avgdl = self.bm25_data.get("avgdl", 100)
df = self.bm25_data.get("df", {})
doc_lens = self.bm25_data.get("doc_lens", [100] * len(self.blocks))
for i, block in enumerate(self.blocks):
# Tokenize block's searchable text
block_tokens = tokenize(block.get("searchable_text", ""))
doc_len = doc_lens[i] if i < len(doc_lens) else len(block_tokens)
# Calculate BM25 score
score = 0.0
for token in query_tokens:
if token not in df:
continue
# Term frequency in this document
tf = block_tokens.count(token)
if tf == 0:
continue
# IDF
doc_freq = df.get(token, 0)
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
# BM25 score component
numerator = tf * (k1 + 1)
denominator = tf + k1 * (1 - b + b * doc_len / avgdl)
score += idf * numerator / denominator
scores[i] = score
# Normalize to [0, 1]
max_score = scores.max()
if max_score > 0:
scores = scores / max_score
return scores
def _compute_name_scores(self, query_tokens: list[str]) -> np.ndarray:
"""Compute name match scores using the name index."""
scores = np.zeros(len(self.blocks))
if not self.name_index or not query_tokens:
return scores
for token in query_tokens:
if token in self.name_index:
for block_idx, weight in self.name_index[token]:
if block_idx < len(scores):
scores[int(block_idx)] += weight
# Also check for partial matches in block names
for i, block in enumerate(self.blocks):
name_lower = block.get("name", "").lower()
for token in query_tokens:
if token in name_lower:
scores[i] += 0.5
# Normalize to [0, 1]
max_score = scores.max()
if max_score > 0:
scores = scores / max_score
return scores
def _compute_category_scores(self, query_tokens: list[str]) -> np.ndarray:
"""Compute category match scores."""
scores = np.zeros(len(self.blocks))
if not query_tokens:
return scores
for i, block in enumerate(self.blocks):
categories = block.get("categories", [])
category_text = " ".join(categories).lower()
for token in query_tokens:
if token in category_text:
scores[i] += 1.0
# Normalize to [0, 1]
max_score = scores.max()
if max_score > 0:
scores = scores / max_score
return scores
def search(
self,
query: str,
top_k: int = 10,
weights: Optional[SearchWeights] = None,
) -> list[BlockSearchResult]:
"""
Perform hybrid search combining multiple signals.
Args:
query: Search query string
top_k: Number of results to return
weights: Optional custom weights for signals
Returns:
List of BlockSearchResult sorted by score
"""
if not self._loaded and not self.load():
return []
if weights is None:
weights = SearchWeights()
# Tokenize query
query_tokens = tokenize(query)
if not query_tokens:
# Fallback: try raw query words
query_tokens = query.lower().split()
# Compute semantic scores
semantic_scores = np.zeros(len(self.blocks))
if self.normalized_embeddings is not None:
query_embedding = self._embed_query(query)
if query_embedding is not None:
semantic_scores = self._compute_semantic_scores(query_embedding)
# Compute other scores
bm25_scores = self._compute_bm25_scores(query_tokens)
name_scores = self._compute_name_scores(query_tokens)
category_scores = self._compute_category_scores(query_tokens)
# Combine scores using weights
combined_scores = (
weights.semantic * semantic_scores
+ weights.bm25 * bm25_scores
+ weights.name_match * name_scores
+ weights.category_match * category_scores
)
# Get top-k indices
top_indices = np.argsort(combined_scores)[::-1][:top_k]
# Build results
results = []
for idx in top_indices:
if combined_scores[idx] <= 0:
continue
block = self.blocks[idx]
results.append(
BlockSearchResult(
block_id=block["id"],
name=block["name"],
description=block["description"],
categories=block.get("categories", []),
score=float(combined_scores[idx]),
semantic_score=float(semantic_scores[idx]),
bm25_score=float(bm25_scores[idx]),
name_score=float(name_scores[idx]),
category_score=float(category_scores[idx]),
)
)
return results
# Global index instance (lazy loaded)
_block_search_index: Optional[BlockSearchIndex] = None
def get_block_search_index() -> BlockSearchIndex:
"""Get or create the block search index singleton."""
global _block_search_index
if _block_search_index is None:
_block_search_index = BlockSearchIndex(INDEX_PATH)
return _block_search_index

View File

@@ -1,386 +0,0 @@
"""Tool for searching platform documentation."""
import json
import logging
import math
import re
from pathlib import Path
from typing import Any
from backend.api.features.chat.model import ChatSession
from .base import BaseTool
from .models import (
DocSearchResult,
DocSearchResultsResponse,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
# Documentation base URL
DOCS_BASE_URL = "https://docs.agpt.co/platform"
# Path to the JSON index file (relative to this file)
INDEX_PATH = Path(__file__).parent / "docs_index.json"
def tokenize(text: str) -> list[str]:
"""Simple tokenizer for BM25."""
text = text.lower()
# Remove code blocks
text = re.sub(r"```[\s\S]*?```", "", text)
text = re.sub(r"`[^`]+`", "", text)
# Extract words
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
# Remove very short words and stopwords
stopwords = {
"the",
"a",
"an",
"is",
"are",
"was",
"were",
"be",
"been",
"being",
"have",
"has",
"had",
"do",
"does",
"did",
"will",
"would",
"could",
"should",
"may",
"might",
"must",
"shall",
"can",
"need",
"dare",
"ought",
"used",
"to",
"of",
"in",
"for",
"on",
"with",
"at",
"by",
"from",
"as",
"into",
"through",
"during",
"before",
"after",
"above",
"below",
"between",
"under",
"again",
"further",
"then",
"once",
"and",
"but",
"or",
"nor",
"so",
"yet",
"both",
"either",
"neither",
"not",
"only",
"own",
"same",
"than",
"too",
"very",
"just",
"also",
"now",
"here",
"there",
"when",
"where",
"why",
"how",
"all",
"each",
"every",
"both",
"few",
"more",
"most",
"other",
"some",
"such",
"no",
"any",
"this",
"that",
"these",
"those",
"it",
"its",
}
return [w for w in words if len(w) > 2 and w not in stopwords]
class DocSearchIndex:
"""Lightweight documentation search index using BM25."""
def __init__(self, index_path: Path):
self.chunks: list[dict] = []
self.bm25_data: dict = {}
self._loaded = False
self._index_path = index_path
def load(self) -> bool:
"""Load the index from JSON file."""
if self._loaded:
return True
if not self._index_path.exists():
logger.warning(f"Documentation index not found at {self._index_path}")
return False
try:
with open(self._index_path, "r", encoding="utf-8") as f:
data = json.load(f)
self.chunks = data.get("chunks", [])
self.bm25_data = data.get("bm25", {})
self._loaded = True
logger.info(f"Loaded documentation index with {len(self.chunks)} chunks")
return True
except Exception as e:
logger.error(f"Failed to load documentation index: {e}")
return False
def search(self, query: str, top_k: int = 5) -> list[dict]:
"""Search the index using BM25."""
if not self._loaded and not self.load():
return []
query_tokens = tokenize(query)
if not query_tokens:
return []
# BM25 parameters
k1 = 1.5
b = 0.75
n_docs = self.bm25_data.get("n_docs", len(self.chunks))
avgdl = self.bm25_data.get("avgdl", 100)
df = self.bm25_data.get("df", {})
doc_lens = self.bm25_data.get("doc_lens", [100] * len(self.chunks))
scores = []
for i, chunk in enumerate(self.chunks):
# Tokenize chunk text
chunk_tokens = tokenize(chunk.get("text", ""))
doc_len = doc_lens[i] if i < len(doc_lens) else len(chunk_tokens)
# Calculate BM25 score
score = 0.0
for token in query_tokens:
if token not in df:
continue
# Term frequency in this document
tf = chunk_tokens.count(token)
if tf == 0:
continue
# IDF
doc_freq = df.get(token, 0)
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
# BM25 score component
numerator = tf * (k1 + 1)
denominator = tf + k1 * (1 - b + b * doc_len / avgdl)
score += idf * numerator / denominator
# Boost for title/heading matches
title = chunk.get("title", "").lower()
heading = chunk.get("heading", "").lower()
for token in query_tokens:
if token in title:
score *= 1.5
if token in heading:
score *= 1.2
scores.append((i, score))
# Sort by score and return top_k
scores.sort(key=lambda x: x[1], reverse=True)
results = []
seen_sections = set()
for idx, score in scores:
if score <= 0:
continue
chunk = self.chunks[idx]
section_key = (chunk.get("doc", ""), chunk.get("heading", ""))
# Deduplicate by section
if section_key in seen_sections:
continue
seen_sections.add(section_key)
results.append(
{
"title": chunk.get("title", ""),
"path": chunk.get("doc", ""),
"heading": chunk.get("heading", ""),
"text": chunk.get("text", ""), # Full text for LLM comprehension
"score": score,
}
)
if len(results) >= top_k:
break
return results
# Global index instance (lazy loaded)
_search_index: DocSearchIndex | None = None
def get_search_index() -> DocSearchIndex:
"""Get or create the search index singleton."""
global _search_index
if _search_index is None:
_search_index = DocSearchIndex(INDEX_PATH)
return _search_index
class SearchDocsTool(BaseTool):
"""Tool for searching AutoGPT platform documentation."""
@property
def name(self) -> str:
return "search_platform_docs"
@property
def description(self) -> str:
return (
"Search the AutoGPT platform documentation and support Q&A for information about "
"how to use the platform, create agents, configure blocks, "
"set up integrations, troubleshoot issues, and more. Use this when users ask "
"support questions or want to learn how to do something with AutoGPT."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"Search query describing what the user wants to learn about. "
"Use keywords like 'blocks', 'agents', 'credentials', 'API', etc."
),
},
},
"required": ["query"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Search documentation for the query.
Args:
user_id: User ID (may be anonymous)
session: Chat session
query: Search query
Returns:
DocSearchResultsResponse: List of matching documentation sections
NoResultsResponse: No results found
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
session_id = session.session_id
if not query:
return ErrorResponse(
message="Please provide a search query",
session_id=session_id,
)
try:
index = get_search_index()
results = index.search(query, top_k=5)
if not results:
return NoResultsResponse(
message=f"No documentation found for '{query}'. Try different keywords.",
session_id=session_id,
suggestions=[
"Try more general terms like 'blocks', 'agents', 'setup'",
"Check the documentation at docs.agpt.co",
],
)
# Convert to response format
doc_results = []
for r in results:
# Build documentation URL
path = r["path"]
if path.endswith(".md"):
path = path[:-3] # Remove .md extension
doc_url = f"{DOCS_BASE_URL}/{path}"
full_text = r["text"]
doc_results.append(
DocSearchResult(
title=r["title"],
path=r["path"],
section=r["heading"],
snippet=(
full_text[:300] + "..."
if len(full_text) > 300
else full_text
),
content=full_text, # Full text for LLM to read and understand
score=round(r["score"], 3),
doc_url=doc_url,
)
)
return DocSearchResultsResponse(
message=(
f"Found {len(doc_results)} relevant documentation sections. "
"Use these to help answer the user's question. "
"Include links to the documentation when helpful."
),
results=doc_results,
count=len(doc_results),
query=query,
session_id=session_id,
)
except Exception as e:
logger.error(f"Error searching documentation: {e}", exc_info=True)
return ErrorResponse(
message="Failed to search documentation. Please try again.",
error=str(e),
session_id=session_id,
)

View File

@@ -1,72 +0,0 @@
#!/usr/bin/env python3
"""
CLI script to backfill embeddings for store agents.
Usage:
poetry run python -m backend.server.v2.store.backfill_embeddings [--batch-size N]
"""
import argparse
import asyncio
import sys
import prisma
async def main(batch_size: int = 100) -> int:
"""Run the backfill process."""
# Initialize Prisma client
client = prisma.Prisma()
await client.connect()
prisma.register(client)
try:
from backend.api.features.store.embeddings import (
backfill_missing_embeddings,
get_embedding_stats,
)
# Get current stats
print("Current embedding stats:")
stats = await get_embedding_stats()
print(f" Total approved: {stats['total_approved']}")
print(f" With embeddings: {stats['with_embeddings']}")
print(f" Without embeddings: {stats['without_embeddings']}")
print(f" Coverage: {stats['coverage_percent']}%")
if stats["without_embeddings"] == 0:
print("\nAll agents already have embeddings. Nothing to do.")
return 0
# Run backfill
print(f"\nBackfilling up to {batch_size} embeddings...")
result = await backfill_missing_embeddings(batch_size=batch_size)
print(f" Processed: {result['processed']}")
print(f" Success: {result['success']}")
print(f" Failed: {result['failed']}")
# Get final stats
print("\nFinal embedding stats:")
stats = await get_embedding_stats()
print(f" Total approved: {stats['total_approved']}")
print(f" With embeddings: {stats['with_embeddings']}")
print(f" Without embeddings: {stats['without_embeddings']}")
print(f" Coverage: {stats['coverage_percent']}%")
return 0 if result["failed"] == 0 else 1
finally:
await client.disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Backfill embeddings for store agents")
parser.add_argument(
"--batch-size",
type=int,
default=100,
help="Number of embeddings to generate (default: 100)",
)
args = parser.parse_args()
sys.exit(asyncio.run(main(batch_size=args.batch_size)))

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import typing
from datetime import datetime, timezone
from typing import Literal
@@ -9,7 +10,7 @@ import prisma.errors
import prisma.models
import prisma.types
from backend.data.db import transaction
from backend.data.db import query_raw_with_schema, transaction
from backend.data.graph import (
GraphMeta,
GraphModel,
@@ -56,21 +57,95 @@ async def get_store_agents(
)
try:
# If search_query is provided, use hybrid search (embeddings + tsvector)
# If search_query is provided, use full-text search
if search_query:
from backend.api.features.store.hybrid_search import hybrid_search
offset = (page - 1) * page_size
# Use hybrid search combining semantic and lexical signals
agents, total = await hybrid_search(
query=search_query,
featured=featured,
creators=creators,
category=category,
sorted_by="relevance", # Use hybrid scoring for relevance
page=page,
page_size=page_size,
)
# Whitelist allowed order_by columns
ALLOWED_ORDER_BY = {
"rating": "rating DESC, rank DESC",
"runs": "runs DESC, rank DESC",
"name": "agent_name ASC, rank ASC",
"updated_at": "updated_at DESC, rank DESC",
}
# Validate and get order clause
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
else:
order_by_clause = "updated_at DESC, rank DESC"
# Build WHERE conditions and parameters list
where_parts: list[str] = []
params: list[typing.Any] = [search_query] # $1 - search term
param_index = 2 # Start at $2 for next parameter
# Always filter for available agents
where_parts.append("is_available = true")
if featured:
where_parts.append("featured = true")
if creators and creators:
# Use ANY with array parameter
where_parts.append(f"creator_username = ANY(${param_index})")
params.append(creators)
param_index += 1
if category and category:
where_parts.append(f"${param_index} = ANY(categories)")
params.append(category)
param_index += 1
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
# Add pagination params
params.extend([page_size, offset])
limit_param = f"${param_index}"
offset_param = f"${param_index + 1}"
# Execute full-text search query with parameterized values
sql_query = f"""
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
ts_rank_cd(search, query) AS rank
FROM {{schema_prefix}}"StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
AND search @@ query
ORDER BY {order_by_clause}
LIMIT {limit_param} OFFSET {offset_param}
"""
# Count query for pagination - only uses search term parameter
count_query = f"""
SELECT COUNT(*) as count
FROM {{schema_prefix}}"StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
AND search @@ query
"""
# Execute both queries with parameters
agents = await query_raw_with_schema(sql_query, *params)
# For count, use params without pagination (last 2 params)
count_params = params[:-2]
count_result = await query_raw_with_schema(count_query, *count_params)
total = count_result[0]["count"] if count_result else 0
total_pages = (total + page_size - 1) // page_size
# Convert raw results to StoreAgent models
@@ -1489,24 +1564,6 @@ async def review_store_submission(
},
)
# Generate embedding for approved listing (non-blocking)
try:
from backend.api.features.store.embeddings import ensure_embedding
await ensure_embedding(
version_id=store_listing_version_id,
name=store_listing_version.name,
description=store_listing_version.description,
sub_heading=store_listing_version.subHeading,
categories=store_listing_version.categories or [],
)
except Exception as e:
# Don't fail approval if embedding generation fails
logger.warning(
f"Failed to generate embedding for approved listing "
f"{store_listing_version_id}: {e}"
)
# If rejecting an approved agent, update the StoreListing accordingly
if is_rejecting_approved:
# Check if there are other approved versions

View File

@@ -1,408 +0,0 @@
"""
Store Listing Embeddings Service
Handles generation and storage of OpenAI embeddings for store listings
to enable semantic/hybrid search.
"""
import hashlib
import logging
import os
from typing import Any
import prisma
logger = logging.getLogger(__name__)
# OpenAI embedding model configuration
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_DIM = 1536
def build_searchable_text(
name: str,
description: str,
sub_heading: str,
categories: list[str],
) -> str:
"""
Build searchable text from listing version fields.
Combines relevant fields into a single string for embedding.
"""
parts = []
# Name is important - include it
if name:
parts.append(name)
# Sub-heading provides context
if sub_heading:
parts.append(sub_heading)
# Description is the main content
if description:
parts.append(description)
# Categories help with semantic matching
if categories:
parts.append(" ".join(categories))
return " ".join(parts)
def compute_content_hash(text: str) -> str:
"""Compute MD5 hash of text for change detection."""
return hashlib.md5(text.encode()).hexdigest()
async def generate_embedding(text: str) -> list[float] | None:
"""
Generate embedding for text using OpenAI API.
Returns None if embedding generation fails.
"""
try:
from openai import OpenAI
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.warning("OPENAI_API_KEY not set, cannot generate embedding")
return None
client = OpenAI(api_key=api_key)
# Truncate text to avoid token limits (~32k chars for safety)
truncated_text = text[:32000]
response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
)
embedding = response.data[0].embedding
logger.debug(f"Generated embedding with {len(embedding)} dimensions")
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
return None
async def store_embedding(
version_id: str,
embedding: list[float],
searchable_text: str,
content_hash: str,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Store embedding in the database.
Uses raw SQL since Prisma doesn't natively support pgvector.
"""
try:
client = tx if tx else prisma.get_client()
# Convert embedding to PostgreSQL vector format
embedding_str = "[" + ",".join(str(x) for x in embedding) + "]"
# Upsert the embedding
# Set search_path to include public for vector type visibility
await client.execute_raw(
"""
SET LOCAL search_path TO platform, public;
INSERT INTO platform."StoreListingEmbedding" (
"id", "storeListingVersionId", "embedding",
"searchableText", "contentHash", "createdAt", "updatedAt"
)
VALUES (
gen_random_uuid(), $1, $2::vector,
$3, $4, NOW(), NOW()
)
ON CONFLICT ("storeListingVersionId")
DO UPDATE SET
"embedding" = $2::vector,
"searchableText" = $3,
"contentHash" = $4,
"updatedAt" = NOW()
""",
version_id,
embedding_str,
searchable_text,
content_hash,
)
logger.info(f"Stored embedding for version {version_id}")
return True
except Exception as e:
logger.error(f"Failed to store embedding for version {version_id}: {e}")
return False
async def get_embedding(version_id: str) -> dict[str, Any] | None:
"""
Retrieve embedding record for a listing version.
Returns dict with embedding, searchableText, contentHash or None if not found.
"""
try:
client = prisma.get_client()
result = await client.query_raw(
"""
SELECT
"id",
"storeListingVersionId",
"embedding"::text as "embedding",
"searchableText",
"contentHash",
"createdAt",
"updatedAt"
FROM platform."StoreListingEmbedding"
WHERE "storeListingVersionId" = $1
""",
version_id,
)
if result and len(result) > 0:
return result[0]
return None
except Exception as e:
logger.error(f"Failed to get embedding for version {version_id}: {e}")
return None
async def ensure_embedding(
version_id: str,
name: str,
description: str,
sub_heading: str,
categories: list[str],
force: bool = False,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Ensure an embedding exists for the listing version.
Creates embedding if missing or if content has changed.
Skips if content hash matches existing embedding.
Args:
version_id: The StoreListingVersion ID
name: Agent name
description: Agent description
sub_heading: Agent sub-heading
categories: Agent categories
force: Force regeneration even if hash matches
tx: Optional transaction client
Returns:
True if embedding exists/was created, False on failure
"""
try:
# Build searchable text and compute hash
searchable_text = build_searchable_text(
name, description, sub_heading, categories
)
content_hash = compute_content_hash(searchable_text)
# Check if embedding already exists with same hash
if not force:
existing = await get_embedding(version_id)
if existing and existing.get("contentHash") == content_hash:
logger.debug(
f"Embedding for version {version_id} is up to date (hash match)"
)
return True
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(f"Could not generate embedding for version {version_id}")
return False
# Store the embedding
return await store_embedding(
version_id=version_id,
embedding=embedding,
searchable_text=searchable_text,
content_hash=content_hash,
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
return False
async def delete_embedding(version_id: str) -> bool:
"""
Delete embedding for a listing version.
Note: This is usually handled automatically by CASCADE delete,
but provided for manual cleanup if needed.
"""
try:
client = prisma.get_client()
await client.execute_raw(
"""
DELETE FROM platform."StoreListingEmbedding"
WHERE "storeListingVersionId" = $1
""",
version_id,
)
logger.info(f"Deleted embedding for version {version_id}")
return True
except Exception as e:
logger.error(f"Failed to delete embedding for version {version_id}: {e}")
return False
async def get_embedding_stats() -> dict[str, Any]:
"""
Get statistics about embedding coverage.
Returns counts of:
- Total approved listing versions
- Versions with embeddings
- Versions without embeddings
"""
try:
client = prisma.get_client()
# Count approved versions
approved_result = await client.query_raw(
"""
SELECT COUNT(*) as count
FROM platform."StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
AND "isDeleted" = false
"""
)
total_approved = approved_result[0]["count"] if approved_result else 0
# Count versions with embeddings
embedded_result = await client.query_raw(
"""
SELECT COUNT(*) as count
FROM platform."StoreListingVersion" slv
JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
WHERE slv."submissionStatus" = 'APPROVED'
AND slv."isDeleted" = false
"""
)
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
return {
"total_approved": total_approved,
"with_embeddings": with_embeddings,
"without_embeddings": total_approved - with_embeddings,
"coverage_percent": (
round(with_embeddings / total_approved * 100, 1)
if total_approved > 0
else 0
),
}
except Exception as e:
logger.error(f"Failed to get embedding stats: {e}")
return {
"total_approved": 0,
"with_embeddings": 0,
"without_embeddings": 0,
"coverage_percent": 0,
"error": str(e),
}
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
"""
Generate embeddings for approved listings that don't have them.
Args:
batch_size: Number of embeddings to generate in one call
Returns:
Dict with success/failure counts
"""
try:
client = prisma.get_client()
# Find approved versions without embeddings
missing = await client.query_raw(
"""
SELECT
slv.id,
slv.name,
slv.description,
slv."subHeading",
slv.categories
FROM platform."StoreListingVersion" slv
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
WHERE slv."submissionStatus" = 'APPROVED'
AND slv."isDeleted" = false
AND sle.id IS NULL
LIMIT $1
""",
batch_size,
)
if not missing:
return {
"processed": 0,
"success": 0,
"failed": 0,
"message": "No missing embeddings",
}
success = 0
failed = 0
for row in missing:
result = await ensure_embedding(
version_id=row["id"],
name=row["name"],
description=row["description"],
sub_heading=row["subHeading"],
categories=row["categories"] or [],
)
if result:
success += 1
else:
failed += 1
return {
"processed": len(missing),
"success": success,
"failed": failed,
"message": f"Backfilled {success} embeddings, {failed} failed",
}
except Exception as e:
logger.error(f"Failed to backfill embeddings: {e}")
return {
"processed": 0,
"success": 0,
"failed": 0,
"error": str(e),
}
async def embed_query(query: str) -> list[float] | None:
"""
Generate embedding for a search query.
Same as generate_embedding but with clearer intent.
"""
return await generate_embedding(query)
def embedding_to_vector_string(embedding: list[float]) -> str:
"""Convert embedding list to PostgreSQL vector string format."""
return "[" + ",".join(str(x) for x in embedding) + "]"

View File

@@ -1,440 +0,0 @@
"""
Hybrid Search for Store Agents
Combines semantic (embedding) search with lexical (tsvector) search
for improved relevance in marketplace agent discovery.
"""
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Literal
import prisma
from backend.api.features.store.embeddings import (
embed_query,
embedding_to_vector_string,
)
logger = logging.getLogger(__name__)
@dataclass
class HybridSearchWeights:
"""Weights for combining search signals."""
semantic: float = 0.35 # Embedding cosine similarity
lexical: float = 0.35 # tsvector ts_rank_cd score
category: float = 0.20 # Category match boost
recency: float = 0.10 # Newer agents ranked higher
DEFAULT_WEIGHTS = HybridSearchWeights()
# Minimum relevance score threshold - agents below this are filtered out
# With weights (0.35 semantic + 0.35 lexical + 0.20 category + 0.10 recency):
# - 0.20 means at least ~50% semantic match OR strong lexical match required
# - Ensures only genuinely relevant results are returned
# - Recency alone (0.10 max) won't pass the threshold
DEFAULT_MIN_SCORE = 0.20
@dataclass
class HybridSearchResult:
"""A single search result with score breakdown."""
slug: str
agent_name: str
agent_image: str
creator_username: str
creator_avatar: str
sub_heading: str
description: str
runs: int
rating: float
categories: list[str]
featured: bool
is_available: bool
updated_at: datetime
# Score breakdown (for debugging/tuning)
combined_score: float
semantic_score: float = 0.0
lexical_score: float = 0.0
category_score: float = 0.0
recency_score: float = 0.0
async def hybrid_search(
query: str,
featured: bool = False,
creators: list[str] | None = None,
category: str | None = None,
sorted_by: (
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
) = None,
page: int = 1,
page_size: int = 20,
weights: HybridSearchWeights | None = None,
min_score: float | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Perform hybrid search combining semantic and lexical signals.
Args:
query: Search query string
featured: Filter for featured agents only
creators: Filter by creator usernames
category: Filter by category
sorted_by: Sort order (relevance uses hybrid scoring)
page: Page number (1-indexed)
page_size: Results per page
weights: Custom weights for search signals
min_score: Minimum relevance score threshold (0-1). Results below
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
Returns:
Tuple of (results list, total count). Returns empty list if no
results meet the minimum relevance threshold.
"""
if weights is None:
weights = DEFAULT_WEIGHTS
if min_score is None:
min_score = DEFAULT_MIN_SCORE
offset = (page - 1) * page_size
client = prisma.get_client()
# Generate query embedding
query_embedding = await embed_query(query)
# Build WHERE clause conditions
where_parts: list[str] = ["sa.is_available = true"]
params: list[Any] = []
param_index = 1
# Add search query for lexical matching
params.append(query)
query_param = f"${param_index}"
param_index += 1
if featured:
where_parts.append("sa.featured = true")
if creators:
where_parts.append(f"sa.creator_username = ANY(${param_index})")
params.append(creators)
param_index += 1
if category:
where_parts.append(f"${param_index} = ANY(sa.categories)")
params.append(category)
param_index += 1
where_clause = " AND ".join(where_parts)
# Determine if we can use hybrid search (have query embedding)
use_hybrid = query_embedding is not None
if use_hybrid:
# Add embedding parameter
embedding_str = embedding_to_vector_string(query_embedding)
params.append(embedding_str)
embedding_param = f"${param_index}"
param_index += 1
# Build hybrid search query with weighted scoring
# The semantic score is (1 - cosine_distance), normalized to [0,1]
# The lexical score is ts_rank_cd, normalized by max value
# Set search_path to include public for vector type visibility
sql_query = f"""
SET LOCAL search_path TO platform, public;
WITH search_scores AS (
SELECT
sa.*,
-- Semantic score: cosine similarity (1 - distance)
COALESCE(1 - (sle.embedding <=> {embedding_param}::vector), 0) as semantic_score,
-- Lexical score: ts_rank_cd normalized
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
-- Category match: 1 if query term appears in categories, else 0
CASE
WHEN EXISTS (
SELECT 1 FROM unnest(sa.categories) cat
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
) THEN 1.0
ELSE 0.0
END as category_score,
-- Recency score: exponential decay over 90 days
EXP(-EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score
FROM platform."StoreAgent" sa
LEFT JOIN platform."StoreListing" sl ON sa.slug = sl.slug
LEFT JOIN platform."StoreListingVersion" slv ON sl."activeVersionId" = slv.id
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
WHERE {where_clause}
AND (
sa.search @@ plainto_tsquery('english', {query_param})
OR sle.embedding IS NOT NULL
)
),
normalized AS (
SELECT
*,
-- Normalize lexical score by max in result set
CASE
WHEN MAX(lexical_raw) OVER () > 0
THEN lexical_raw / MAX(lexical_raw) OVER ()
ELSE 0
END as lexical_score
FROM search_scores
),
scored AS (
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
semantic_score,
lexical_score,
category_score,
recency_score,
(
{weights.semantic} * semantic_score +
{weights.lexical} * lexical_score +
{weights.category} * category_score +
{weights.recency} * recency_score
) as combined_score
FROM normalized
)
SELECT * FROM scored
WHERE combined_score >= {min_score}
ORDER BY combined_score DESC
LIMIT ${param_index} OFFSET ${param_index + 1}
"""
# Add pagination params
params.extend([page_size, offset])
# Count query - must also filter by min_score
count_query = f"""
SET LOCAL search_path TO platform, public;
WITH search_scores AS (
SELECT
sa.slug,
COALESCE(1 - (sle.embedding <=> {embedding_param}::vector), 0) as semantic_score,
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
CASE
WHEN EXISTS (
SELECT 1 FROM unnest(sa.categories) cat
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
) THEN 1.0
ELSE 0.0
END as category_score,
EXP(-EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score
FROM platform."StoreAgent" sa
LEFT JOIN platform."StoreListing" sl ON sa.slug = sl.slug
LEFT JOIN platform."StoreListingVersion" slv ON sl."activeVersionId" = slv.id
LEFT JOIN platform."StoreListingEmbedding" sle ON slv.id = sle."storeListingVersionId"
WHERE {where_clause}
AND (
sa.search @@ plainto_tsquery('english', {query_param})
OR sle.embedding IS NOT NULL
)
),
normalized AS (
SELECT
slug,
semantic_score,
category_score,
recency_score,
CASE
WHEN MAX(lexical_raw) OVER () > 0
THEN lexical_raw / MAX(lexical_raw) OVER ()
ELSE 0
END as lexical_score
FROM search_scores
),
scored AS (
SELECT
slug,
(
{weights.semantic} * semantic_score +
{weights.lexical} * lexical_score +
{weights.category} * category_score +
{weights.recency} * recency_score
) as combined_score
FROM normalized
)
SELECT COUNT(*) as count FROM scored
WHERE combined_score >= {min_score}
"""
else:
# Fallback to lexical-only search (existing behavior)
# Note: For lexical-only, we still require tsvector match but don't
# apply min_score since ts_rank_cd isn't normalized to [0,1]
logger.warning("Falling back to lexical-only search (no query embedding)")
sql_query = f"""
WITH lexical_scores AS (
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
0.0 as semantic_score,
ts_rank_cd(search, plainto_tsquery('english', {query_param})) as lexical_raw,
CASE
WHEN EXISTS (
SELECT 1 FROM unnest(categories) cat
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
) THEN 1.0
ELSE 0.0
END as category_score,
EXP(-EXTRACT(EPOCH FROM (NOW() - updated_at)) / (90 * 24 * 3600)) as recency_score
FROM platform."StoreAgent" sa
WHERE {where_clause}
AND search @@ plainto_tsquery('english', {query_param})
),
normalized AS (
SELECT
*,
CASE
WHEN MAX(lexical_raw) OVER () > 0
THEN lexical_raw / MAX(lexical_raw) OVER ()
ELSE 0
END as lexical_score
FROM lexical_scores
),
scored AS (
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
semantic_score,
lexical_score,
category_score,
recency_score,
(
{weights.lexical} * lexical_score +
{weights.category} * category_score +
{weights.recency} * recency_score
) as combined_score
FROM normalized
)
SELECT * FROM scored
WHERE combined_score >= {min_score}
ORDER BY combined_score DESC
LIMIT ${param_index} OFFSET ${param_index + 1}
"""
params.extend([page_size, offset])
count_query = f"""
WITH lexical_scores AS (
SELECT
slug,
ts_rank_cd(search, plainto_tsquery('english', {query_param})) as lexical_raw,
CASE
WHEN EXISTS (
SELECT 1 FROM unnest(categories) cat
WHERE LOWER(cat) LIKE '%' || LOWER({query_param}) || '%'
) THEN 1.0
ELSE 0.0
END as category_score,
EXP(-EXTRACT(EPOCH FROM (NOW() - updated_at)) / (90 * 24 * 3600)) as recency_score
FROM platform."StoreAgent" sa
WHERE {where_clause}
AND search @@ plainto_tsquery('english', {query_param})
),
normalized AS (
SELECT
slug,
category_score,
recency_score,
CASE
WHEN MAX(lexical_raw) OVER () > 0
THEN lexical_raw / MAX(lexical_raw) OVER ()
ELSE 0
END as lexical_score
FROM lexical_scores
),
scored AS (
SELECT
slug,
(
{weights.lexical} * lexical_score +
{weights.category} * category_score +
{weights.recency} * recency_score
) as combined_score
FROM normalized
)
SELECT COUNT(*) as count FROM scored
WHERE combined_score >= {min_score}
"""
try:
# Execute search query
# Dynamic SQL is safe here - all user inputs are parameterized ($1, $2, etc.)
results = await client.query_raw(sql_query, *params) # type: ignore[arg-type]
# Execute count query (without pagination params)
count_params = params[:-2] # Remove LIMIT and OFFSET params
count_result = await client.query_raw(count_query, *count_params) # type: ignore[arg-type]
total = count_result[0]["count"] if count_result else 0
logger.info(
f"Hybrid search for '{query}': {len(results)} results, {total} total "
f"(hybrid={use_hybrid})"
)
return results, total
except Exception as e:
logger.error(f"Hybrid search failed: {e}")
raise
async def hybrid_search_simple(
query: str,
page: int = 1,
page_size: int = 20,
) -> tuple[list[dict[str, Any]], int]:
"""
Simplified hybrid search for common use cases.
Uses default weights and no filters.
"""
return await hybrid_search(
query=query,
page=page,
page_size=page_size,
)

View File

@@ -294,6 +294,7 @@ async def get_creators(
@router.get(
"/creator/{username}",
summary="Get creator details",
operation_id="getV2GetCreatorDetails",
tags=["store", "public"],
response_model=store_model.CreatorDetails,
)

View File

@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.llm_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
@@ -37,9 +38,11 @@ import backend.data.db
import backend.data.graph
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.v2.llm.routes as public_llm_routes
import backend.util.service
import backend.util.settings
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -109,11 +112,27 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated discriminator_mapping
from backend.data.block import BlockSchema
BlockSchema.clear_all_schema_caches()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
# migrate_llm_models uses registry default model
from backend.blocks.llm import LlmModel
default_model_slug = llm_registry.get_default_model_slug()
if default_model_slug:
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
else:
logger.warning("Skipping LLM model migration: no default model available")
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
with launch_darkly_context():
@@ -298,6 +317,16 @@ app.include_router(
tags=["v2", "executions", "review"],
prefix="/api/review",
)
app.include_router(
backend.api.features.admin.llm_routes.router,
tags=["v2", "admin", "llm"],
prefix="/api/llm/admin",
)
app.include_router(
public_llm_routes.router,
tags=["v2", "llm"],
prefix="/api",
)
app.include_router(
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
)

View File

@@ -77,7 +77,39 @@ async def event_broadcaster(manager: ConnectionManager):
payload=notification.payload,
)
await asyncio.gather(execution_worker(), notification_worker())
async def registry_refresh_worker():
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
from backend.data.redis_client import connect_async
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
)
async for message in pubsub.listen():
if (
message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info(
"Broadcasting LLM registry refresh to all WebSocket clients"
)
await manager.broadcast_to_all(
method=WSMethod.NOTIFICATION,
data={
"type": "LLM_REGISTRY_REFRESH",
"event": "registry_updated",
},
)
await asyncio.gather(
execution_worker(),
notification_worker(),
registry_refresh_worker(),
)
async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -1,7 +1,6 @@
from typing import Any
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AIBlockBase,
@@ -10,6 +9,7 @@ from backend.blocks.llm import (
LlmModel,
LLMResponse,
llm_call,
llm_model_schema_extra,
)
from backend.data.block import (
BlockCategory,
@@ -50,9 +50,10 @@ class AIConditionBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for evaluating the condition.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
@@ -82,7 +83,7 @@ class AIConditionBlock(AIBlockBase):
"condition": "the input is an email address",
"yes_value": "Valid email",
"no_value": "Not an email",
"model": DEFAULT_LLM_MODEL,
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,

View File

@@ -4,17 +4,19 @@ import logging
import re
import secrets
from abc import ABC
from enum import Enum, EnumMeta
from enum import Enum
from json import JSONDecodeError
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
from typing import Any, Iterable, List, Literal, Optional
import anthropic
import ollama
import openai
from anthropic.types import ToolParam
from groq import AsyncGroq
from pydantic import BaseModel, SecretStr
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
from pydantic_core import CoreSchema, core_schema
from backend.data import llm_registry
from backend.data.block import (
Block,
BlockCategory,
@@ -22,6 +24,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.llm_registry import ModelMetadata
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -66,114 +69,117 @@ TEST_CREDENTIALS_INPUT = {
def AICredentialsField() -> AICredentials:
"""
Returns a CredentialsField for LLM providers.
The discriminator_mapping will be refreshed when the schema is generated
if it's empty, ensuring the LLM registry is loaded.
"""
# Get the mapping now - it may be empty initially, but will be refreshed
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
mapping = llm_registry.get_llm_discriminator_mapping()
return CredentialsField(
description="API key for the LLM provider.",
discriminator="model",
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel
},
discriminator_mapping=mapping, # May be empty initially, refreshed later
)
class ModelMetadata(NamedTuple):
provider: str
context_window: int
max_output_tokens: int | None
def llm_model_schema_extra() -> dict[str, Any]:
return {"options": llm_registry.get_llm_model_schema_options()}
class LlmModelMeta(EnumMeta):
pass
class LlmModelMeta(type):
"""
Metaclass for LlmModel that enables attribute-style access to dynamic models.
This allows code like `LlmModel.GPT4O` to work by converting the attribute
name to a slug format:
- GPT4O -> gpt-4o
- GPT4O_MINI -> gpt-4o-mini
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
"""
def __getattr__(cls, name: str):
# Don't intercept private/dunder attributes
if name.startswith("_"):
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
# Convert attribute name to slug format:
# 1. Lowercase: GPT4O -> gpt4o
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
# 3. Insert hyphen between letter and digit: gpt4o -> gpt-4o
slug = name.lower().replace("_", "-")
slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
return cls(slug)
class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5_2 = "gpt-5.2-2025-12-11"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5 = "gpt-5-2025-08-07"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
GPT41 = "gpt-4.1-2025-04-14"
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
# Groq models
LLAMA3_3_70B = "llama-3.3-70b-versatile"
LLAMA3_1_8B = "llama-3.1-8b-instant"
# Ollama models
OLLAMA_LLAMA3_3 = "llama3.3"
OLLAMA_LLAMA3_2 = "llama3.2"
OLLAMA_LLAMA3_8B = "llama3"
OLLAMA_LLAMA3_405B = "llama3.1:405b"
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
MISTRAL_NEMO = "mistralai/mistral-nemo"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
QWEN3_CODER = "qwen/qwen3-coder"
# Llama API models
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
# v0 by Vercel models
V0_1_5_MD = "v0-1.5-md"
V0_1_5_LG = "v0-1.5-lg"
V0_1_0_MD = "v0-1.0-md"
class LlmModel(str, metaclass=LlmModelMeta):
"""
Dynamic LLM model type that accepts any model slug from the registry.
This is a string subclass (not an Enum) that allows any model slug value.
All models are managed via the LLM Registry in the database.
Usage:
model = LlmModel("gpt-4o") # Direct construction
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
model.value # Returns the slug string
model.provider # Returns the provider from registry
"""
def __new__(cls, value: str):
if isinstance(value, LlmModel):
return value
return str.__new__(cls, value)
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
"""
Tell Pydantic how to validate LlmModel.
Accepts strings and converts them to LlmModel instances.
"""
return core_schema.no_info_after_validator_function(
cls, # The validator function (LlmModel constructor)
core_schema.str_schema(), # Accept string input
serialization=core_schema.to_string_ser_schema(), # Serialize as string
)
@property
def value(self) -> str:
"""Return the model slug (for compatibility with enum-style access)."""
return str(self)
@classmethod
def default(cls) -> "LlmModel":
"""
Get the default model from the registry.
Returns the recommended model if set, otherwise gpt-4o if available
and enabled, otherwise the first enabled model from the registry.
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
"""
from backend.data.llm_registry import get_default_model_slug
slug = get_default_model_slug()
if slug is None:
# Registry is empty (e.g., at module import time before DB connection).
# Fall back to gpt-4o for backward compatibility.
slug = "gpt-4o"
return cls(slug)
@property
def metadata(self) -> ModelMetadata:
return MODEL_METADATA[self]
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
raise ValueError(
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
)
@property
def provider(self) -> str:
@@ -188,128 +194,11 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
return self.metadata.max_output_tokens
MODEL_METADATA = {
# https://platform.openai.com/docs/models
LlmModel.O3: ModelMetadata("openai", 200000, 100000),
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT41_MINI: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT4O_MINI: ModelMetadata(
"openai", 128000, 16384
), # gpt-4o-mini-2024-07-18
LlmModel.GPT4O: ModelMetadata("openai", 128000, 16384), # gpt-4o-2024-08-06
LlmModel.GPT4_TURBO: ModelMetadata(
"openai", 128000, 4096
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000
), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 32000
), # claude-4-opus-20250514
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
"anthropic", 200000, 64000
), # claude-4-sonnet-20250514
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000
), # claude-opus-4-5-20251101
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
"anthropic", 200000, 64000
), # claude-sonnet-4-5-20250929
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000
), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 64000
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096
), # claude-3-haiku-20240307
# https://docs.aimlapi.com/api-overview/model-database/text-models
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata("aiml_api", 32000, 8000),
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata("aiml_api", 128000, 40000),
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata("aiml_api", 128000, None),
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata("aiml_api", 131000, 2000),
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata("aiml_api", 128000, None),
# https://console.groq.com/docs/models
LlmModel.LLAMA3_3_70B: ModelMetadata("groq", 128000, 32768),
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 128000, 8192),
# https://ollama.com/library
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata("ollama", 8192, None),
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata("ollama", 8192, None),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, None),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, None),
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
# https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata("open_router", 1048576, 65535),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata("open_router", 1048576, 65535),
LlmModel.GEMINI_2_0_FLASH: ModelMetadata("open_router", 1048576, 8192),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535
),
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata("open_router", 1048576, 8192),
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata("open_router", 128000, 4096),
LlmModel.DEEPSEEK_CHAT: ModelMetadata("open_router", 64000, 2048),
LlmModel.DEEPSEEK_R1_0528: ModelMetadata("open_router", 163840, 163840),
LlmModel.PERPLEXITY_SONAR: ModelMetadata("open_router", 127000, 8000),
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata("open_router", 200000, 8000),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
16000,
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
"open_router", 131000, 4096
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
"open_router", 12288, 12288
),
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata("open_router", 131072, 131072),
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata("open_router", 131072, 32768),
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 300000, 5120),
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 128000, 5120),
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 65536, 4096),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4096, 4096),
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata("open_router", 131072, 131072),
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata("open_router", 1048576, 1000000),
LlmModel.GROK_4: ModelMetadata("open_router", 256000, 256000),
LlmModel.GROK_4_FAST: ModelMetadata("open_router", 2000000, 30000),
LlmModel.GROK_4_1_FAST: ModelMetadata("open_router", 2000000, 30000),
LlmModel.GROK_CODE_FAST_1: ModelMetadata("open_router", 256000, 10000),
LlmModel.KIMI_K2: ModelMetadata("open_router", 131000, 131000),
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata("open_router", 262144, 262144),
LlmModel.QWEN3_CODER: ModelMetadata("open_router", 262144, 262144),
# Llama API models
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata("llama_api", 128000, 4028),
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata("llama_api", 128000, 4028),
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata("llama_api", 128000, 4028),
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata("llama_api", 128000, 4028),
# v0 by Vercel models
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000),
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000),
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
}
# MODEL_METADATA removed - all models now come from the database via llm_registry
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
for model in LlmModel:
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
# Default model constant for backward compatibility
# Uses the dynamic registry to get the default model
DEFAULT_LLM_MODEL = LlmModel.default()
class ToolCall(BaseModel):
@@ -438,19 +327,94 @@ async def llm_call(
- prompt_tokens: The number of tokens used in the prompt.
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
context_window = llm_model.context_window
# Get model metadata and check if enabled - with fallback support
# The model we'll actually use (may differ if original is disabled)
model_to_use = llm_model.value
# Check if model is in registry and if it's enabled
from backend.data.llm_registry import (
get_fallback_model_for_disabled,
get_model_info,
)
model_info = get_model_info(llm_model.value)
if model_info and not model_info.is_enabled:
# Model is disabled - try to find a fallback from the same provider
fallback = get_fallback_model_for_disabled(llm_model.value)
if fallback:
logger.warning(
f"Model '{llm_model.value}' is disabled. Using fallback model '{fallback.slug}' from the same provider ({fallback.metadata.provider})."
)
model_to_use = fallback.slug
# Use fallback model's metadata
provider = fallback.metadata.provider
context_window = fallback.metadata.context_window
model_max_output = fallback.metadata.max_output_tokens or int(2**15)
else:
# No fallback available - raise error
raise ValueError(
f"LLM model '{llm_model.value}' is disabled and no fallback model "
f"from the same provider is available. Please enable the model or "
f"select a different model in the block configuration."
)
else:
# Model is enabled or not in registry (legacy/static model)
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
except ValueError:
# Model not in cache - try refreshing the registry once if we have DB access
logger.warning(f"Model {llm_model.value} not found in registry cache")
# Try refreshing the registry if we have database access
from backend.data.db import is_connected
if is_connected():
try:
logger.info(
f"Refreshing LLM registry and retrying lookup for {llm_model.value}"
)
await llm_registry.refresh_llm_registry()
# Try again after refresh
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
logger.info(
f"Successfully loaded model {llm_model.value} metadata after registry refresh"
)
except ValueError:
# Still not found after refresh
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry after refresh. "
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
)
except Exception as refresh_exc:
logger.error(f"Failed to refresh LLM registry: {refresh_exc}")
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
"Please ensure the model is added to the LLM registry via the admin UI."
) from refresh_exc
else:
# No DB access (e.g., in executor without direct DB connection)
# The registry should have been loaded on startup
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry cache. "
"The registry may need to be refreshed. Please contact support or try again later."
)
if compress_prompt_to_fit:
prompt = compress_prompt(
messages=prompt,
target_tokens=llm_model.context_window // 2,
target_tokens=context_window // 2,
lossy_ok=True,
)
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
model_max_output = llm_model.max_output_tokens or int(2**15)
# model_max_output already set above
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
@@ -468,7 +432,7 @@ async def llm_call(
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
@@ -515,7 +479,7 @@ async def llm_call(
)
try:
resp = await client.messages.create(
model=llm_model.value,
model=model_to_use,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
@@ -579,7 +543,7 @@ async def llm_call(
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if force_json_output else None
response = await client.chat.completions.create(
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -601,7 +565,7 @@ async def llm_call(
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = await client.generate(
model=llm_model.value,
model=model_to_use,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
options={"num_ctx": max_tokens},
@@ -631,7 +595,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -673,7 +637,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -700,7 +664,7 @@ async def llm_call(
reasoning=reasoning,
)
elif provider == "aiml_api":
client = openai.OpenAI(
client = openai.AsyncOpenAI(
base_url="https://api.aimlapi.com/v2",
api_key=credentials.api_key.get_secret_value(),
default_headers={
@@ -710,8 +674,8 @@ async def llm_call(
},
)
completion = client.chat.completions.create(
model=llm_model.value,
completion = await client.chat.completions.create(
model=model_to_use,
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
@@ -743,7 +707,7 @@ async def llm_call(
)
response = await client.chat.completions.create(
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -794,9 +758,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
force_json_output: bool = SchemaField(
title="Restrict LLM to pure JSON output",
@@ -859,7 +824,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": DEFAULT_LLM_MODEL,
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -1225,9 +1190,10 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
sys_prompt: str = SchemaField(
@@ -1321,8 +1287,9 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for summarizing the text.",
json_schema_extra=llm_model_schema_extra(),
)
focus: str = SchemaField(
title="Focus",
@@ -1538,8 +1505,9 @@ class AIConversationBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for the conversation.",
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
max_tokens: int | None = SchemaField(
@@ -1576,7 +1544,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": DEFAULT_LLM_MODEL,
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1639,9 +1607,10 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for generating the list.",
advanced=True,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
max_retries: int = SchemaField(
@@ -1696,7 +1665,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": DEFAULT_LLM_MODEL,
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
"force_json_output": False,

View File

@@ -226,9 +226,10 @@ class SmartDecisionMakerBlock(Block):
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default=llm.DEFAULT_LLM_MODEL,
default_factory=llm.LlmModel.default,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm.llm_model_schema_extra(),
)
credentials: llm.AICredentials = llm.AICredentialsField()
multiple_tool_calls: bool = SchemaField(

View File

@@ -10,13 +10,13 @@ import stagehand.main
from stagehand import Stagehand
from backend.blocks.llm import (
MODEL_METADATA,
AICredentials,
AICredentialsField,
LlmModel,
ModelMetadata,
)
from backend.blocks.stagehand._config import stagehand as stagehand_provider
from backend.data import llm_registry
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
Returns the provider name for the model in the required format for Stagehand:
provider/model_name
"""
model_metadata = MODEL_METADATA[LlmModel(self.value)]
model_metadata = self.metadata
model_name = self.value
if len(model_name.split("/")) == 1 and not self.value.startswith(
@@ -107,19 +107,23 @@ class StagehandRecommendedLlmModel(str, Enum):
@property
def provider(self) -> str:
return MODEL_METADATA[LlmModel(self.value)].provider
return self.metadata.provider
@property
def metadata(self) -> ModelMetadata:
return MODEL_METADATA[LlmModel(self.value)]
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
# Fallback to LlmModel enum if registry lookup fails
return LlmModel(self.value).metadata
@property
def context_window(self) -> int:
return MODEL_METADATA[LlmModel(self.value)].context_window
return self.metadata.context_window
@property
def max_output_tokens(self) -> int | None:
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
return self.metadata.max_output_tokens
class StagehandObserveBlock(Block):

View File

@@ -196,15 +196,6 @@ class TestXMLParserBlockSecurity:
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
pass
async def test_rejects_text_outside_root(self):
"""Ensure parser surfaces readable errors for invalid root text."""
block = XMLParserBlock()
invalid_xml = "<root><child>value</child></root> trailing"
with pytest.raises(ValueError, match="text outside the root element"):
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
pass
class TestStoreMediaFileSecurity:
"""Test file storage security limits."""

View File

@@ -1,5 +1,5 @@
from gravitasml.parser import Parser
from gravitasml.token import Token, tokenize
from gravitasml.token import tokenize
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
from backend.data.model import SchemaField
@@ -25,38 +25,6 @@ class XMLParserBlock(Block):
],
)
@staticmethod
def _validate_tokens(tokens: list[Token]) -> None:
"""Ensure the XML has a single root element and no stray text."""
if not tokens:
raise ValueError("XML input is empty.")
depth = 0
root_seen = False
for token in tokens:
if token.type == "TAG_OPEN":
if depth == 0 and root_seen:
raise ValueError("XML must have a single root element.")
depth += 1
if depth == 1:
root_seen = True
elif token.type == "TAG_CLOSE":
depth -= 1
if depth < 0:
raise SyntaxError("Unexpected closing tag in XML input.")
elif token.type in {"TEXT", "ESCAPE"}:
if depth == 0 and token.value:
raise ValueError(
"XML contains text outside the root element; "
"wrap content in a single root tag."
)
if depth != 0:
raise SyntaxError("Unclosed tag detected in XML input.")
if not root_seen:
raise ValueError("XML must include a root element.")
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Security fix: Add size limits to prevent XML bomb attacks
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
@@ -67,9 +35,7 @@ class XMLParserBlock(Block):
)
try:
tokens = list(tokenize(input_data.input_xml))
self._validate_tokens(tokens)
tokens = tokenize(input_data.input_xml)
parser = Parser(tokens)
parsed_result = parser.parse()
yield "parsed_xml", parsed_result

View File

@@ -25,6 +25,7 @@ from prisma.models import AgentBlock
from prisma.types import AgentBlockCreateInput
from pydantic import BaseModel
from backend.data.llm_registry import update_schema_with_llm_registry
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
from backend.util import json
@@ -141,35 +142,59 @@ class BlockInfo(BaseModel):
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]]
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
@classmethod
def clear_schema_cache(cls) -> None:
"""Clear the cached JSON schema for this class."""
# Use None instead of {} because {} is truthy and would prevent regeneration
cls.cached_jsonschema = None # type: ignore
@staticmethod
def clear_all_schema_caches() -> None:
"""Clear cached JSON schemas for all BlockSchema subclasses."""
def clear_recursive(cls: type) -> None:
"""Recursively clear cache for class and all subclasses."""
if hasattr(cls, "clear_schema_cache"):
cls.clear_schema_cache()
for subclass in cls.__subclasses__():
clear_recursive(subclass)
clear_recursive(BlockSchema)
@classmethod
def jsonschema(cls) -> dict[str, Any]:
if cls.cached_jsonschema:
return cls.cached_jsonschema
# Generate schema if not cached
if not cls.cached_jsonschema:
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next(
(k for k in keys if k in obj and len(obj[k]) == 1), None
)
if one_key:
obj.update(obj[one_key][0])
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
if one_key:
obj.update(obj[one_key][0])
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
return obj
return obj
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
# Always post-process to ensure LLM registry data is up-to-date
# This refreshes model options and discriminator mappings even if schema was cached
update_schema_with_llm_registry(cls.cached_jsonschema, cls)
return cls.cached_jsonschema
@@ -785,6 +810,28 @@ def is_block_auth_configured(
async def initialize_blocks() -> None:
# Refresh LLM registry before initializing blocks so blocks can use registry data
# This ensures the registry cache is populated even in executor context
try:
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
# Only refresh if we have DB access (check if Prisma is connected)
from backend.data.db import is_connected
if is_connected():
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
logger.info("LLM registry refreshed during block initialization")
else:
logger.warning(
"Prisma not connected, skipping LLM registry refresh during block initialization"
)
except Exception as exc:
logger.warning(
"Failed to refresh LLM registry during block initialization: %s", exc
)
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.sdk.cost_integration import sync_all_provider_costs

View File

@@ -1,3 +1,4 @@
import logging
from typing import Type
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
@@ -23,19 +24,18 @@ from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
from backend.blocks.llm import (
MODEL_METADATA,
AIConversationBlock,
AIListGeneratorBlock,
AIStructuredResponseGeneratorBlock,
AITextGeneratorBlock,
AITextSummarizerBlock,
LlmModel,
)
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.data import llm_registry
from backend.data.block import Block, BlockCost, BlockCostType
from backend.integrations.credentials_store import (
aiml_api_credentials,
@@ -55,210 +55,63 @@ from backend.integrations.credentials_store import (
v0_credentials,
)
# =============== Configure the cost for each LLM Model call =============== #
logger = logging.getLogger(__name__)
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 4,
LlmModel.O3_MINI: 2,
LlmModel.O1: 16,
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5_2: 6,
LlmModel.GPT5_1: 5,
LlmModel.GPT5: 2,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 5,
LlmModel.GPT41: 2,
LlmModel.GPT41_MINI: 1,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_OPUS: 14,
LlmModel.CLAUDE_4_5_SONNET: 9,
LlmModel.CLAUDE_3_7_SONNET: 5,
LlmModel.CLAUDE_3_HAIKU: 1,
LlmModel.AIML_API_QWEN2_5_72B: 1,
LlmModel.AIML_API_LLAMA3_1_70B: 1,
LlmModel.AIML_API_LLAMA3_3_70B: 1,
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
LlmModel.LLAMA3_3_70B: 1,
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_3: 1,
LlmModel.OLLAMA_LLAMA3_2: 1,
LlmModel.OLLAMA_LLAMA3_8B: 1,
LlmModel.OLLAMA_LLAMA3_405B: 1,
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.OPENAI_GPT_OSS_120B: 1,
LlmModel.OPENAI_GPT_OSS_20B: 1,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
LlmModel.DEEPSEEK_CHAT: 2,
LlmModel.PERPLEXITY_SONAR: 1,
LlmModel.PERPLEXITY_SONAR_PRO: 5,
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
LlmModel.AMAZON_NOVA_LITE_V1: 1,
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
LlmModel.AMAZON_NOVA_PRO_V1: 1,
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
LlmModel.META_LLAMA_4_SCOUT: 1,
LlmModel.META_LLAMA_4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
LlmModel.GROK_4: 9,
LlmModel.GROK_4_FAST: 1,
LlmModel.GROK_4_1_FAST: 1,
LlmModel.GROK_CODE_FAST_1: 1,
LlmModel.KIMI_K2: 1,
LlmModel.QWEN3_235B_A22B_THINKING: 1,
LlmModel.QWEN3_CODER: 9,
LlmModel.GEMINI_2_5_FLASH: 1,
LlmModel.GEMINI_2_0_FLASH: 1,
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.DEEPSEEK_R1_0528: 1,
# v0 by Vercel models
LlmModel.V0_1_5_MD: 1,
LlmModel.V0_1_5_LG: 2,
LlmModel.V0_1_0_MD: 1,
PROVIDER_CREDENTIALS = {
"openai": openai_credentials,
"anthropic": anthropic_credentials,
"groq": groq_credentials,
"open_router": open_router_credentials,
"llama_api": llama_api_credentials,
"aiml_api": aiml_api_credentials,
"v0": v0_credentials,
}
for model in LlmModel:
if model not in MODEL_COST:
raise ValueError(f"Missing MODEL_COST for model: {model}")
# =============== Configure the cost for each LLM Model call =============== #
# All LLM costs now come from the database via llm_registry
LLM_COST: list[BlockCost] = []
LLM_COST = (
# Anthropic Models
[
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
def _build_llm_costs_from_registry() -> list[BlockCost]:
"""Build BlockCost list from all models in the LLM registry."""
costs: list[BlockCost] = []
for model in llm_registry.iter_dynamic_models():
for cost in model.costs:
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
if not credentials:
logger.warning(
"Skipping cost entry for %s due to unknown credentials provider %s",
model.slug,
cost.credential_provider,
)
continue
cost_filter = {
"model": model.slug,
"credentials": {
"id": anthropic_credentials.id,
"provider": anthropic_credentials.provider,
"type": anthropic_credentials.type,
"id": credentials.id,
"provider": credentials.provider,
"type": credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "anthropic"
]
# OpenAI Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": openai_credentials.id,
"provider": openai_credentials.provider,
"type": openai_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "openai"
]
# Groq Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {"id": groq_credentials.id},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "groq"
]
# Open Router Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": open_router_credentials.id,
"provider": open_router_credentials.provider,
"type": open_router_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "open_router"
]
# Llama API Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": llama_api_credentials.id,
"provider": llama_api_credentials.provider,
"type": llama_api_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "llama_api"
]
# v0 by Vercel Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": v0_credentials.id,
"provider": v0_credentials.provider,
"type": v0_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "v0"
]
# AI/ML Api Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": aiml_api_credentials.id,
"provider": aiml_api_credentials.provider,
"type": aiml_api_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "aiml_api"
]
)
}
costs.append(
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter=cost_filter,
cost_amount=cost.credit_cost,
)
)
return costs
def refresh_llm_costs() -> None:
"""Refresh LLM costs from the registry. All costs now come from the database."""
LLM_COST.clear()
LLM_COST.extend(_build_llm_costs_from_registry())
# Initial load will happen after registry is refreshed at startup
# Don't call refresh_llm_costs() here - it will be called after registry refresh
# =============== This is the exhaustive list of cost for each Block =============== #

View File

@@ -1443,8 +1443,10 @@ async def migrate_llm_models(migrate_to: LlmModel):
if field.annotation == LlmModel:
llm_model_fields[block.id] = field_name
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel]
# Get all model slugs from the registry (dynamic, not hardcoded enum)
from backend.data import llm_registry
enum_values = list(llm_registry.get_all_model_slugs_for_validation())
escaped_enum_values = repr(tuple(enum_values)) # hack but works
# Update each block

View File

@@ -0,0 +1,72 @@
"""
LLM Registry module for managing LLM models, providers, and costs dynamically.
This module provides a database-driven registry system for LLM models,
replacing hardcoded model configurations with a flexible admin-managed system.
"""
from backend.data.llm_registry.model_types import ModelMetadata
# Re-export for backwards compatibility
from backend.data.llm_registry.notifications import (
REGISTRY_REFRESH_CHANNEL,
publish_registry_refresh_notification,
subscribe_to_registry_refresh,
)
from backend.data.llm_registry.registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
get_all_model_slugs_for_validation,
get_default_model_slug,
get_dynamic_model_slugs,
get_fallback_model_for_disabled,
get_llm_discriminator_mapping,
get_llm_model_cost,
get_llm_model_metadata,
get_llm_model_schema_options,
get_model_info,
is_model_enabled,
iter_dynamic_models,
refresh_llm_registry,
register_static_costs,
register_static_metadata,
)
from backend.data.llm_registry.schema_utils import (
is_llm_model_field,
refresh_llm_discriminator_mapping,
refresh_llm_model_options,
update_schema_with_llm_registry,
)
__all__ = [
# Types
"ModelMetadata",
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Registry functions
"get_all_model_slugs_for_validation",
"get_default_model_slug",
"get_dynamic_model_slugs",
"get_fallback_model_for_disabled",
"get_llm_discriminator_mapping",
"get_llm_model_cost",
"get_llm_model_metadata",
"get_llm_model_schema_options",
"get_model_info",
"is_model_enabled",
"iter_dynamic_models",
"refresh_llm_registry",
"register_static_costs",
"register_static_metadata",
# Notifications
"REGISTRY_REFRESH_CHANNEL",
"publish_registry_refresh_notification",
"subscribe_to_registry_refresh",
# Schema utilities
"is_llm_model_field",
"refresh_llm_discriminator_mapping",
"refresh_llm_model_options",
"update_schema_with_llm_registry",
]

View File

@@ -0,0 +1,11 @@
"""Type definitions for LLM model metadata."""
from typing import NamedTuple
class ModelMetadata(NamedTuple):
"""Metadata for an LLM model."""
provider: str
context_window: int
max_output_tokens: int | None

View File

@@ -0,0 +1,89 @@
"""
Redis pub/sub notifications for LLM registry updates.
When models are added/updated/removed via the admin UI, this module
publishes notifications to Redis that all executor services subscribe to,
ensuring they refresh their registry cache in real-time.
"""
import asyncio
import logging
from typing import Any
from backend.data.redis_client import connect_async
logger = logging.getLogger(__name__)
# Redis channel name for LLM registry refresh notifications
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
async def publish_registry_refresh_notification() -> None:
"""
Publish a notification to Redis that the LLM registry has been updated.
All executor services subscribed to this channel will refresh their registry.
"""
try:
redis = await connect_async()
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
logger.info("Published LLM registry refresh notification to Redis")
except Exception as exc:
logger.warning(
"Failed to publish LLM registry refresh notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_refresh(
on_refresh: Any, # Async callable that takes no args
) -> None:
"""
Subscribe to Redis notifications for LLM registry updates.
This runs in a loop and processes messages as they arrive.
Args:
on_refresh: Async callable to execute when a refresh notification is received
"""
try:
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications on channel: %s",
REGISTRY_REFRESH_CHANNEL,
)
# Process messages in a loop
while True:
try:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if (
message
and message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info("Received LLM registry refresh notification")
try:
await on_refresh()
except Exception as exc:
logger.error(
"Error refreshing LLM registry from notification: %s",
exc,
exc_info=True,
)
except Exception as exc:
logger.warning(
"Error processing registry refresh message: %s", exc, exc_info=True
)
# Continue listening even if one message fails
await asyncio.sleep(1)
except Exception as exc:
logger.error(
"Failed to subscribe to LLM registry refresh notifications: %s",
exc,
exc_info=True,
)
raise

View File

@@ -0,0 +1,370 @@
"""Core LLM registry implementation for managing models dynamically."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any, Iterable
import prisma.models
from backend.data.llm_registry.model_types import ModelMetadata
logger = logging.getLogger(__name__)
def _json_to_dict(value: Any) -> dict[str, Any]:
"""Convert Prisma Json type to dict, with fallback to empty dict."""
if value is None:
return {}
if isinstance(value, dict):
return value
# Prisma Json type should always be a dict at runtime
return dict(value) if value else {}
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
credit_cost: int
credential_provider: str
credential_id: str | None
credential_type: str | None
currency: str | None
metadata: dict[str, Any]
@dataclass(frozen=True)
class RegistryModelCreator:
"""Creator information for an LLM model."""
id: str
name: str
display_name: str
description: str | None
website_url: str | None
logo_url: str | None
@dataclass(frozen=True)
class RegistryModel:
"""Represents a model in the LLM registry."""
slug: str
display_name: str
description: str | None
metadata: ModelMetadata
capabilities: dict[str, Any]
extra_metadata: dict[str, Any]
provider_display_name: str
is_enabled: bool
is_recommended: bool = False
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
creator: RegistryModelCreator | None = None
_static_metadata: dict[str, ModelMetadata] = {}
_static_costs: dict[str, int] = {}
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_discriminator_mapping: dict[str, str] = {}
_lock = asyncio.Lock()
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
"""Register static metadata for legacy models (deprecated)."""
_static_metadata.update({str(key): value for key, value in metadata.items()})
_refresh_cached_schema()
def register_static_costs(costs: dict[Any, int]) -> None:
"""Register static costs for legacy models (deprecated)."""
_static_costs.update({str(key): value for key, value in costs.items()})
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
options: list[dict[str, str]] = []
# Only include enabled models in the dropdown options
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
options.append(
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
)
for slug, metadata in _static_metadata.items():
if slug in _dynamic_models:
continue
options.append(
{
"label": slug,
"value": slug,
"group": metadata.provider,
"description": "",
}
)
return options
async def refresh_llm_registry() -> None:
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.debug("Found %d LLM model records in database", len(records))
except Exception as exc:
logger.error(
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
)
return
dynamic: dict[str, RegistryModel] = {}
for record in records:
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
)
costs = tuple(
RegistryModelCost(
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=_json_to_dict(cost.metadata),
)
for cost in (record.Costs or [])
)
# Map creator if present
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
dynamic[record.slug] = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=_json_to_dict(record.capabilities),
extra_metadata=_json_to_dict(record.metadata),
provider_display_name=(
record.Provider.displayName
if record.Provider
else record.providerId
),
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
# Atomic swap - build new structures then replace references
# This ensures readers never see partially updated state
global _dynamic_models
_dynamic_models = dynamic
_refresh_cached_schema()
logger.info(
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
len(dynamic),
sum(1 for m in dynamic.values() if m.is_enabled),
sum(1 for m in dynamic.values() if not m.is_enabled),
)
def _refresh_cached_schema() -> None:
"""Refresh cached schema options and discriminator mapping."""
global _schema_options, _discriminator_mapping
# Build new structures
new_options = _build_schema_options()
new_mapping = {slug: entry.metadata.provider for slug, entry in _dynamic_models.items()}
for slug, metadata in _static_metadata.items():
new_mapping.setdefault(slug, metadata.provider)
# Atomic swap - replace references to ensure readers see consistent state
_schema_options = new_options
_discriminator_mapping = new_mapping
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
if slug in _dynamic_models:
return _dynamic_models[slug].metadata
return _static_metadata.get(slug)
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
"""Get model cost configuration by slug."""
if slug in _dynamic_models:
return _dynamic_models[slug].costs
cost_value = _static_costs.get(slug)
if cost_value is None:
return tuple()
return (
RegistryModelCost(
credit_cost=cost_value,
credential_provider="static",
credential_id=None,
credential_type=None,
currency=None,
metadata={},
),
)
def get_llm_model_schema_options() -> list[dict[str, str]]:
"""
Get schema options for LLM model selection dropdown.
Returns a copy of cached schema options that are refreshed when the registry is
updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
"""
# Return a copy to prevent external mutation
return list(_schema_options)
def get_llm_discriminator_mapping() -> dict[str, str]:
"""
Get discriminator mapping for LLM models.
Returns a copy of cached discriminator mapping that is refreshed when the registry
is updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
"""
# Return a copy to prevent external mutation
return dict(_discriminator_mapping)
def get_dynamic_model_slugs() -> set[str]:
"""Get all dynamic model slugs from the registry."""
return set(_dynamic_models.keys())
def get_all_model_slugs_for_validation() -> set[str]:
"""
Get ALL model slugs (both enabled and disabled) for validation purposes.
This is used for JSON schema enum validation - we need to accept any known
model value (even disabled ones) so that existing graphs don't fail validation.
The actual fallback/enforcement happens at runtime in llm_call().
"""
all_slugs = set(_dynamic_models.keys())
all_slugs.update(_static_metadata.keys())
return all_slugs
def iter_dynamic_models() -> Iterable[RegistryModel]:
"""Iterate over all dynamic models in the registry."""
return tuple(_dynamic_models.values())
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
"""
Find a fallback model when the requested model is disabled.
Looks for an enabled model from the same provider. Prefers models with
similar names or capabilities if possible.
Args:
disabled_model_slug: The slug of the disabled model
Returns:
An enabled RegistryModel from the same provider, or None if no fallback found
"""
disabled_model = _dynamic_models.get(disabled_model_slug)
if not disabled_model:
return None
provider = disabled_model.metadata.provider
# Find all enabled models from the same provider
candidates = [
model
for model in _dynamic_models.values()
if model.is_enabled and model.metadata.provider == provider
]
if not candidates:
return None
# Sort by: prefer models with similar context window, then by name
candidates.sort(
key=lambda m: (
abs(m.metadata.context_window - disabled_model.metadata.context_window),
m.display_name.lower(),
)
)
return candidates[0]
def is_model_enabled(model_slug: str) -> bool:
"""Check if a model is enabled in the registry."""
model = _dynamic_models.get(model_slug)
if not model:
# Model not in registry - assume it's a static/legacy model and allow it
return True
return model.is_enabled
def get_model_info(model_slug: str) -> RegistryModel | None:
"""Get model info from the registry."""
return _dynamic_models.get(model_slug)
def get_default_model_slug() -> str | None:
"""
Get the default model slug to use for block defaults.
Returns the recommended model if set (configured via admin UI),
otherwise returns the first enabled model alphabetically.
Returns None if no models are available or enabled.
"""
# Return the recommended model if one is set and enabled
for model in _dynamic_models.values():
if model.is_recommended and model.is_enabled:
return model.slug
# No recommended model set - find first enabled model alphabetically
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
logger.warning(
"No recommended model set, using '%s' as default",
model.slug,
)
return model.slug
# No enabled models available
if _dynamic_models:
logger.error(
"No enabled models found in registry (%d models registered but all disabled)",
len(_dynamic_models),
)
else:
logger.error("No models registered in LLM registry")
return None

View File

@@ -0,0 +1,130 @@
"""
Helper utilities for LLM registry integration with block schemas.
This module handles the dynamic injection of discriminator mappings
and model options from the LLM registry into block schemas.
"""
import logging
from typing import Any
from backend.data.llm_registry.registry import (
get_all_model_slugs_for_validation,
get_default_model_slug,
get_llm_discriminator_mapping,
get_llm_model_schema_options,
)
logger = logging.getLogger(__name__)
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
"""
Check if a field is an LLM model selection field.
Returns True if the field has 'options' in json_schema_extra
(set by llm_model_schema_extra() in blocks/llm.py).
"""
if not hasattr(field_info, "json_schema_extra"):
return False
extra = field_info.json_schema_extra
if isinstance(extra, dict):
return "options" in extra
return False
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
"""
Refresh LLM model options from the registry.
Updates 'options' (for frontend dropdown) to show only enabled models,
but keeps the 'enum' (for validation) inclusive of ALL known models.
This is important because:
- Options: What users see in the dropdown (enabled models only)
- Enum: What values pass validation (all known models, including disabled)
Existing graphs may have disabled models selected - they should pass validation
and the fallback logic in llm_call() will handle using an alternative model.
"""
fresh_options = get_llm_model_schema_options()
if not fresh_options:
return
# Update options array (UI dropdown) - only enabled models
if "options" in field_schema:
field_schema["options"] = fresh_options
all_known_slugs = get_all_model_slugs_for_validation()
if all_known_slugs and "enum" in field_schema:
existing_enum = set(field_schema.get("enum", []))
combined_enum = existing_enum | all_known_slugs
field_schema["enum"] = sorted(combined_enum)
# Set the default value from the registry (gpt-4o if available, else first enabled)
# This ensures new blocks have a sensible default pre-selected
default_slug = get_default_model_slug()
if default_slug:
field_schema["default"] = default_slug
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
"""
Refresh discriminator_mapping for fields that use model-based discrimination.
The discriminator is already set when AICredentialsField() creates the field.
We only need to refresh the mapping when models are added/removed.
"""
if field_schema.get("discriminator") != "model":
return
# Always refresh the mapping to get latest models
fresh_mapping = get_llm_discriminator_mapping()
if fresh_mapping:
field_schema["discriminator_mapping"] = fresh_mapping
def update_schema_with_llm_registry(
schema: dict[str, Any], model_class: type | None = None
) -> None:
"""
Update a JSON schema with current LLM registry data.
Refreshes:
1. Model options for LLM model selection fields (dropdown choices)
2. Discriminator mappings for credentials fields (model → provider)
Args:
schema: The JSON schema to update (mutated in-place)
model_class: The Pydantic model class (optional, for field introspection)
"""
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
if not isinstance(field_schema, dict):
continue
# Refresh model options for LLM model fields
if model_class and hasattr(model_class, "model_fields"):
field_info = model_class.model_fields.get(field_name)
if field_info and is_llm_model_field(field_name, field_info):
try:
refresh_llm_model_options(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh LLM options for field %s: %s",
field_name,
exc,
)
# Refresh discriminator mapping for fields that use model discrimination
try:
refresh_llm_discriminator_mapping(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh discriminator mapping for field %s: %s",
field_name,
exc,
)

View File

@@ -40,6 +40,7 @@ from pydantic_core import (
)
from typing_extensions import TypedDict
from backend.data.llm_registry import update_schema_with_llm_registry
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.settings import Secrets
@@ -544,7 +545,9 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
else:
schema["credentials_provider"] = allowed_providers
schema["credentials_types"] = model_class.allowed_cred_types()
# Do not return anything, just mutate schema in place
# Ensure LLM discriminators are populated (delegates to shared helper)
update_schema_with_llm_registry(schema, model_class)
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
@@ -693,16 +696,20 @@ def CredentialsField(
This is enforced by the `BlockSchema` base class.
"""
field_schema_extra = {
k: v
for k, v in {
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
"discriminator_values": discriminator_values,
}.items()
if v is not None
}
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
field_schema_extra: dict[str, Any] = {}
# Always include discriminator if provided
if discriminator is not None:
field_schema_extra["discriminator"] = discriminator
# Always include discriminator_mapping when discriminator is set (even if empty initially)
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
# Include other optional fields (only if not None)
if required_scopes:
field_schema_extra["credentials_scopes"] = list(required_scopes)
if discriminator_values:
field_schema_extra["discriminator_values"] = discriminator_values
# Merge any json_schema_extra passed in kwargs
if "json_schema_extra" in kwargs:

View File

@@ -1,429 +0,0 @@
"""Data models and access layer for user business understanding."""
import logging
from datetime import datetime
from typing import Any, Optional, cast
import pydantic
from prisma.models import UserBusinessUnderstanding
from prisma.types import (
UserBusinessUnderstandingCreateInput,
UserBusinessUnderstandingUpdateInput,
)
from backend.data.redis_client import get_redis_async
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
# Cache configuration
CACHE_KEY_PREFIX = "understanding"
CACHE_TTL_SECONDS = 48 * 60 * 60 # 48 hours
def _cache_key(user_id: str) -> str:
"""Generate cache key for user business understanding."""
return f"{CACHE_KEY_PREFIX}:{user_id}"
def _json_to_list(value: Any) -> list[str]:
"""Convert Json field to list[str], handling None."""
if value is None:
return []
if isinstance(value, list):
return cast(list[str], value)
return []
class BusinessUnderstandingInput(pydantic.BaseModel):
"""Input model for updating business understanding - all fields optional for incremental updates."""
# User info
user_name: Optional[str] = pydantic.Field(None, description="The user's name")
job_title: Optional[str] = pydantic.Field(None, description="The user's job title")
# Business basics
business_name: Optional[str] = pydantic.Field(
None, description="Name of the user's business"
)
industry: Optional[str] = pydantic.Field(None, description="Industry or sector")
business_size: Optional[str] = pydantic.Field(
None, description="Company size (e.g., '1-10', '11-50')"
)
user_role: Optional[str] = pydantic.Field(
None,
description="User's role in the organization (e.g., 'decision maker', 'implementer')",
)
# Processes & activities
key_workflows: Optional[list[str]] = pydantic.Field(
None, description="Key business workflows"
)
daily_activities: Optional[list[str]] = pydantic.Field(
None, description="Daily activities performed"
)
# Pain points & goals
pain_points: Optional[list[str]] = pydantic.Field(
None, description="Current pain points"
)
bottlenecks: Optional[list[str]] = pydantic.Field(
None, description="Process bottlenecks"
)
manual_tasks: Optional[list[str]] = pydantic.Field(
None, description="Manual/repetitive tasks"
)
automation_goals: Optional[list[str]] = pydantic.Field(
None, description="Desired automation goals"
)
# Current tools
current_software: Optional[list[str]] = pydantic.Field(
None, description="Software/tools currently used"
)
existing_automation: Optional[list[str]] = pydantic.Field(
None, description="Existing automations"
)
# Additional context
additional_notes: Optional[str] = pydantic.Field(
None, description="Any additional context"
)
class BusinessUnderstanding(pydantic.BaseModel):
"""Full business understanding model returned from database."""
id: str
user_id: str
created_at: datetime
updated_at: datetime
# User info
user_name: Optional[str] = None
job_title: Optional[str] = None
# Business basics
business_name: Optional[str] = None
industry: Optional[str] = None
business_size: Optional[str] = None
user_role: Optional[str] = None
# Processes & activities
key_workflows: list[str] = pydantic.Field(default_factory=list)
daily_activities: list[str] = pydantic.Field(default_factory=list)
# Pain points & goals
pain_points: list[str] = pydantic.Field(default_factory=list)
bottlenecks: list[str] = pydantic.Field(default_factory=list)
manual_tasks: list[str] = pydantic.Field(default_factory=list)
automation_goals: list[str] = pydantic.Field(default_factory=list)
# Current tools
current_software: list[str] = pydantic.Field(default_factory=list)
existing_automation: list[str] = pydantic.Field(default_factory=list)
# Additional context
additional_notes: Optional[str] = None
@classmethod
def from_db(cls, db_record: UserBusinessUnderstanding) -> "BusinessUnderstanding":
"""Convert database record to Pydantic model."""
return cls(
id=db_record.id,
user_id=db_record.userId,
created_at=db_record.createdAt,
updated_at=db_record.updatedAt,
user_name=db_record.userName,
job_title=db_record.jobTitle,
business_name=db_record.businessName,
industry=db_record.industry,
business_size=db_record.businessSize,
user_role=db_record.userRole,
key_workflows=_json_to_list(db_record.keyWorkflows),
daily_activities=_json_to_list(db_record.dailyActivities),
pain_points=_json_to_list(db_record.painPoints),
bottlenecks=_json_to_list(db_record.bottlenecks),
manual_tasks=_json_to_list(db_record.manualTasks),
automation_goals=_json_to_list(db_record.automationGoals),
current_software=_json_to_list(db_record.currentSoftware),
existing_automation=_json_to_list(db_record.existingAutomation),
additional_notes=db_record.additionalNotes,
)
def _merge_lists(existing: list | None, new: list | None) -> list | None:
"""Merge two lists, removing duplicates while preserving order."""
if new is None:
return existing
if existing is None:
return new
# Preserve order, add new items that don't exist
merged = list(existing)
for item in new:
if item not in merged:
merged.append(item)
return merged
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
"""Get business understanding from Redis cache."""
try:
redis = await get_redis_async()
cached_data = await redis.get(_cache_key(user_id))
if cached_data:
return BusinessUnderstanding.model_validate_json(cached_data)
except Exception as e:
logger.warning(f"Failed to get understanding from cache: {e}")
return None
async def _set_cache(user_id: str, understanding: BusinessUnderstanding) -> None:
"""Set business understanding in Redis cache with TTL."""
try:
redis = await get_redis_async()
await redis.setex(
_cache_key(user_id),
CACHE_TTL_SECONDS,
understanding.model_dump_json(),
)
except Exception as e:
logger.warning(f"Failed to set understanding in cache: {e}")
async def _delete_cache(user_id: str) -> None:
"""Delete business understanding from Redis cache."""
try:
redis = await get_redis_async()
await redis.delete(_cache_key(user_id))
except Exception as e:
logger.warning(f"Failed to delete understanding from cache: {e}")
async def get_business_understanding(
user_id: str,
) -> Optional[BusinessUnderstanding]:
"""Get the business understanding for a user.
Checks cache first, falls back to database if not cached.
Results are cached for 48 hours.
"""
# Try cache first
cached = await _get_from_cache(user_id)
if cached:
logger.debug(f"Business understanding cache hit for user {user_id}")
return cached
# Cache miss - load from database
logger.debug(f"Business understanding cache miss for user {user_id}")
record = await UserBusinessUnderstanding.prisma().find_unique(
where={"userId": user_id}
)
if record is None:
return None
understanding = BusinessUnderstanding.from_db(record)
# Store in cache for next time
await _set_cache(user_id, understanding)
return understanding
async def upsert_business_understanding(
user_id: str,
data: BusinessUnderstandingInput,
) -> BusinessUnderstanding:
"""
Create or update business understanding with incremental merge strategy.
- String fields: new value overwrites if provided (not None)
- List fields: new items are appended to existing (deduplicated)
"""
# Get existing record for merge
existing = await UserBusinessUnderstanding.prisma().find_unique(
where={"userId": user_id}
)
# Build update data with merge strategy
update_data: UserBusinessUnderstandingUpdateInput = {}
create_data: dict[str, Any] = {"userId": user_id}
# String fields - overwrite if provided
if data.user_name is not None:
update_data["userName"] = data.user_name
create_data["userName"] = data.user_name
if data.job_title is not None:
update_data["jobTitle"] = data.job_title
create_data["jobTitle"] = data.job_title
if data.business_name is not None:
update_data["businessName"] = data.business_name
create_data["businessName"] = data.business_name
if data.industry is not None:
update_data["industry"] = data.industry
create_data["industry"] = data.industry
if data.business_size is not None:
update_data["businessSize"] = data.business_size
create_data["businessSize"] = data.business_size
if data.user_role is not None:
update_data["userRole"] = data.user_role
create_data["userRole"] = data.user_role
if data.additional_notes is not None:
update_data["additionalNotes"] = data.additional_notes
create_data["additionalNotes"] = data.additional_notes
# List fields - merge with existing
if data.key_workflows is not None:
existing_list = _json_to_list(existing.keyWorkflows) if existing else None
merged = _merge_lists(existing_list, data.key_workflows)
update_data["keyWorkflows"] = SafeJson(merged)
create_data["keyWorkflows"] = SafeJson(merged)
if data.daily_activities is not None:
existing_list = _json_to_list(existing.dailyActivities) if existing else None
merged = _merge_lists(existing_list, data.daily_activities)
update_data["dailyActivities"] = SafeJson(merged)
create_data["dailyActivities"] = SafeJson(merged)
if data.pain_points is not None:
existing_list = _json_to_list(existing.painPoints) if existing else None
merged = _merge_lists(existing_list, data.pain_points)
update_data["painPoints"] = SafeJson(merged)
create_data["painPoints"] = SafeJson(merged)
if data.bottlenecks is not None:
existing_list = _json_to_list(existing.bottlenecks) if existing else None
merged = _merge_lists(existing_list, data.bottlenecks)
update_data["bottlenecks"] = SafeJson(merged)
create_data["bottlenecks"] = SafeJson(merged)
if data.manual_tasks is not None:
existing_list = _json_to_list(existing.manualTasks) if existing else None
merged = _merge_lists(existing_list, data.manual_tasks)
update_data["manualTasks"] = SafeJson(merged)
create_data["manualTasks"] = SafeJson(merged)
if data.automation_goals is not None:
existing_list = _json_to_list(existing.automationGoals) if existing else None
merged = _merge_lists(existing_list, data.automation_goals)
update_data["automationGoals"] = SafeJson(merged)
create_data["automationGoals"] = SafeJson(merged)
if data.current_software is not None:
existing_list = _json_to_list(existing.currentSoftware) if existing else None
merged = _merge_lists(existing_list, data.current_software)
update_data["currentSoftware"] = SafeJson(merged)
create_data["currentSoftware"] = SafeJson(merged)
if data.existing_automation is not None:
existing_list = _json_to_list(existing.existingAutomation) if existing else None
merged = _merge_lists(existing_list, data.existing_automation)
update_data["existingAutomation"] = SafeJson(merged)
create_data["existingAutomation"] = SafeJson(merged)
# Upsert
record = await UserBusinessUnderstanding.prisma().upsert(
where={"userId": user_id},
data={
"create": UserBusinessUnderstandingCreateInput(**create_data),
"update": update_data,
},
)
understanding = BusinessUnderstanding.from_db(record)
# Update cache with new understanding
await _set_cache(user_id, understanding)
return understanding
async def clear_business_understanding(user_id: str) -> bool:
"""Clear/delete business understanding for a user from both DB and cache."""
# Delete from cache first
await _delete_cache(user_id)
try:
await UserBusinessUnderstanding.prisma().delete(where={"userId": user_id})
return True
except Exception:
# Record might not exist
return False
def format_understanding_for_prompt(understanding: BusinessUnderstanding) -> str:
"""Format business understanding as text for system prompt injection."""
sections = []
# User info section
user_info = []
if understanding.user_name:
user_info.append(f"Name: {understanding.user_name}")
if understanding.job_title:
user_info.append(f"Job Title: {understanding.job_title}")
if user_info:
sections.append("## User\n" + "\n".join(user_info))
# Business section
business_info = []
if understanding.business_name:
business_info.append(f"Company: {understanding.business_name}")
if understanding.industry:
business_info.append(f"Industry: {understanding.industry}")
if understanding.business_size:
business_info.append(f"Size: {understanding.business_size}")
if understanding.user_role:
business_info.append(f"Role Context: {understanding.user_role}")
if business_info:
sections.append("## Business\n" + "\n".join(business_info))
# Processes section
processes = []
if understanding.key_workflows:
processes.append(f"Key Workflows: {', '.join(understanding.key_workflows)}")
if understanding.daily_activities:
processes.append(
f"Daily Activities: {', '.join(understanding.daily_activities)}"
)
if processes:
sections.append("## Processes\n" + "\n".join(processes))
# Pain points section
pain_points = []
if understanding.pain_points:
pain_points.append(f"Pain Points: {', '.join(understanding.pain_points)}")
if understanding.bottlenecks:
pain_points.append(f"Bottlenecks: {', '.join(understanding.bottlenecks)}")
if understanding.manual_tasks:
pain_points.append(f"Manual Tasks: {', '.join(understanding.manual_tasks)}")
if pain_points:
sections.append("## Pain Points\n" + "\n".join(pain_points))
# Goals section
if understanding.automation_goals:
sections.append(
"## Automation Goals\n"
+ "\n".join(f"- {goal}" for goal in understanding.automation_goals)
)
# Current tools section
tools_info = []
if understanding.current_software:
tools_info.append(
f"Current Software: {', '.join(understanding.current_software)}"
)
if understanding.existing_automation:
tools_info.append(
f"Existing Automation: {', '.join(understanding.existing_automation)}"
)
if tools_info:
sections.append("## Current Tools\n" + "\n".join(tools_info))
# Additional notes
if understanding.additional_notes:
sections.append(f"## Additional Context\n{understanding.additional_notes}")
if not sections:
return ""
return "# User Business Context\n\n" + "\n\n".join(sections)

View File

@@ -0,0 +1,66 @@
"""
Helper functions for LLM registry initialization in executor context.
These functions handle refreshing the LLM registry when the executor starts
and subscribing to real-time updates via Redis pub/sub.
"""
import logging
from backend.data import db, llm_registry
from backend.data.block import BlockSchema, initialize_blocks
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.llm_registry import subscribe_to_registry_refresh
logger = logging.getLogger(__name__)
async def initialize_registry_for_executor() -> None:
"""
Initialize blocks and refresh LLM registry in the executor context.
This must run in the executor's event loop to have access to the database.
"""
try:
# Connect to database if not already connected
if not db.is_connected():
await db.connect()
logger.info("[GraphExecutor] Connected to database for registry refresh")
# Initialize blocks (internally refreshes LLM registry and costs)
await initialize_blocks()
logger.info("[GraphExecutor] Blocks initialized")
except Exception as exc:
logger.warning(
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
exc,
exc_info=True,
)
async def refresh_registry_on_notification() -> None:
"""Refresh LLM registry when notified via Redis pub/sub."""
try:
# Ensure DB is connected
if not db.is_connected():
await db.connect()
# Refresh registry and costs
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they regenerate with new model options
BlockSchema.clear_all_schema_caches()
logger.info("[GraphExecutor] LLM registry refreshed from notification")
except Exception as exc:
logger.error(
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_updates() -> None:
"""Subscribe to Redis pub/sub for LLM registry refresh notifications."""
await subscribe_to_registry_refresh(refresh_registry_on_notification)

View File

@@ -696,6 +696,20 @@ class ExecutionProcessor:
)
self.node_execution_thread.start()
self.node_evaluation_thread.start()
# Initialize LLM registry and subscribe to updates
from backend.executor.llm_registry_init import (
initialize_registry_for_executor,
subscribe_to_registry_updates,
)
asyncio.run_coroutine_threadsafe(
initialize_registry_for_executor(), self.node_execution_loop
)
asyncio.run_coroutine_threadsafe(
subscribe_to_registry_updates(), self.node_execution_loop
)
logger.info(f"[GraphExecutor] {self.tid} started")
@error_logged(swallow=False)

View File

@@ -0,0 +1,849 @@
from __future__ import annotations
from typing import Any, Iterable, Sequence, cast
import prisma
import prisma.models
from backend.data.db import transaction
from backend.server.v2.llm import model as llm_model
def _json_dict(value: Any | None) -> dict[str, Any]:
if not value:
return {}
if isinstance(value, dict):
return value
return {}
def _map_cost(record: prisma.models.LlmModelCost) -> llm_model.LlmModelCost:
return llm_model.LlmModelCost(
id=record.id,
unit=record.unit,
credit_cost=record.creditCost,
credential_provider=record.credentialProvider,
credential_id=record.credentialId,
credential_type=record.credentialType,
currency=record.currency,
metadata=_json_dict(record.metadata),
)
def _map_creator(
record: prisma.models.LlmModelCreator,
) -> llm_model.LlmModelCreator:
return llm_model.LlmModelCreator(
id=record.id,
name=record.name,
display_name=record.displayName,
description=record.description,
website_url=record.websiteUrl,
logo_url=record.logoUrl,
metadata=_json_dict(record.metadata),
)
def _map_model(record: prisma.models.LlmModel) -> llm_model.LlmModel:
costs = []
if record.Costs:
costs = [_map_cost(cost) for cost in record.Costs]
creator = None
if hasattr(record, "Creator") and record.Creator:
creator = _map_creator(record.Creator)
return llm_model.LlmModel(
id=record.id,
slug=record.slug,
display_name=record.displayName,
description=record.description,
provider_id=record.providerId,
creator_id=record.creatorId,
creator=creator,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
capabilities=_json_dict(record.capabilities),
metadata=_json_dict(record.metadata),
costs=costs,
)
def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
models: list[llm_model.LlmModel] = []
if record.Models:
models = [_map_model(model) for model in record.Models]
return llm_model.LlmProvider(
id=record.id,
name=record.name,
display_name=record.displayName,
description=record.description,
default_credential_provider=record.defaultCredentialProvider,
default_credential_id=record.defaultCredentialId,
default_credential_type=record.defaultCredentialType,
supports_tools=record.supportsTools,
supports_json_output=record.supportsJsonOutput,
supports_reasoning=record.supportsReasoning,
supports_parallel_tool=record.supportsParallelTool,
metadata=_json_dict(record.metadata),
models=models,
)
async def list_providers(
include_models: bool = True, enabled_only: bool = False
) -> list[llm_model.LlmProvider]:
"""
List all LLM providers.
Args:
include_models: Whether to include models for each provider
enabled_only: If True, only include enabled models (for public routes)
"""
include: Any = None
if include_models:
model_where = {"isEnabled": True} if enabled_only else None
include = {
"Models": {
"include": {"Costs": True, "Creator": True},
"where": model_where,
}
}
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
return [_map_provider(record) for record in records]
async def upsert_provider(
request: llm_model.UpsertLlmProviderRequest,
provider_id: str | None = None,
) -> llm_model.LlmProvider:
data: Any = {
"name": request.name,
"displayName": request.display_name,
"description": request.description,
"defaultCredentialProvider": request.default_credential_provider,
"defaultCredentialId": request.default_credential_id,
"defaultCredentialType": request.default_credential_type,
"supportsTools": request.supports_tools,
"supportsJsonOutput": request.supports_json_output,
"supportsReasoning": request.supports_reasoning,
"supportsParallelTool": request.supports_parallel_tool,
"metadata": request.metadata,
}
include: Any = {"Models": {"include": {"Costs": True, "Creator": True}}}
if provider_id:
record = await prisma.models.LlmProvider.prisma().update(
where={"id": provider_id},
data=data,
include=include,
)
else:
record = await prisma.models.LlmProvider.prisma().create(
data=data,
include=include,
)
if record is None:
raise ValueError("Failed to create/update provider")
return _map_provider(record)
async def list_models(
provider_id: str | None = None, enabled_only: bool = False
) -> list[llm_model.LlmModel]:
"""
List LLM models.
Args:
provider_id: Optional filter by provider ID
enabled_only: If True, only return enabled models (for public routes)
"""
where: Any = {}
if provider_id:
where["providerId"] = provider_id
if enabled_only:
where["isEnabled"] = True
records = await prisma.models.LlmModel.prisma().find_many(
where=where if where else None,
include={"Costs": True, "Creator": True},
)
return [_map_model(record) for record in records]
def _cost_create_payload(
costs: Sequence[llm_model.LlmModelCostInput],
) -> dict[str, Iterable[dict[str, Any]]]:
create_items = []
for cost in costs:
item: dict[str, Any] = {
"unit": cost.unit,
"creditCost": cost.credit_cost,
"credentialProvider": cost.credential_provider,
}
# Only include optional fields if they have values
if cost.credential_id:
item["credentialId"] = cost.credential_id
if cost.credential_type:
item["credentialType"] = cost.credential_type
if cost.currency:
item["currency"] = cost.currency
# Handle metadata - use Prisma Json type
if cost.metadata is not None and cost.metadata != {}:
item["metadata"] = prisma.Json(cost.metadata)
create_items.append(item)
return {"create": create_items}
async def create_model(
request: llm_model.CreateLlmModelRequest,
) -> llm_model.LlmModel:
data: Any = {
"slug": request.slug,
"displayName": request.display_name,
"description": request.description,
"providerId": request.provider_id,
"contextWindow": request.context_window,
"maxOutputTokens": request.max_output_tokens,
"isEnabled": request.is_enabled,
"capabilities": request.capabilities,
"metadata": request.metadata,
"Costs": _cost_create_payload(request.costs),
}
if request.creator_id:
data["creatorId"] = request.creator_id
record = await prisma.models.LlmModel.prisma().create(
data=data,
include={"Costs": True, "Creator": True},
)
return _map_model(record)
async def update_model(
model_id: str,
request: llm_model.UpdateLlmModelRequest,
) -> llm_model.LlmModel:
# Build scalar field updates (non-relation fields)
scalar_data: Any = {}
if request.display_name is not None:
scalar_data["displayName"] = request.display_name
if request.description is not None:
scalar_data["description"] = request.description
if request.context_window is not None:
scalar_data["contextWindow"] = request.context_window
if request.max_output_tokens is not None:
scalar_data["maxOutputTokens"] = request.max_output_tokens
if request.is_enabled is not None:
scalar_data["isEnabled"] = request.is_enabled
if request.capabilities is not None:
scalar_data["capabilities"] = request.capabilities
if request.metadata is not None:
scalar_data["metadata"] = request.metadata
# Foreign keys can be updated directly as scalar fields
if request.provider_id is not None:
scalar_data["providerId"] = request.provider_id
if request.creator_id is not None:
# Empty string means remove the creator
scalar_data["creatorId"] = request.creator_id if request.creator_id else None
# If we have costs to update, we need to handle them separately
# because nested writes have different constraints
if request.costs is not None:
# Wrap cost replacement in a transaction for atomicity
async with transaction() as tx:
# First update scalar fields
if scalar_data:
await tx.llmmodel.update(
where={"id": model_id},
data=scalar_data,
)
# Then handle costs: delete existing and create new
await tx.llmmodelcost.delete_many(where={"llmModelId": model_id})
if request.costs:
cost_payload = _cost_create_payload(request.costs)
for cost_item in cost_payload["create"]:
cost_item["llmModelId"] = model_id
await tx.llmmodelcost.create(data=cast(Any, cost_item))
# Fetch the updated record (outside transaction)
record = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id},
include={"Costs": True, "Creator": True},
)
else:
# No costs update - simple update
record = await prisma.models.LlmModel.prisma().update(
where={"id": model_id},
data=scalar_data,
include={"Costs": True, "Creator": True},
)
if not record:
raise ValueError(f"Model with id '{model_id}' not found")
return _map_model(record)
async def toggle_model(
model_id: str,
is_enabled: bool,
migrate_to_slug: str | None = None,
migration_reason: str | None = None,
custom_credit_cost: int | None = None,
) -> llm_model.ToggleLlmModelResponse:
"""
Toggle a model's enabled status, optionally migrating workflows when disabling.
Args:
model_id: UUID of the model to toggle
is_enabled: New enabled status
migrate_to_slug: If disabling and this is provided, migrate all workflows
using this model to the specified replacement model
migration_reason: Optional reason for the migration (e.g., "Provider outage")
custom_credit_cost: Optional custom pricing override for migrated workflows.
When set, the billing system should use this cost instead
of the target model's cost for affected nodes.
Returns:
ToggleLlmModelResponse with the updated model and optional migration stats
"""
import json
# Get the model being toggled
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
nodes_migrated = 0
migration_id: str | None = None
# If disabling with migration, perform migration first
if not is_enabled and migrate_to_slug:
# Validate replacement model exists and is enabled
replacement = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": migrate_to_slug}
)
if not replacement:
raise ValueError(f"Replacement model '{migrate_to_slug}' not found")
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{migrate_to_slug}' is disabled. "
f"Please enable it before using it as a replacement."
)
# Perform all operations atomically within a single transaction
# This ensures no nodes are missed between query and update
async with transaction() as tx:
# Get the IDs of nodes that will be migrated (inside transaction for consistency)
node_ids_result = await tx.query_raw(
"""
SELECT id
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
FOR UPDATE
""",
model.slug,
)
migrated_node_ids = (
[row["id"] for row in node_ids_result] if node_ids_result else []
)
nodes_migrated = len(migrated_node_ids)
if nodes_migrated > 0:
# Update by IDs to ensure we only update the exact nodes we queried
node_ids_pg_array = "{" + ",".join(migrated_node_ids) + "}"
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text = ANY($2::text[])
""",
migrate_to_slug,
node_ids_pg_array,
)
record = await tx.llmmodel.update(
where={"id": model_id},
data={"isEnabled": is_enabled},
include={"Costs": True},
)
# Create migration record for revert capability
if nodes_migrated > 0:
migration_data: Any = {
"sourceModelSlug": model.slug,
"targetModelSlug": migrate_to_slug,
"reason": migration_reason,
"migratedNodeIds": json.dumps(migrated_node_ids),
"nodeCount": nodes_migrated,
"customCreditCost": custom_credit_cost,
}
migration_record = await tx.llmmodelmigration.create(
data=migration_data
)
migration_id = migration_record.id
else:
# Simple toggle without migration
record = await prisma.models.LlmModel.prisma().update(
where={"id": model_id},
data={"isEnabled": is_enabled},
include={"Costs": True},
)
if record is None:
raise ValueError(f"Model with id '{model_id}' not found")
return llm_model.ToggleLlmModelResponse(
model=_map_model(record),
nodes_migrated=nodes_migrated,
migrated_to_slug=migrate_to_slug if nodes_migrated > 0 else None,
migration_id=migration_id,
)
async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
"""Get usage count for a model."""
import prisma as prisma_module
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
count_result = await prisma_module.get_client().query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
model.slug,
)
node_count = int(count_result[0]["count"]) if count_result else 0
return llm_model.LlmModelUsageResponse(model_slug=model.slug, node_count=node_count)
async def delete_model(
model_id: str, replacement_model_slug: str
) -> llm_model.DeleteLlmModelResponse:
"""
Delete a model and migrate all AgentNodes using it to a replacement model.
This performs an atomic operation within a database transaction:
1. Validates the model exists
2. Validates the replacement model exists and is enabled
3. Counts affected nodes
4. Migrates all AgentNode.constantInput->model to replacement (in transaction)
5. Deletes the LlmModel record (CASCADE deletes costs) (in transaction)
Args:
model_id: UUID of the model to delete
replacement_model_slug: Slug of the model to migrate to
Returns:
DeleteLlmModelResponse with migration stats
Raises:
ValueError: If model not found, replacement not found, or replacement is disabled
"""
# 1. Get the model being deleted (validation - outside transaction)
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
deleted_slug = model.slug
deleted_display_name = model.displayName
# 2. Validate replacement model exists and is enabled (validation - outside transaction)
replacement = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": replacement_model_slug}
)
if not replacement:
raise ValueError(f"Replacement model '{replacement_model_slug}' not found")
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{replacement_model_slug}' is disabled. "
f"Please enable it before using it as a replacement."
)
# 3 & 4. Perform count, migration and deletion atomically within a transaction
nodes_affected = 0
async with transaction() as tx:
# Count affected nodes (inside transaction for consistency)
count_result = await tx.query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
deleted_slug,
)
nodes_affected = int(count_result[0]["count"]) if count_result else 0
# Migrate all AgentNode.constantInput->model to replacement
if nodes_affected > 0:
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE "constantInput"::jsonb->>'model' = $2
""",
replacement_model_slug,
deleted_slug,
)
# Delete the model (CASCADE will delete costs automatically)
await tx.llmmodel.delete(where={"id": model_id})
return llm_model.DeleteLlmModelResponse(
deleted_model_slug=deleted_slug,
deleted_model_display_name=deleted_display_name,
replacement_model_slug=replacement_model_slug,
nodes_migrated=nodes_affected,
message=(
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}) "
f"and migrated {nodes_affected} workflow node(s) to '{replacement_model_slug}'."
),
)
def _map_migration(
record: prisma.models.LlmModelMigration,
) -> llm_model.LlmModelMigration:
return llm_model.LlmModelMigration(
id=record.id,
source_model_slug=record.sourceModelSlug,
target_model_slug=record.targetModelSlug,
reason=record.reason,
node_count=record.nodeCount,
custom_credit_cost=record.customCreditCost,
is_reverted=record.isReverted,
created_at=record.createdAt.isoformat(),
reverted_at=record.revertedAt.isoformat() if record.revertedAt else None,
)
async def list_migrations(
include_reverted: bool = False,
) -> list[llm_model.LlmModelMigration]:
"""
List model migrations, optionally including reverted ones.
Args:
include_reverted: If True, include reverted migrations. Default is False.
Returns:
List of LlmModelMigration records
"""
where: Any = None if include_reverted else {"isReverted": False}
records = await prisma.models.LlmModelMigration.prisma().find_many(
where=where,
order={"createdAt": "desc"},
)
return [_map_migration(record) for record in records]
async def get_migration(migration_id: str) -> llm_model.LlmModelMigration | None:
"""Get a specific migration by ID."""
record = await prisma.models.LlmModelMigration.prisma().find_unique(
where={"id": migration_id}
)
return _map_migration(record) if record else None
async def revert_migration(
migration_id: str,
re_enable_source_model: bool = True,
) -> llm_model.RevertMigrationResponse:
"""
Revert a model migration, restoring affected nodes to their original model.
This only reverts the specific nodes that were migrated, not all nodes
currently using the target model.
Args:
migration_id: UUID of the migration to revert
re_enable_source_model: Whether to re-enable the source model if it's disabled
Returns:
RevertMigrationResponse with revert stats
Raises:
ValueError: If migration not found, already reverted, or source model not available
"""
import json
from datetime import datetime, timezone
# Get the migration record
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
where={"id": migration_id}
)
if not migration:
raise ValueError(f"Migration with id '{migration_id}' not found")
if migration.isReverted:
raise ValueError(
f"Migration '{migration_id}' has already been reverted "
f"on {migration.revertedAt.isoformat() if migration.revertedAt else 'unknown date'}"
)
# Check if source model exists
source_model = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": migration.sourceModelSlug}
)
if not source_model:
raise ValueError(
f"Source model '{migration.sourceModelSlug}' no longer exists. "
f"Cannot revert migration."
)
# Get the migrated node IDs (Prisma auto-parses JSONB to list)
migrated_node_ids: list[str] = (
migration.migratedNodeIds
if isinstance(migration.migratedNodeIds, list)
else json.loads(migration.migratedNodeIds) # type: ignore
)
if not migrated_node_ids:
raise ValueError("No nodes to revert in this migration")
# Track if we need to re-enable the source model
source_model_was_disabled = not source_model.isEnabled
should_re_enable = source_model_was_disabled and re_enable_source_model
source_model_re_enabled = False
# Perform revert atomically
async with transaction() as tx:
# Re-enable the source model if requested and it was disabled
if should_re_enable:
await tx.llmmodel.update(
where={"id": source_model.id},
data={"isEnabled": True},
)
source_model_re_enabled = True
# Update only the specific nodes that were migrated
# We need to check that they still have the target model (haven't been changed since)
# Use a single batch update for efficiency
# Format node IDs as PostgreSQL text array literal for comparison
node_ids_pg_array = "{" + ",".join(migrated_node_ids) + "}"
result = await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text = ANY($2::text[])
AND "constantInput"::jsonb->>'model' = $3
""",
migration.sourceModelSlug,
node_ids_pg_array,
migration.targetModelSlug,
)
nodes_reverted = result if result else 0
# Mark migration as reverted
await tx.llmmodelmigration.update(
where={"id": migration_id},
data={
"isReverted": True,
"revertedAt": datetime.now(timezone.utc),
},
)
# Calculate nodes that were already changed since migration
nodes_already_changed = len(migrated_node_ids) - nodes_reverted
# Build appropriate message
message_parts = [
f"Successfully reverted migration: {nodes_reverted} node(s) restored "
f"from '{migration.targetModelSlug}' to '{migration.sourceModelSlug}'."
]
if nodes_already_changed > 0:
message_parts.append(
f" {nodes_already_changed} node(s) were already changed and not reverted."
)
if source_model_re_enabled:
message_parts.append(
f" Model '{migration.sourceModelSlug}' has been re-enabled."
)
return llm_model.RevertMigrationResponse(
migration_id=migration_id,
source_model_slug=migration.sourceModelSlug,
target_model_slug=migration.targetModelSlug,
nodes_reverted=nodes_reverted,
nodes_already_changed=nodes_already_changed,
source_model_re_enabled=source_model_re_enabled,
message="".join(message_parts),
)
# ============================================================================
# Creator CRUD operations
# ============================================================================
async def list_creators() -> list[llm_model.LlmModelCreator]:
"""List all LLM model creators."""
records = await prisma.models.LlmModelCreator.prisma().find_many(
order={"displayName": "asc"}
)
return [_map_creator(record) for record in records]
async def get_creator(creator_id: str) -> llm_model.LlmModelCreator | None:
"""Get a specific creator by ID."""
record = await prisma.models.LlmModelCreator.prisma().find_unique(
where={"id": creator_id}
)
return _map_creator(record) if record else None
async def upsert_creator(
request: llm_model.UpsertLlmCreatorRequest,
creator_id: str | None = None,
) -> llm_model.LlmModelCreator:
"""Create or update a model creator."""
data: Any = {
"name": request.name,
"displayName": request.display_name,
"description": request.description,
"websiteUrl": request.website_url,
"logoUrl": request.logo_url,
"metadata": request.metadata,
}
if creator_id:
record = await prisma.models.LlmModelCreator.prisma().update(
where={"id": creator_id},
data=data,
)
else:
record = await prisma.models.LlmModelCreator.prisma().create(data=data)
if record is None:
raise ValueError("Failed to create/update creator")
return _map_creator(record)
async def delete_creator(creator_id: str) -> bool:
"""
Delete a model creator.
This will set creatorId to NULL on all associated models (due to onDelete: SetNull).
Args:
creator_id: UUID of the creator to delete
Returns:
True if deleted successfully
Raises:
ValueError: If creator not found
"""
creator = await prisma.models.LlmModelCreator.prisma().find_unique(
where={"id": creator_id}
)
if not creator:
raise ValueError(f"Creator with id '{creator_id}' not found")
await prisma.models.LlmModelCreator.prisma().delete(where={"id": creator_id})
return True
async def get_recommended_model() -> llm_model.LlmModel | None:
"""
Get the currently recommended LLM model.
Returns:
The recommended model, or None if no model is marked as recommended.
"""
record = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True, "isEnabled": True},
include={"Costs": True, "Creator": True},
)
return _map_model(record) if record else None
async def set_recommended_model(
model_id: str,
) -> tuple[llm_model.LlmModel, str | None]:
"""
Set a model as the recommended model.
This will clear the isRecommended flag from any other model and set it
on the specified model. The model must be enabled.
Args:
model_id: UUID of the model to set as recommended
Returns:
Tuple of (the updated model, previous recommended model slug or None)
Raises:
ValueError: If model not found or not enabled
"""
# First, verify the model exists and is enabled
target_model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}
)
if not target_model:
raise ValueError(f"Model with id '{model_id}' not found")
if not target_model.isEnabled:
raise ValueError(
f"Cannot set disabled model '{target_model.slug}' as recommended"
)
# Get the current recommended model (if any)
current_recommended = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True}
)
previous_slug = current_recommended.slug if current_recommended else None
# Use a transaction to ensure atomicity
async with transaction() as tx:
# Clear isRecommended from all models
await tx.llmmodel.update_many(
where={"isRecommended": True},
data={"isRecommended": False},
)
# Set the new recommended model
await tx.llmmodel.update(
where={"id": model_id},
data={"isRecommended": True},
)
# Fetch and return the updated model
updated_record = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id},
include={"Costs": True, "Creator": True},
)
if not updated_record:
raise ValueError("Failed to fetch updated model")
return _map_model(updated_record), previous_slug
async def get_recommended_model_slug() -> str | None:
"""
Get the slug of the currently recommended LLM model.
Returns:
The slug of the recommended model, or None if no model is marked as recommended.
"""
record = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True, "isEnabled": True},
)
return record.slug if record else None

View File

@@ -0,0 +1,231 @@
from __future__ import annotations
import re
from typing import Any, Optional
import prisma.enums
import pydantic
# Pattern for valid model slugs: alphanumeric start, then alphanumeric, dots, underscores, slashes, hyphens
SLUG_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._/-]*$")
class LlmModelCost(pydantic.BaseModel):
id: str
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
credit_cost: int
credential_provider: str
credential_id: Optional[str] = None
credential_type: Optional[str] = None
currency: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCreator(pydantic.BaseModel):
"""Represents the organization that created/trained the model (e.g., OpenAI, Meta)."""
id: str
name: str
display_name: str
description: Optional[str] = None
website_url: Optional[str] = None
logo_url: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModel(pydantic.BaseModel):
id: str
slug: str
display_name: str
description: Optional[str] = None
provider_id: str
creator_id: Optional[str] = None
creator: Optional[LlmModelCreator] = None
context_window: int
max_output_tokens: Optional[int] = None
is_enabled: bool = True
is_recommended: bool = False
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
class LlmProvider(pydantic.BaseModel):
id: str
name: str
display_name: str
description: Optional[str] = None
default_credential_provider: Optional[str] = None
default_credential_id: Optional[str] = None
default_credential_type: Optional[str] = None
supports_tools: bool = True
supports_json_output: bool = True
supports_reasoning: bool = False
supports_parallel_tool: bool = False
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
models: list[LlmModel] = pydantic.Field(default_factory=list)
class LlmProvidersResponse(pydantic.BaseModel):
providers: list[LlmProvider]
class LlmModelsResponse(pydantic.BaseModel):
models: list[LlmModel]
class LlmCreatorsResponse(pydantic.BaseModel):
creators: list[LlmModelCreator]
class UpsertLlmProviderRequest(pydantic.BaseModel):
name: str
display_name: str
description: Optional[str] = None
default_credential_provider: Optional[str] = None
default_credential_id: Optional[str] = None
default_credential_type: Optional[str] = "api_key"
supports_tools: bool = True
supports_json_output: bool = True
supports_reasoning: bool = False
supports_parallel_tool: bool = False
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class UpsertLlmCreatorRequest(pydantic.BaseModel):
name: str
display_name: str
description: Optional[str] = None
website_url: Optional[str] = None
logo_url: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCostInput(pydantic.BaseModel):
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
credit_cost: int
credential_provider: str
credential_id: Optional[str] = None
credential_type: Optional[str] = "api_key"
currency: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class CreateLlmModelRequest(pydantic.BaseModel):
slug: str
display_name: str
description: Optional[str] = None
provider_id: str
creator_id: Optional[str] = None
context_window: int
max_output_tokens: Optional[int] = None
is_enabled: bool = True
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCostInput]
@pydantic.field_validator("slug")
@classmethod
def validate_slug(cls, v: str) -> str:
if not v or len(v) > 100:
raise ValueError("Slug must be 1-100 characters")
if not SLUG_PATTERN.match(v):
raise ValueError(
"Slug must start with alphanumeric and contain only "
"alphanumeric characters, dots, underscores, slashes, or hyphens"
)
return v
class UpdateLlmModelRequest(pydantic.BaseModel):
display_name: Optional[str] = None
description: Optional[str] = None
context_window: Optional[int] = None
max_output_tokens: Optional[int] = None
is_enabled: Optional[bool] = None
capabilities: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
provider_id: Optional[str] = None
creator_id: Optional[str] = None
costs: Optional[list[LlmModelCostInput]] = None
class ToggleLlmModelRequest(pydantic.BaseModel):
is_enabled: bool
migrate_to_slug: Optional[str] = None
migration_reason: Optional[str] = None # e.g., "Provider outage"
# Custom pricing override for migrated workflows. When set, billing should use
# this cost instead of the target model's cost for affected nodes.
# See LlmModelMigration in schema.prisma for full documentation.
custom_credit_cost: Optional[int] = None
class ToggleLlmModelResponse(pydantic.BaseModel):
model: LlmModel
nodes_migrated: int = 0
migrated_to_slug: Optional[str] = None
migration_id: Optional[str] = None # ID of the migration record for revert
class DeleteLlmModelResponse(pydantic.BaseModel):
deleted_model_slug: str
deleted_model_display_name: str
replacement_model_slug: str
nodes_migrated: int
message: str
class LlmModelUsageResponse(pydantic.BaseModel):
model_slug: str
node_count: int
# Migration tracking models
class LlmModelMigration(pydantic.BaseModel):
id: str
source_model_slug: str
target_model_slug: str
reason: Optional[str] = None
node_count: int
# Custom pricing override - billing should use this instead of target model's cost
custom_credit_cost: Optional[int] = None
is_reverted: bool = False
created_at: str # ISO datetime string
reverted_at: Optional[str] = None
class LlmMigrationsResponse(pydantic.BaseModel):
migrations: list[LlmModelMigration]
class RevertMigrationRequest(pydantic.BaseModel):
re_enable_source_model: bool = (
True # Whether to re-enable the source model if disabled
)
class RevertMigrationResponse(pydantic.BaseModel):
migration_id: str
source_model_slug: str
target_model_slug: str
nodes_reverted: int
nodes_already_changed: int = (
0 # Nodes that were modified since migration (not reverted)
)
source_model_re_enabled: bool = False # Whether the source model was re-enabled
message: str
class SetRecommendedModelRequest(pydantic.BaseModel):
model_id: str
class SetRecommendedModelResponse(pydantic.BaseModel):
model: LlmModel
previous_recommended_slug: Optional[str] = None
message: str
class RecommendedModelResponse(pydantic.BaseModel):
model: Optional[LlmModel] = None
slug: Optional[str] = None

View File

@@ -0,0 +1,25 @@
import autogpt_libs.auth
import fastapi
from backend.server.v2.llm import db as llm_db
from backend.server.v2.llm import model as llm_model
router = fastapi.APIRouter(
prefix="/llm",
tags=["llm"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
@router.get("/models", response_model=llm_model.LlmModelsResponse)
async def list_models():
"""List all enabled LLM models available to users."""
models = await llm_db.list_models(enabled_only=True)
return llm_model.LlmModelsResponse(models=models)
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
async def list_providers():
"""List all LLM providers with their enabled models."""
providers = await llm_db.list_providers(include_models=True, enabled_only=True)
return llm_model.LlmProvidersResponse(providers=providers)

View File

@@ -658,14 +658,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
# Langfuse prompt management
langfuse_public_key: str = Field(default="", description="Langfuse public key")
langfuse_secret_key: str = Field(default="", description="Langfuse secret key")
langfuse_host: str = Field(
default="https://cloud.langfuse.com", description="Langfuse host URL"
)
# Add more secret fields as needed
model_config = SettingsConfigDict(
env_file=".env",

View File

@@ -0,0 +1,78 @@
-- CreateEnum
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
-- CreateTable
CREATE TABLE "LlmProvider" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"defaultCredentialProvider" TEXT,
"defaultCredentialId" TEXT,
"defaultCredentialType" TEXT,
"supportsTools" BOOLEAN NOT NULL DEFAULT TRUE,
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT TRUE,
"supportsReasoning" BOOLEAN NOT NULL DEFAULT FALSE,
"supportsParallelTool" BOOLEAN NOT NULL DEFAULT FALSE,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id"),
CONSTRAINT "LlmProvider_name_key" UNIQUE ("name")
);
-- CreateTable
CREATE TABLE "LlmModel" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"slug" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"providerId" TEXT NOT NULL,
"contextWindow" INTEGER NOT NULL,
"maxOutputTokens" INTEGER,
"isEnabled" BOOLEAN NOT NULL DEFAULT TRUE,
"capabilities" JSONB NOT NULL DEFAULT '{}'::jsonb,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id"),
CONSTRAINT "LlmModel_slug_key" UNIQUE ("slug")
);
-- CreateTable
CREATE TABLE "LlmModelCost" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
"creditCost" INTEGER NOT NULL,
"credentialProvider" TEXT NOT NULL,
"credentialId" TEXT,
"credentialType" TEXT,
"currency" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
"llmModelId" TEXT NOT NULL,
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
-- CreateIndex
CREATE INDEX "LlmModel_slug_idx" ON "LlmModel"("slug");
-- CreateIndex
CREATE INDEX "LlmModelCost_llmModelId_idx" ON "LlmModelCost"("llmModelId");
-- CreateIndex
CREATE INDEX "LlmModelCost_credentialProvider_idx" ON "LlmModelCost"("credentialProvider");
-- AddForeignKey
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,225 @@
-- Seed LLM Registry from existing hard-coded data
-- This migration populates the LlmProvider, LlmModel, and LlmModelCost tables
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
-- Insert Providers
INSERT INTO "LlmProvider" ("id", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "supportsTools", "supportsJsonOutput", "supportsReasoning", "supportsParallelTool", "metadata")
VALUES
(gen_random_uuid(), 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', true, true, true, true, '{}'::jsonb),
(gen_random_uuid(), 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', true, true, true, false, '{}'::jsonb),
(gen_random_uuid(), 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', true, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', true, true, false, false, '{}'::jsonb)
ON CONFLICT ("name") DO NOTHING;
-- Insert Models (using CTEs to reference provider IDs)
WITH provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
model_slug,
model_display_name,
NULL,
p."id",
context_window,
max_output_tokens,
true,
'{}'::jsonb,
'{}'::jsonb
FROM (VALUES
-- OpenAI models
('o3', 'O3', 'openai', 200000, 100000),
('o3-mini', 'O3 Mini', 'openai', 200000, 100000),
('o1', 'O1', 'openai', 200000, 100000),
('o1-mini', 'O1 Mini', 'openai', 128000, 65536),
('gpt-5-2025-08-07', 'GPT 5', 'openai', 400000, 128000),
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 400000, 128000),
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 400000, 128000),
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 400000, 128000),
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 400000, 16384),
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 1047576, 32768),
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 1047576, 32768),
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 128000, 16384),
('gpt-4o', 'GPT 4o', 'openai', 128000, 16384),
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 128000, 4096),
('gpt-3.5-turbo', 'GPT 3.5 Turbo', 'openai', 16385, 4096),
-- Anthropic models
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 200000, 32000),
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 200000, 32000),
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 200000, 64000),
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 200000, 64000),
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 200000, 64000),
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 200000, 64000),
('claude-3-7-sonnet-20250219', 'Claude 3.7 Sonnet', 'anthropic', 200000, 64000),
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 200000, 4096),
-- AI/ML API models
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 32000, 8000),
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 128000, 40000),
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 128000, NULL),
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 131000, 2000),
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 128000, NULL),
-- Groq models
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 128000, 32768),
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 128000, 8192),
-- Ollama models
('llama3.3', 'Llama 3.3', 'ollama', 8192, NULL),
('llama3.2', 'Llama 3.2', 'ollama', 8192, NULL),
('llama3', 'Llama 3', 'ollama', 8192, NULL),
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 8192, NULL),
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 32768, NULL),
-- OpenRouter models
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 1050000, 8192),
('google/gemini-3-pro-preview', 'Gemini 3 Pro Preview', 'open_router', 1048576, 65535),
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 1048576, 65535),
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 1048576, 8192),
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 1048576, 65535),
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 1048576, 8192),
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 128000, 4096),
('cohere/command-r-08-2024', 'Command R', 'open_router', 128000, 4096),
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 128000, 4096),
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 64000, 2048),
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 163840, 163840),
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 127000, 8000),
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 200000, 8000),
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 128000, 16000),
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 131000, 4096),
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 12288, 12288),
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 131072, 131072),
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 131072, 32768),
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 300000, 5120),
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 128000, 5120),
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 300000, 5120),
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 65536, 4096),
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 4096, 4096),
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 131072, 131072),
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 1048576, 1000000),
('x-ai/grok-4', 'Grok 4', 'open_router', 256000, 256000),
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 2000000, 30000),
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 2000000, 30000),
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 256000, 10000),
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 131000, 131000),
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 262144, 262144),
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 262144, 262144),
-- Llama API models
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 128000, 4028),
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 128000, 4028),
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 128000, 4028),
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 128000, 4028),
-- v0 models
('v0-1.5-md', 'v0 1.5 MD', 'v0', 128000, 64000),
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 512000, 64000),
('v0-1.0-md', 'v0 1.0 MD', 'v0', 128000, 64000)
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
JOIN provider_ids p ON p."name" = models.provider_name
ON CONFLICT ("slug") DO NOTHING;
-- Insert Costs (using CTEs to reference model IDs)
WITH model_ids AS (
SELECT "id", "slug", "providerId" FROM "LlmModel"
),
provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
cost,
p."name",
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM (VALUES
-- OpenAI costs
('o3', 4),
('o3-mini', 2),
('o1', 16),
('o1-mini', 4),
('gpt-5-2025-08-07', 2),
('gpt-5.1-2025-11-13', 5),
('gpt-5-mini-2025-08-07', 1),
('gpt-5-nano-2025-08-07', 1),
('gpt-5-chat-latest', 5),
('gpt-4.1-2025-04-14', 2),
('gpt-4.1-mini-2025-04-14', 1),
('gpt-4o-mini', 1),
('gpt-4o', 3),
('gpt-4-turbo', 10),
('gpt-3.5-turbo', 1),
-- Anthropic costs
('claude-opus-4-1-20250805', 21),
('claude-opus-4-20250514', 21),
('claude-sonnet-4-20250514', 5),
('claude-haiku-4-5-20251001', 4),
('claude-opus-4-5-20251101', 14),
('claude-sonnet-4-5-20250929', 9),
('claude-3-7-sonnet-20250219', 5),
('claude-3-haiku-20240307', 1),
-- AI/ML API costs
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
-- Groq costs
('llama-3.3-70b-versatile', 1),
('llama-3.1-8b-instant', 1),
-- Ollama costs
('llama3.3', 1),
('llama3.2', 1),
('llama3', 1),
('llama3.1:405b', 1),
('dolphin-mistral:latest', 1),
-- OpenRouter costs
('google/gemini-2.5-pro-preview-03-25', 4),
('google/gemini-3-pro-preview', 5),
('mistralai/mistral-nemo', 1),
('cohere/command-r-08-2024', 1),
('cohere/command-r-plus-08-2024', 3),
('deepseek/deepseek-chat', 2),
('perplexity/sonar', 1),
('perplexity/sonar-pro', 5),
('perplexity/sonar-deep-research', 10),
('nousresearch/hermes-3-llama-3.1-405b', 1),
('nousresearch/hermes-3-llama-3.1-70b', 1),
('amazon/nova-lite-v1', 1),
('amazon/nova-micro-v1', 1),
('amazon/nova-pro-v1', 1),
('microsoft/wizardlm-2-8x22b', 1),
('gryphe/mythomax-l2-13b', 1),
('meta-llama/llama-4-scout', 1),
('meta-llama/llama-4-maverick', 1),
('x-ai/grok-4', 9),
('x-ai/grok-4-fast', 1),
('x-ai/grok-4.1-fast', 1),
('x-ai/grok-code-fast-1', 1),
('moonshotai/kimi-k2', 1),
('qwen/qwen3-235b-a22b-thinking-2507', 1),
('qwen/qwen3-coder', 9),
('google/gemini-2.5-flash', 1),
('google/gemini-2.0-flash-001', 1),
('google/gemini-2.5-flash-lite-preview-06-17', 1),
('google/gemini-2.0-flash-lite-001', 1),
('deepseek/deepseek-r1-0528', 1),
('openai/gpt-oss-120b', 1),
('openai/gpt-oss-20b', 1),
-- Llama API costs
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
('Llama-3.3-8B-Instruct', 1),
('Llama-3.3-70B-Instruct', 1),
-- v0 costs
('v0-1.5-md', 1),
('v0-1.5-lg', 2),
('v0-1.0-md', 1)
) AS costs(model_slug, cost)
JOIN model_ids m ON m."slug" = costs.model_slug
JOIN provider_ids p ON p."id" = m."providerId";

View File

@@ -1,78 +0,0 @@
-- CreateTable
CREATE TABLE "UserBusinessUnderstanding" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT NOT NULL,
"userName" TEXT,
"jobTitle" TEXT,
"businessName" TEXT,
"industry" TEXT,
"businessSize" TEXT,
"userRole" TEXT,
"keyWorkflows" JSONB,
"dailyActivities" JSONB,
"painPoints" JSONB,
"bottlenecks" JSONB,
"manualTasks" JSONB,
"automationGoals" JSONB,
"currentSoftware" JSONB,
"existingAutomation" JSONB,
"additionalNotes" TEXT,
CONSTRAINT "UserBusinessUnderstanding_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "ChatSession" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT,
"title" TEXT,
"credentials" JSONB NOT NULL DEFAULT '{}',
"successfulAgentRuns" JSONB NOT NULL DEFAULT '{}',
"successfulAgentSchedules" JSONB NOT NULL DEFAULT '{}',
"totalPromptTokens" INTEGER NOT NULL DEFAULT 0,
"totalCompletionTokens" INTEGER NOT NULL DEFAULT 0,
CONSTRAINT "ChatSession_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "ChatMessage" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"sessionId" TEXT NOT NULL,
"role" TEXT NOT NULL,
"content" TEXT,
"name" TEXT,
"toolCallId" TEXT,
"refusal" TEXT,
"toolCalls" JSONB,
"functionCall" JSONB,
"sequence" INTEGER NOT NULL,
CONSTRAINT "ChatMessage_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "UserBusinessUnderstanding_userId_key" ON "UserBusinessUnderstanding"("userId");
-- CreateIndex
CREATE INDEX "UserBusinessUnderstanding_userId_idx" ON "UserBusinessUnderstanding"("userId");
-- CreateIndex
CREATE INDEX "ChatSession_userId_updatedAt_idx" ON "ChatSession"("userId", "updatedAt");
-- CreateIndex
CREATE INDEX "ChatMessage_sessionId_sequence_idx" ON "ChatMessage"("sessionId", "sequence");
-- CreateIndex
CREATE UNIQUE INDEX "ChatMessage_sessionId_sequence_key" ON "ChatMessage"("sessionId", "sequence");
-- AddForeignKey
ALTER TABLE "UserBusinessUnderstanding" ADD CONSTRAINT "UserBusinessUnderstanding_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ChatMessage" ADD CONSTRAINT "ChatMessage_sessionId_fkey" FOREIGN KEY ("sessionId") REFERENCES "ChatSession"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -1,41 +0,0 @@
-- Migration: Add pgvector extension and StoreListingEmbedding table
-- This enables hybrid search combining semantic (embedding) and lexical (tsvector) search
-- Enable pgvector extension for vector similarity search
CREATE EXTENSION IF NOT EXISTS vector;
-- Create table to store embeddings for store listing versions
CREATE TABLE "StoreListingEmbedding" (
"id" TEXT NOT NULL DEFAULT gen_random_uuid(),
"storeListingVersionId" TEXT NOT NULL,
"embedding" vector(1536), -- OpenAI text-embedding-3-small produces 1536 dimensions
"searchableText" TEXT, -- The text that was embedded (for debugging/recomputation)
"contentHash" TEXT, -- MD5 hash of searchable text for change detection
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "StoreListingEmbedding_pkey" PRIMARY KEY ("id")
);
-- Unique constraint: one embedding per listing version
CREATE UNIQUE INDEX "StoreListingEmbedding_storeListingVersionId_key"
ON "StoreListingEmbedding"("storeListingVersionId");
-- HNSW index for fast approximate nearest neighbor search
-- Using cosine distance (vector_cosine_ops) which is standard for text embeddings
CREATE INDEX "StoreListingEmbedding_embedding_idx"
ON "StoreListingEmbedding"
USING hnsw ("embedding" vector_cosine_ops);
-- Index on content hash for fast lookup during change detection
CREATE INDEX "StoreListingEmbedding_contentHash_idx"
ON "StoreListingEmbedding"("contentHash");
-- Foreign key to StoreListingVersion with CASCADE delete
-- When a listing version is deleted, its embedding is automatically removed
ALTER TABLE "StoreListingEmbedding"
ADD CONSTRAINT "StoreListingEmbedding_storeListingVersionId_fkey"
FOREIGN KEY ("storeListingVersionId")
REFERENCES "StoreListingVersion"("id")
ON DELETE CASCADE
ON UPDATE CASCADE;

View File

@@ -1,5 +0,0 @@
-- DropIndex
DROP INDEX "StoreListingEmbedding_embedding_idx";
-- AlterTable
ALTER TABLE "StoreListingEmbedding" ALTER COLUMN "id" DROP DEFAULT;

View File

@@ -0,0 +1,25 @@
-- CreateTable
CREATE TABLE "LlmModelMigration" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"sourceModelSlug" TEXT NOT NULL,
"targetModelSlug" TEXT NOT NULL,
"reason" TEXT,
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
"nodeCount" INTEGER NOT NULL,
"customCreditCost" INTEGER,
"isReverted" BOOLEAN NOT NULL DEFAULT false,
"revertedAt" TIMESTAMP(3),
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LlmModelMigration_sourceModelSlug_idx" ON "LlmModelMigration"("sourceModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_isReverted_idx" ON "LlmModelMigration"("isReverted");

View File

@@ -0,0 +1,127 @@
-- Add LlmModelCreator table
-- Creator represents who made/trained the model (e.g., OpenAI, Meta)
-- This is distinct from Provider who hosts/serves the model (e.g., OpenRouter)
-- Create the LlmModelCreator table
CREATE TABLE "LlmModelCreator" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"websiteUrl" TEXT,
"logoUrl" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
);
-- Create unique index on name
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
-- Add creatorId column to LlmModel
ALTER TABLE "LlmModel" ADD COLUMN "creatorId" TEXT;
-- Add foreign key constraint
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- Create index on creatorId
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
-- Seed creators based on known model creators
INSERT INTO "LlmModelCreator" ("id", "updatedAt", "name", "displayName", "description", "websiteUrl", "metadata")
VALUES
(gen_random_uuid(), CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT models', 'https://openai.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude models', 'https://anthropic.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama models', 'https://ai.meta.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini models', 'https://deepmind.google', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'mistral', 'Mistral AI', 'Creator of Mistral models', 'https://mistral.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command models', 'https://cohere.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek models', 'https://deepseek.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar models', 'https://perplexity.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'qwen', 'Qwen (Alibaba)', 'Creator of Qwen models', 'https://qwenlm.github.io', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova models', 'https://aws.amazon.com/bedrock', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of WizardLM models', 'https://microsoft.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'moonshot', 'Moonshot AI', 'Creator of Kimi models', 'https://moonshot.cn', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nous_research', 'Nous Research', 'Creator of Hermes models', 'https://nousresearch.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 models', 'https://vercel.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cognitive_computations', 'Cognitive Computations', 'Creator of Dolphin models', 'https://erichartford.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', '{}')
ON CONFLICT ("name") DO NOTHING;
-- Update existing models with their creators
-- OpenAI models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'openai')
WHERE "slug" LIKE 'gpt-%' OR "slug" LIKE 'o1%' OR "slug" LIKE 'o3%' OR "slug" LIKE 'openai/%';
-- Anthropic models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'anthropic')
WHERE "slug" LIKE 'claude-%';
-- Meta/Llama models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'meta')
WHERE "slug" LIKE 'llama%' OR "slug" LIKE 'Llama%' OR "slug" LIKE 'meta-llama/%' OR "slug" LIKE '%/llama-%';
-- Google models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'google')
WHERE "slug" LIKE 'google/%' OR "slug" LIKE 'gemini%';
-- Mistral models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'mistral')
WHERE "slug" LIKE 'mistral%' OR "slug" LIKE 'mistralai/%';
-- Cohere models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cohere')
WHERE "slug" LIKE 'cohere/%' OR "slug" LIKE 'command-%';
-- DeepSeek models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'deepseek')
WHERE "slug" LIKE 'deepseek/%' OR "slug" LIKE 'deepseek-%';
-- Perplexity models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'perplexity')
WHERE "slug" LIKE 'perplexity/%' OR "slug" LIKE 'sonar%';
-- Qwen models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'qwen')
WHERE "slug" LIKE 'Qwen/%' OR "slug" LIKE 'qwen/%';
-- xAI/Grok models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'xai')
WHERE "slug" LIKE 'x-ai/%' OR "slug" LIKE 'grok%';
-- Amazon models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'amazon')
WHERE "slug" LIKE 'amazon/%' OR "slug" LIKE 'nova-%';
-- Microsoft models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'microsoft')
WHERE "slug" LIKE 'microsoft/%' OR "slug" LIKE 'wizardlm%';
-- Moonshot models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'moonshot')
WHERE "slug" LIKE 'moonshotai/%' OR "slug" LIKE 'kimi%';
-- NVIDIA models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nvidia')
WHERE "slug" LIKE 'nvidia/%' OR "slug" LIKE '%nemotron%';
-- Nous Research models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nous_research')
WHERE "slug" LIKE 'nousresearch/%' OR "slug" LIKE 'hermes%';
-- Vercel/v0 models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'vercel')
WHERE "slug" LIKE 'v0-%';
-- Dolphin models (Cognitive Computations / Eric Hartford)
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cognitive_computations')
WHERE "slug" LIKE 'dolphin-%';
-- Gryphe models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'gryphe')
WHERE "slug" LIKE 'gryphe/%' OR "slug" LIKE 'mythomax%';

View File

@@ -0,0 +1,4 @@
-- CreateIndex
-- Index for efficient LLM model lookups on AgentNode.constantInput->>'model'
-- This improves performance of model migration queries in the LLM registry
CREATE INDEX "AgentNode_constantInput_model_idx" ON "AgentNode" ((("constantInput"->>'model')));

View File

@@ -0,0 +1,52 @@
-- Add GPT-5.2 model and update O3 slug
-- This migration adds the new GPT-5.2 model added in dev branch
-- Update O3 slug to match dev branch format
UPDATE "LlmModel"
SET "slug" = 'o3-2025-04-16'
WHERE "slug" = 'o3';
-- Update cost reference for O3 if needed
-- (costs are linked by model ID, so no update needed)
-- Add GPT-5.2 model
WITH provider_id AS (
SELECT "id" FROM "LlmProvider" WHERE "name" = 'openai'
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
'gpt-5.2-2025-12-11',
'GPT 5.2',
'OpenAI GPT-5.2 model',
p."id",
400000,
128000,
true,
'{}'::jsonb,
'{}'::jsonb
FROM provider_id p
ON CONFLICT ("slug") DO NOTHING;
-- Add cost for GPT-5.2
WITH model_id AS (
SELECT m."id", p."name" as provider_name
FROM "LlmModel" m
JOIN "LlmProvider" p ON p."id" = m."providerId"
WHERE m."slug" = 'gpt-5.2-2025-12-11'
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
3, -- Same cost tier as GPT-5.1
m.provider_name,
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM model_id m
WHERE NOT EXISTS (
SELECT 1 FROM "LlmModelCost" c WHERE c."llmModelId" = m."id"
);

View File

@@ -0,0 +1,11 @@
-- Add isRecommended field to LlmModel table
-- This allows admins to mark a model as the recommended default
ALTER TABLE "LlmModel" ADD COLUMN "isRecommended" BOOLEAN NOT NULL DEFAULT false;
-- Set gpt-4o-mini as the default recommended model (if it exists)
UPDATE "LlmModel" SET "isRecommended" = true WHERE "slug" = 'gpt-4o-mini' AND "isEnabled" = true;
-- Create unique partial index to enforce only one recommended model at the database level
-- This prevents multiple rows from having isRecommended = true
CREATE UNIQUE INDEX "LlmModel_single_recommended_idx" ON "LlmModel" ("isRecommended") WHERE "isRecommended" = true;

View File

@@ -1924,14 +1924,14 @@ google = ["google-api-python-client (>=2.0.0)", "google-auth (>=2.0.0)"]
[[package]]
name = "gravitasml"
version = "0.1.4"
version = "0.1.3"
description = ""
optional = false
python-versions = "<4.0,>=3.10"
groups = ["main"]
files = [
{file = "gravitasml-0.1.4-py3-none-any.whl", hash = "sha256:671a18b11d3d8a0e270c6a80c72cd058458b18d5ef7560d00010e962ab1bca74"},
{file = "gravitasml-0.1.4.tar.gz", hash = "sha256:35d0d9fec7431817482d53d9c976e375557c3e041d1eb6928e809324a8c866e3"},
{file = "gravitasml-0.1.3-py3-none-any.whl", hash = "sha256:51ff98b4564b7a61f7796f18d5f2558b919d30b3722579296089645b7bc18b85"},
{file = "gravitasml-0.1.3.tar.gz", hash = "sha256:04d240b9fa35878252d57a36032130b6516487468847fcdced1022c032a20f57"},
]
[package.dependencies]
@@ -2777,33 +2777,6 @@ enabler = ["pytest-enabler (>=2.2)"]
test = ["pyfakefs", "pytest (>=6,!=8.1.*)"]
type = ["pygobject-stubs", "pytest-mypy", "shtab", "types-pywin32"]
[[package]]
name = "langfuse"
version = "2.60.10"
description = "A client library for accessing langfuse"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "langfuse-2.60.10-py3-none-any.whl", hash = "sha256:815c6369194aa5b2a24f88eb9952f7c3fc863272c41e90642a71f3bc76f4a11f"},
{file = "langfuse-2.60.10.tar.gz", hash = "sha256:a26d0d927a28ee01b2d12bb5b862590b643cc4e60a28de6e2b0c2cfff5dbfc6a"},
]
[package.dependencies]
anyio = ">=4.4.0,<5.0.0"
backoff = ">=1.10.0"
httpx = ">=0.15.4,<1.0"
idna = ">=3.7,<4.0"
packaging = ">=23.2,<25.0"
pydantic = ">=1.10.7,<3.0"
requests = ">=2,<3"
wrapt = ">=1.14,<2.0"
[package.extras]
langchain = ["langchain (>=0.0.309)"]
llama-index = ["llama-index (>=0.10.12,<2.0.0)"]
openai = ["openai (>=0.27.8)"]
[[package]]
name = "launchdarkly-eventsource"
version = "1.3.0"
@@ -6949,97 +6922,6 @@ files = [
{file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"},
]
[[package]]
name = "wrapt"
version = "1.17.3"
description = "Module for decorators, wrappers and monkey patching."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88bbae4d40d5a46142e70d58bf664a89b6b4befaea7b2ecc14e03cedb8e06c04"},
{file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6b13af258d6a9ad602d57d889f83b9d5543acd471eee12eb51f5b01f8eb1bc2"},
{file = "wrapt-1.17.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd341868a4b6714a5962c1af0bd44f7c404ef78720c7de4892901e540417111c"},
{file = "wrapt-1.17.3-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f9b2601381be482f70e5d1051a5965c25fb3625455a2bf520b5a077b22afb775"},
{file = "wrapt-1.17.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343e44b2a8e60e06a7e0d29c1671a0d9951f59174f3709962b5143f60a2a98bd"},
{file = "wrapt-1.17.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:33486899acd2d7d3066156b03465b949da3fd41a5da6e394ec49d271baefcf05"},
{file = "wrapt-1.17.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e6f40a8aa5a92f150bdb3e1c44b7e98fb7113955b2e5394122fa5532fec4b418"},
{file = "wrapt-1.17.3-cp310-cp310-win32.whl", hash = "sha256:a36692b8491d30a8c75f1dfee65bef119d6f39ea84ee04d9f9311f83c5ad9390"},
{file = "wrapt-1.17.3-cp310-cp310-win_amd64.whl", hash = "sha256:afd964fd43b10c12213574db492cb8f73b2f0826c8df07a68288f8f19af2ebe6"},
{file = "wrapt-1.17.3-cp310-cp310-win_arm64.whl", hash = "sha256:af338aa93554be859173c39c85243970dc6a289fa907402289eeae7543e1ae18"},
{file = "wrapt-1.17.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:273a736c4645e63ac582c60a56b0acb529ef07f78e08dc6bfadf6a46b19c0da7"},
{file = "wrapt-1.17.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5531d911795e3f935a9c23eb1c8c03c211661a5060aab167065896bbf62a5f85"},
{file = "wrapt-1.17.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0610b46293c59a3adbae3dee552b648b984176f8562ee0dba099a56cfbe4df1f"},
{file = "wrapt-1.17.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b32888aad8b6e68f83a8fdccbf3165f5469702a7544472bdf41f582970ed3311"},
{file = "wrapt-1.17.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cccf4f81371f257440c88faed6b74f1053eef90807b77e31ca057b2db74edb1"},
{file = "wrapt-1.17.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8a210b158a34164de8bb68b0e7780041a903d7b00c87e906fb69928bf7890d5"},
{file = "wrapt-1.17.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:79573c24a46ce11aab457b472efd8d125e5a51da2d1d24387666cd85f54c05b2"},
{file = "wrapt-1.17.3-cp311-cp311-win32.whl", hash = "sha256:c31eebe420a9a5d2887b13000b043ff6ca27c452a9a22fa71f35f118e8d4bf89"},
{file = "wrapt-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b1831115c97f0663cb77aa27d381237e73ad4f721391a9bfb2fe8bc25fa6e77"},
{file = "wrapt-1.17.3-cp311-cp311-win_arm64.whl", hash = "sha256:5a7b3c1ee8265eb4c8f1b7d29943f195c00673f5ab60c192eba2d4a7eae5f46a"},
{file = "wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0"},
{file = "wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba"},
{file = "wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd"},
{file = "wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828"},
{file = "wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9"},
{file = "wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396"},
{file = "wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc"},
{file = "wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe"},
{file = "wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c"},
{file = "wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6"},
{file = "wrapt-1.17.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a47681378a0439215912ef542c45a783484d4dd82bac412b71e59cf9c0e1cea0"},
{file = "wrapt-1.17.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a30837587c6ee3cd1a4d1c2ec5d24e77984d44e2f34547e2323ddb4e22eb77"},
{file = "wrapt-1.17.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ecf15d6af39246fe33e507105d67e4b81d8f8d2c6598ff7e3ca1b8a37213f7"},
{file = "wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6fd1ad24dc235e4ab88cda009e19bf347aabb975e44fd5c2fb22a3f6e4141277"},
{file = "wrapt-1.17.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ed61b7c2d49cee3c027372df5809a59d60cf1b6c2f81ee980a091f3afed6a2d"},
{file = "wrapt-1.17.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:423ed5420ad5f5529db9ce89eac09c8a2f97da18eb1c870237e84c5a5c2d60aa"},
{file = "wrapt-1.17.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e01375f275f010fcbf7f643b4279896d04e571889b8a5b3f848423d91bf07050"},
{file = "wrapt-1.17.3-cp313-cp313-win32.whl", hash = "sha256:53e5e39ff71b3fc484df8a522c933ea2b7cdd0d5d15ae82e5b23fde87d44cbd8"},
{file = "wrapt-1.17.3-cp313-cp313-win_amd64.whl", hash = "sha256:1f0b2f40cf341ee8cc1a97d51ff50dddb9fcc73241b9143ec74b30fc4f44f6cb"},
{file = "wrapt-1.17.3-cp313-cp313-win_arm64.whl", hash = "sha256:7425ac3c54430f5fc5e7b6f41d41e704db073309acfc09305816bc6a0b26bb16"},
{file = "wrapt-1.17.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:cf30f6e3c077c8e6a9a7809c94551203c8843e74ba0c960f4a98cd80d4665d39"},
{file = "wrapt-1.17.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e228514a06843cae89621384cfe3a80418f3c04aadf8a3b14e46a7be704e4235"},
{file = "wrapt-1.17.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5ea5eb3c0c071862997d6f3e02af1d055f381b1d25b286b9d6644b79db77657c"},
{file = "wrapt-1.17.3-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:281262213373b6d5e4bb4353bc36d1ba4084e6d6b5d242863721ef2bf2c2930b"},
{file = "wrapt-1.17.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dc4a8d2b25efb6681ecacad42fca8859f88092d8732b170de6a5dddd80a1c8fa"},
{file = "wrapt-1.17.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:373342dd05b1d07d752cecbec0c41817231f29f3a89aa8b8843f7b95992ed0c7"},
{file = "wrapt-1.17.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d40770d7c0fd5cbed9d84b2c3f2e156431a12c9a37dc6284060fb4bec0b7ffd4"},
{file = "wrapt-1.17.3-cp314-cp314-win32.whl", hash = "sha256:fbd3c8319de8e1dc79d346929cd71d523622da527cca14e0c1d257e31c2b8b10"},
{file = "wrapt-1.17.3-cp314-cp314-win_amd64.whl", hash = "sha256:e1a4120ae5705f673727d3253de3ed0e016f7cd78dc463db1b31e2463e1f3cf6"},
{file = "wrapt-1.17.3-cp314-cp314-win_arm64.whl", hash = "sha256:507553480670cab08a800b9463bdb881b2edeed77dc677b0a5915e6106e91a58"},
{file = "wrapt-1.17.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ed7c635ae45cfbc1a7371f708727bf74690daedc49b4dba310590ca0bd28aa8a"},
{file = "wrapt-1.17.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:249f88ed15503f6492a71f01442abddd73856a0032ae860de6d75ca62eed8067"},
{file = "wrapt-1.17.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a03a38adec8066d5a37bea22f2ba6bbf39fcdefbe2d91419ab864c3fb515454"},
{file = "wrapt-1.17.3-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5d4478d72eb61c36e5b446e375bbc49ed002430d17cdec3cecb36993398e1a9e"},
{file = "wrapt-1.17.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223db574bb38637e8230eb14b185565023ab624474df94d2af18f1cdb625216f"},
{file = "wrapt-1.17.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e405adefb53a435f01efa7ccdec012c016b5a1d3f35459990afc39b6be4d5056"},
{file = "wrapt-1.17.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:88547535b787a6c9ce4086917b6e1d291aa8ed914fdd3a838b3539dc95c12804"},
{file = "wrapt-1.17.3-cp314-cp314t-win32.whl", hash = "sha256:41b1d2bc74c2cac6f9074df52b2efbef2b30bdfe5f40cb78f8ca22963bc62977"},
{file = "wrapt-1.17.3-cp314-cp314t-win_amd64.whl", hash = "sha256:73d496de46cd2cdbdbcce4ae4bcdb4afb6a11234a1df9c085249d55166b95116"},
{file = "wrapt-1.17.3-cp314-cp314t-win_arm64.whl", hash = "sha256:f38e60678850c42461d4202739f9bf1e3a737c7ad283638251e79cc49effb6b6"},
{file = "wrapt-1.17.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:70d86fa5197b8947a2fa70260b48e400bf2ccacdcab97bb7de47e3d1e6312225"},
{file = "wrapt-1.17.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:df7d30371a2accfe4013e90445f6388c570f103d61019b6b7c57e0265250072a"},
{file = "wrapt-1.17.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:caea3e9c79d5f0d2c6d9ab96111601797ea5da8e6d0723f77eabb0d4068d2b2f"},
{file = "wrapt-1.17.3-cp38-cp38-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:758895b01d546812d1f42204bd443b8c433c44d090248bf22689df673ccafe00"},
{file = "wrapt-1.17.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:02b551d101f31694fc785e58e0720ef7d9a10c4e62c1c9358ce6f63f23e30a56"},
{file = "wrapt-1.17.3-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:656873859b3b50eeebe6db8b1455e99d90c26ab058db8e427046dbc35c3140a5"},
{file = "wrapt-1.17.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:a9a2203361a6e6404f80b99234fe7fb37d1fc73487b5a78dc1aa5b97201e0f22"},
{file = "wrapt-1.17.3-cp38-cp38-win32.whl", hash = "sha256:55cbbc356c2842f39bcc553cf695932e8b30e30e797f961860afb308e6b1bb7c"},
{file = "wrapt-1.17.3-cp38-cp38-win_amd64.whl", hash = "sha256:ad85e269fe54d506b240d2d7b9f5f2057c2aa9a2ea5b32c66f8902f768117ed2"},
{file = "wrapt-1.17.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:30ce38e66630599e1193798285706903110d4f057aab3168a34b7fdc85569afc"},
{file = "wrapt-1.17.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:65d1d00fbfb3ea5f20add88bbc0f815150dbbde3b026e6c24759466c8b5a9ef9"},
{file = "wrapt-1.17.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a7c06742645f914f26c7f1fa47b8bc4c91d222f76ee20116c43d5ef0912bba2d"},
{file = "wrapt-1.17.3-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7e18f01b0c3e4a07fe6dfdb00e29049ba17eadbc5e7609a2a3a4af83ab7d710a"},
{file = "wrapt-1.17.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f5f51a6466667a5a356e6381d362d259125b57f059103dd9fdc8c0cf1d14139"},
{file = "wrapt-1.17.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:59923aa12d0157f6b82d686c3fd8e1166fa8cdfb3e17b42ce3b6147ff81528df"},
{file = "wrapt-1.17.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:46acc57b331e0b3bcb3e1ca3b421d65637915cfcd65eb783cb2f78a511193f9b"},
{file = "wrapt-1.17.3-cp39-cp39-win32.whl", hash = "sha256:3e62d15d3cfa26e3d0788094de7b64efa75f3a53875cdbccdf78547aed547a81"},
{file = "wrapt-1.17.3-cp39-cp39-win_amd64.whl", hash = "sha256:1f23fa283f51c890eda8e34e4937079114c74b4c81d2b2f1f1d94948f5cc3d7f"},
{file = "wrapt-1.17.3-cp39-cp39-win_arm64.whl", hash = "sha256:24c2ed34dc222ed754247a2702b1e1e89fdbaa4016f324b4b8f1a802d4ffe87f"},
{file = "wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22"},
{file = "wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0"},
]
[[package]]
name = "xattr"
version = "1.2.0"
@@ -7413,4 +7295,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "e5a35435ea318a13aa53df4124a4151bc8a3dd7dfc8f9738e3da24d7906d555e"
content-hash = "b762806d5d58fcf811220890c4705a16dc62b33387af43e3a29399c62a641098"

View File

@@ -27,13 +27,12 @@ google-api-python-client = "^2.177.0"
google-auth-oauthlib = "^1.2.2"
google-cloud-storage = "^3.2.0"
googlemaps = "^4.10.0"
gravitasml = "^0.1.4"
gravitasml = "^0.1.3"
groq = "^0.30.0"
html2text = "^2024.2.26"
jinja2 = "^3.1.6"
jsonref = "^1.1.0"
jsonschema = "^4.25.0"
langfuse = "^2.0.0"
launchdarkly-server-sdk = "^9.12.0"
mem0ai = "^0.1.115"
moviepy = "^2.1.2"

View File

@@ -1,15 +1,14 @@
datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
directUrl = env("DIRECT_URL")
extensions = [pgvector(map: "vector", schema: "public")]
provider = "postgresql"
url = env("DATABASE_URL")
directUrl = env("DIRECT_URL")
}
generator client {
provider = "prisma-client-py"
recursive_type_depth = -1
interface = "asyncio"
previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"]
previewFeatures = ["views", "fullTextSearch"]
partial_type_generator = "backend/data/partial_types.py"
}
@@ -54,7 +53,6 @@ model User {
Profile Profile[]
UserOnboarding UserOnboarding?
BusinessUnderstanding UserBusinessUnderstanding?
BuilderSearchHistory BuilderSearchHistory[]
StoreListings StoreListing[]
StoreListingReviews StoreListingReview[]
@@ -123,109 +121,19 @@ model UserOnboarding {
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
}
model UserBusinessUnderstanding {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String @unique
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
// User info
userName String?
jobTitle String?
// Business basics (string columns)
businessName String?
industry String?
businessSize String? // "1-10", "11-50", "51-200", "201-1000", "1000+"
userRole String? // Role in organization context (e.g., "decision maker", "implementer")
// Processes & activities (JSON arrays)
keyWorkflows Json?
dailyActivities Json?
// Pain points & goals (JSON arrays)
painPoints Json?
bottlenecks Json?
manualTasks Json?
automationGoals Json?
// Current tools (JSON arrays)
currentSoftware Json?
existingAutomation Json?
additionalNotes String?
@@index([userId])
}
model BuilderSearchHistory {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
searchQuery String
filter String[] @default([])
byCreator String[] @default([])
filter String[] @default([])
byCreator String[] @default([])
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////////////// CHAT SESSION TABLES ///////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
model ChatSession {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String?
// Session metadata
title String?
credentials Json @default("{}") // Map of provider -> credential metadata
// Rate limiting counters (stored as JSON maps)
successfulAgentRuns Json @default("{}") // Map of graph_id -> count
successfulAgentSchedules Json @default("{}") // Map of graph_id -> count
// Usage tracking
totalPromptTokens Int @default(0)
totalCompletionTokens Int @default(0)
Messages ChatMessage[]
@@index([userId, updatedAt])
}
model ChatMessage {
id String @id @default(uuid())
createdAt DateTime @default(now())
sessionId String
Session ChatSession @relation(fields: [sessionId], references: [id], onDelete: Cascade)
// Message content
role String // "user", "assistant", "system", "tool", "function"
content String?
name String?
toolCallId String?
refusal String?
toolCalls Json? // List of tool calls for assistant messages
functionCall Json? // Deprecated but kept for compatibility
// Ordering within session
sequence Int
@@unique([sessionId, sequence])
@@index([sessionId, sequence])
}
// This model describes the Agent Graph/Flow (Multi Agent System).
model AgentGraph {
id String @default(uuid())
@@ -813,26 +721,26 @@ view StoreAgent {
storeListingVersionId String
updated_at DateTime
slug String
agent_name String
agent_video String?
agent_output_demo String?
agent_image String[]
slug String
agent_name String
agent_video String?
agent_output_demo String?
agent_image String[]
featured Boolean @default(false)
creator_username String?
creator_avatar String?
sub_heading String
description String
categories String[]
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
runs Int
rating Float
versions String[]
agentGraphVersions String[]
agentGraphId String
is_available Boolean @default(true)
useForOnboarding Boolean @default(false)
featured Boolean @default(false)
creator_username String?
creator_avatar String?
sub_heading String
description String
categories String[]
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
runs Int
rating Float
versions String[]
agentGraphVersions String[]
agentGraphId String
is_available Boolean @default(true)
useForOnboarding Boolean @default(false)
// Materialized views used (refreshed every 15 minutes via pg_cron):
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
@@ -948,14 +856,14 @@ model StoreListingVersion {
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
// Content fields
name String
subHeading String
videoUrl String?
agentOutputDemoUrl String?
imageUrls String[]
description String
instructions String?
categories String[]
name String
subHeading String
videoUrl String?
agentOutputDemoUrl String?
imageUrls String[]
description String
instructions String?
categories String[]
isFeatured Boolean @default(false)
@@ -991,9 +899,6 @@ model StoreListingVersion {
// Reviews for this specific version
Reviews StoreListingReview[]
// Embedding for semantic search (one-to-one)
Embedding StoreListingEmbedding?
@@unique([storeListingId, version])
@@index([storeListingId, submissionStatus, isAvailable])
@@index([submissionStatus])
@@ -1019,24 +924,6 @@ model StoreListingReview {
@@index([reviewByUserId])
}
// Stores vector embeddings for semantic search of store listings
// Uses pgvector extension for efficient similarity search
model StoreListingEmbedding {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
storeListingVersionId String @unique
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
// pgvector embedding - stored as Unsupported type since Prisma doesn't natively support vector
embedding Unsupported("vector(1536)")?
searchableText String? // The text that was embedded (for debugging/recomputation)
contentHash String? // MD5 hash for change detection
@@index([contentHash])
}
enum SubmissionStatus {
DRAFT // Being prepared, not yet submitted
PENDING // Submitted, awaiting review
@@ -1100,6 +987,151 @@ enum APIKeyStatus {
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
///////////// LLM REGISTRY AND BILLING DATA /////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// LlmCostUnit: Defines how LLM MODEL costs are calculated (per run or per token).
// This is distinct from BlockCostType (in backend/data/block.py) which defines
// how BLOCK EXECUTION costs are calculated (per run, per byte, or per second).
// LlmCostUnit is for pricing individual LLM model API calls in the registry,
// while BlockCostType is for billing platform block executions.
enum LlmCostUnit {
RUN
TOKENS
}
model LlmModelCreator {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique // e.g., "openai", "anthropic", "meta"
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
description String?
websiteUrl String? // Link to creator's website
logoUrl String? // URL to creator's logo
metadata Json @default("{}")
Models LlmModel[]
}
model LlmProvider {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique
displayName String
description String?
defaultCredentialProvider String?
defaultCredentialId String?
defaultCredentialType String?
supportsTools Boolean @default(true)
supportsJsonOutput Boolean @default(true)
supportsReasoning Boolean @default(false)
supportsParallelTool Boolean @default(false)
metadata Json @default("{}")
Models LlmModel[]
}
model LlmModel {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
slug String @unique
displayName String
description String?
providerId String
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
creatorId String?
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
contextWindow Int
maxOutputTokens Int?
isEnabled Boolean @default(true)
isRecommended Boolean @default(false)
capabilities Json @default("{}")
metadata Json @default("{}")
Costs LlmModelCost[]
@@index([providerId, isEnabled])
@@index([creatorId])
@@index([slug])
}
model LlmModelCost {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
unit LlmCostUnit @default(RUN)
creditCost Int
credentialProvider String
credentialId String?
credentialType String?
currency String?
metadata Json @default("{}")
llmModelId String
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
@@index([llmModelId])
@@index([credentialProvider])
}
// Tracks model migrations for revert capability
// When a model is disabled with migration, we record which nodes were affected
// so they can be reverted when the original model is back online
model LlmModelMigration {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
sourceModelSlug String // The original model that was disabled
targetModelSlug String // The model workflows were migrated to
reason String? // Why the migration happened (e.g., "Provider outage")
// Track affected nodes as JSON array of node IDs
// Format: ["node-uuid-1", "node-uuid-2", ...]
migratedNodeIds Json @default("[]")
nodeCount Int // Number of nodes migrated
// Custom pricing override for migrated workflows during the migration period.
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
// to avoid billing surprises, or offer a discount during the transition.
//
// IMPORTANT: This field is intended for integration with the billing system.
// When billing calculates costs for nodes affected by this migration, it should
// check if customCreditCost is set and use it instead of the target model's cost.
// If null, the target model's normal cost applies.
//
// TODO: Integrate with billing system to apply this override during cost calculation.
customCreditCost Int?
// Revert tracking
isReverted Boolean @default(false)
revertedAt DateTime?
@@index([sourceModelSlug])
@@index([targetModelSlug])
@@index([isReverted])
}
////////////// OAUTH PROVIDER TABLES //////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
@@ -1111,16 +1143,16 @@ model OAuthApplication {
updatedAt DateTime @updatedAt
// Application metadata
name String
description String?
logoUrl String? // URL to app logo stored in GCS
clientId String @unique
clientSecret String // Hashed with Scrypt (same as API keys)
clientSecretSalt String // Salt for Scrypt hashing
name String
description String?
logoUrl String? // URL to app logo stored in GCS
clientId String @unique
clientSecret String // Hashed with Scrypt (same as API keys)
clientSecretSalt String // Salt for Scrypt hashing
// OAuth configuration
redirectUris String[] // Allowed callback URLs
grantTypes String[] @default(["authorization_code", "refresh_token"])
grantTypes String[] @default(["authorization_code", "refresh_token"])
scopes APIKeyPermission[] // Which permissions the app can request
// Application management

View File

@@ -1,146 +0,0 @@
/**
* Cloudflare Workers Script for docs.agpt.co → agpt.co/docs migration
*
* Deploy this script to handle all redirects with a single JavaScript file.
* No rule limits, easy to maintain, handles all edge cases.
*/
// URL mapping for special cases that don't follow patterns
const SPECIAL_MAPPINGS = {
// Root page
'/': '/docs/platform',
// Special cases that don't follow standard patterns
'/platform/d_id/': '/docs/integrations/block-integrations/d-id',
'/platform/blocks/blocks/': '/docs/integrations',
'/platform/blocks/decoder_block/': '/docs/integrations/block-integrations/text-decoder',
'/platform/blocks/http': '/docs/integrations/block-integrations/send-web-request',
'/platform/blocks/llm/': '/docs/integrations/block-integrations/ai-and-llm',
'/platform/blocks/time_blocks': '/docs/integrations/block-integrations/time-and-date',
'/platform/blocks/text_to_speech_block': '/docs/integrations/block-integrations/text-to-speech',
'/platform/blocks/ai_shortform_video_block': '/docs/integrations/block-integrations/ai-shortform-video',
'/platform/blocks/replicate_flux_advanced': '/docs/integrations/block-integrations/replicate-flux-advanced',
'/platform/blocks/flux_kontext': '/docs/integrations/block-integrations/flux-kontext',
'/platform/blocks/ai_condition/': '/docs/integrations/block-integrations/ai-condition',
'/platform/blocks/email_block': '/docs/integrations/block-integrations/email',
'/platform/blocks/google_maps': '/docs/integrations/block-integrations/google-maps',
'/platform/blocks/google/gmail': '/docs/integrations/block-integrations/gmail',
'/platform/blocks/github/issues/': '/docs/integrations/block-integrations/github-issues',
'/platform/blocks/github/repo/': '/docs/integrations/block-integrations/github-repo',
'/platform/blocks/github/pull_requests': '/docs/integrations/block-integrations/github-pull-requests',
'/platform/blocks/twitter/twitter': '/docs/integrations/block-integrations/twitter',
'/classic/setup/': '/docs/classic/setup/setting-up-autogpt-classic',
'/code-of-conduct/': '/docs/classic/help-us-improve-autogpt/code-of-conduct',
'/contributing/': '/docs/classic/contributing',
'/contribute/': '/docs/contribute',
'/forge/components/introduction/': '/docs/classic/forge/introduction'
};
/**
* Transform path by replacing underscores with hyphens and removing trailing slashes
*/
function transformPath(path) {
return path.replace(/_/g, '-').replace(/\/$/, '');
}
/**
* Handle docs.agpt.co redirects
*/
function handleDocsRedirect(url) {
const pathname = url.pathname;
// Check special mappings first
if (SPECIAL_MAPPINGS[pathname]) {
return `https://agpt.co${SPECIAL_MAPPINGS[pathname]}`;
}
// Pattern-based redirects
// Platform blocks: /platform/blocks/* → /docs/integrations/block-integrations/*
if (pathname.startsWith('/platform/blocks/')) {
const blockName = pathname.substring('/platform/blocks/'.length);
const transformedName = transformPath(blockName);
return `https://agpt.co/docs/integrations/block-integrations/${transformedName}`;
}
// Platform contributing: /platform/contributing/* → /docs/platform/contributing/*
if (pathname.startsWith('/platform/contributing/')) {
const subPath = pathname.substring('/platform/contributing/'.length);
return `https://agpt.co/docs/platform/contributing/${subPath}`;
}
// Platform general: /platform/* → /docs/platform/* (with underscore→hyphen)
if (pathname.startsWith('/platform/')) {
const subPath = pathname.substring('/platform/'.length);
const transformedPath = transformPath(subPath);
return `https://agpt.co/docs/platform/${transformedPath}`;
}
// Forge components: /forge/components/* → /docs/classic/forge/introduction/*
if (pathname.startsWith('/forge/components/')) {
const subPath = pathname.substring('/forge/components/'.length);
return `https://agpt.co/docs/classic/forge/introduction/${subPath}`;
}
// Forge general: /forge/* → /docs/classic/forge/*
if (pathname.startsWith('/forge/')) {
const subPath = pathname.substring('/forge/'.length);
return `https://agpt.co/docs/classic/forge/${subPath}`;
}
// Classic: /classic/* → /docs/classic/*
if (pathname.startsWith('/classic/')) {
const subPath = pathname.substring('/classic/'.length);
return `https://agpt.co/docs/classic/${subPath}`;
}
// Default fallback
return 'https://agpt.co/docs/';
}
/**
* Main Worker function
*/
export default {
async fetch(request, env, ctx) {
const url = new URL(request.url);
// Only handle docs.agpt.co requests
if (url.hostname === 'docs.agpt.co') {
const redirectUrl = handleDocsRedirect(url);
return new Response(null, {
status: 301,
headers: {
'Location': redirectUrl,
'Cache-Control': 'max-age=300' // Cache redirects for 5 minutes
}
});
}
// For non-docs requests, pass through or return 404
return new Response('Not Found', { status: 404 });
}
};
// Test function for local development
export function testRedirects() {
const testCases = [
'https://docs.agpt.co/',
'https://docs.agpt.co/platform/getting-started/',
'https://docs.agpt.co/platform/advanced_setup/',
'https://docs.agpt.co/platform/blocks/basic/',
'https://docs.agpt.co/platform/blocks/ai_condition/',
'https://docs.agpt.co/classic/setup/',
'https://docs.agpt.co/forge/components/agents/',
'https://docs.agpt.co/contributing/',
'https://docs.agpt.co/unknown-page'
];
console.log('Testing redirects:');
testCases.forEach(testUrl => {
const url = new URL(testUrl);
const result = handleDocsRedirect(url);
console.log(`${testUrl}${result}`);
});
}

View File

@@ -46,14 +46,13 @@
"@radix-ui/react-scroll-area": "1.2.10",
"@radix-ui/react-select": "2.2.6",
"@radix-ui/react-separator": "1.1.7",
"@radix-ui/react-slider": "1.3.6",
"@radix-ui/react-slot": "1.2.3",
"@radix-ui/react-switch": "1.2.6",
"@radix-ui/react-tabs": "1.1.13",
"@radix-ui/react-toast": "1.2.15",
"@radix-ui/react-tooltip": "1.2.8",
"@rjsf/core": "6.1.2",
"@rjsf/utils": "6.1.2",
"@rjsf/core": "5.24.13",
"@rjsf/utils": "5.24.13",
"@rjsf/validator-ajv8": "5.24.13",
"@sentry/nextjs": "10.27.0",
"@supabase/ssr": "0.7.0",

View File

@@ -62,9 +62,6 @@ importers:
'@radix-ui/react-separator':
specifier: 1.1.7
version: 1.1.7(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@radix-ui/react-slider':
specifier: 1.3.6
version: 1.3.6(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@radix-ui/react-slot':
specifier: 1.2.3
version: 1.2.3(@types/react@18.3.17)(react@18.3.1)
@@ -81,14 +78,14 @@ importers:
specifier: 1.2.8
version: 1.2.8(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@rjsf/core':
specifier: 6.1.2
version: 6.1.2(@rjsf/utils@6.1.2(react@18.3.1))(react@18.3.1)
specifier: 5.24.13
version: 5.24.13(@rjsf/utils@5.24.13(react@18.3.1))(react@18.3.1)
'@rjsf/utils':
specifier: 6.1.2
version: 6.1.2(react@18.3.1)
specifier: 5.24.13
version: 5.24.13(react@18.3.1)
'@rjsf/validator-ajv8':
specifier: 5.24.13
version: 5.24.13(@rjsf/utils@6.1.2(react@18.3.1))
version: 5.24.13(@rjsf/utils@5.24.13(react@18.3.1))
'@sentry/nextjs':
specifier: 10.27.0
version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.101.3(esbuild@0.25.9))
@@ -2313,19 +2310,6 @@ packages:
'@types/react-dom':
optional: true
'@radix-ui/react-slider@1.3.6':
resolution: {integrity: sha512-JPYb1GuM1bxfjMRlNLE+BcmBC8onfCi60Blk7OBqi2MLTFdS+8401U4uFjnwkOr49BLmXxLC6JHkvAsx5OJvHw==}
peerDependencies:
'@types/react': '*'
'@types/react-dom': '*'
react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc
react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc
peerDependenciesMeta:
'@types/react':
optional: true
'@types/react-dom':
optional: true
'@radix-ui/react-slot@1.2.3':
resolution: {integrity: sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==}
peerDependencies:
@@ -2495,18 +2479,18 @@ packages:
react-redux:
optional: true
'@rjsf/core@6.1.2':
resolution: {integrity: sha512-fcEO6kArMcVIzTBoBxNStqxzAL417NDw049nmNx11pIcMwUnU5sAkSW18c8kgZOT6v1xaZhQrY+X5cBzzHy9+g==}
engines: {node: '>=20'}
'@rjsf/core@5.24.13':
resolution: {integrity: sha512-ONTr14s7LFIjx2VRFLuOpagL76sM/HPy6/OhdBfq6UukINmTIs6+aFN0GgcR0aXQHFDXQ7f/fel0o/SO05Htdg==}
engines: {node: '>=14'}
peerDependencies:
'@rjsf/utils': ^6.x
react: '>=18'
'@rjsf/utils': ^5.24.x
react: ^16.14.0 || >=17
'@rjsf/utils@6.1.2':
resolution: {integrity: sha512-Px3FIkE1KK0745Qng9v88RZ0O7hcLf/1JUu0j00g+r6C8Zyokna42Hz/5TKyyQSKJqgVYcj2Z47YroVLenUM3A==}
engines: {node: '>=20'}
'@rjsf/utils@5.24.13':
resolution: {integrity: sha512-rNF8tDxIwTtXzz5O/U23QU73nlhgQNYJ+Sv5BAwQOIyhIE2Z3S5tUiSVMwZHt0julkv/Ryfwi+qsD4FiE5rOuw==}
engines: {node: '>=14'}
peerDependencies:
react: '>=18'
react: ^16.14.0 || >=17
'@rjsf/validator-ajv8@5.24.13':
resolution: {integrity: sha512-oWHP7YK581M8I5cF1t+UXFavnv+bhcqjtL1a7MG/Kaffi0EwhgcYjODrD8SsnrhncsEYMqSECr4ZOEoirnEUWw==}
@@ -3659,9 +3643,6 @@ packages:
'@webassemblyjs/wast-printer@1.14.1':
resolution: {integrity: sha512-kPSSXE6De1XOR820C90RIo2ogvZG+c3KiHzqUoO/F34Y2shGzesfqv7o57xrxovZJH/MetF5UjroJ/R/3isoiw==}
'@x0k/json-schema-merge@1.0.2':
resolution: {integrity: sha512-1734qiJHNX3+cJGDMMw2yz7R+7kpbAtl5NdPs1c/0gO5kYT6s4dMbLXiIfpZNsOYhGZI3aH7FWrj4Zxz7epXNg==}
'@xtuc/ieee754@1.2.0':
resolution: {integrity: sha512-DX8nKgqcGwsc0eJSqYt5lwP4DH5FlHnmuWWBRy7X0NcaGR0ZtuyeESgMwTYVEtxmsNGY+qit4QYT/MIYTOTPeA==}
@@ -4178,6 +4159,12 @@ packages:
compare-versions@6.1.1:
resolution: {integrity: sha512-4hm4VPpIecmlg59CHXnRDnqGplJFrbLG4aFEl5vl6cK1u76ws3LLvX7ikFnTDl5vo39sjWD6AaDPYodJp/NNHg==}
compute-gcd@1.2.1:
resolution: {integrity: sha512-TwMbxBNz0l71+8Sc4czv13h4kEqnchV9igQZBi6QUaz09dnz13juGnnaWWJTRsP3brxOoxeB4SA2WELLw1hCtg==}
compute-lcm@1.1.2:
resolution: {integrity: sha512-OFNPdQAXnQhDSKioX8/XYT6sdUlXwpeMjfd6ApxMJfyZ4GxmLR1xvMERctlYhlHwIiz6CSpBc2+qYKjHGZw4TQ==}
concat-map@0.0.1:
resolution: {integrity: sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==}
@@ -4842,9 +4829,6 @@ packages:
fast-uri@3.0.6:
resolution: {integrity: sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==}
fast-uri@3.1.0:
resolution: {integrity: sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==}
fastq@1.19.1:
resolution: {integrity: sha512-GwLTyxkCXjXbxqIhTsMI2Nui8huMPtnxg7krajPJAjnEG/iiOS7i+zCtWGZR9G0NBKbXKh6X9m9UIsYX/N6vvQ==}
@@ -5493,6 +5477,13 @@ packages:
json-parse-even-better-errors@2.3.1:
resolution: {integrity: sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==}
json-schema-compare@0.2.2:
resolution: {integrity: sha512-c4WYmDKyJXhs7WWvAWm3uIYnfyWFoIp+JEoX34rctVvEkMYCPGhXtvmFFXiffBbxfZsvQ0RNnV5H7GvDF5HCqQ==}
json-schema-merge-allof@0.8.1:
resolution: {integrity: sha512-CTUKmIlPJbsWfzRRnOXz+0MjIqvnleIXwFTzz+t9T86HnYX/Rozria6ZVGLktAU9e+NygNljveP+yxqtQp/Q4w==}
engines: {node: '>=12.0.0'}
json-schema-traverse@0.4.1:
resolution: {integrity: sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==}
@@ -5608,9 +5599,6 @@ packages:
lodash-es@4.17.21:
resolution: {integrity: sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==}
lodash-es@4.17.22:
resolution: {integrity: sha512-XEawp1t0gxSi9x01glktRZ5HDy0HXqrM0x5pXQM98EaI0NxO6jVM7omDOxsuEo5UIASAnm2bRp1Jt/e0a2XU8Q==}
lodash.camelcase@4.3.0:
resolution: {integrity: sha512-TwuEnCnxbc3rAvhf/LbG7tJUDzhqXyFnv3dtzLOPgCG/hODL7WFnsbwktkD7yUV0RrreP/l1PALq/YSg6VvjlA==}
@@ -5700,14 +5688,11 @@ packages:
markdown-table@3.0.4:
resolution: {integrity: sha512-wiYz4+JrLyb/DqW2hkFJxP7Vd7JuTDm77fvbM8VfEQdmSMqcImWeeRbHwZjBjIFki/VaMK2BhFi7oUUZeM5bqw==}
markdown-to-jsx@8.0.0:
resolution: {integrity: sha512-hWEaRxeCDjes1CVUQqU+Ov0mCqBqkGhLKjL98KdbwHSgEWZZSJQeGlJQatVfeZ3RaxrfTrZZ3eczl2dhp5c/pA==}
markdown-to-jsx@7.7.13:
resolution: {integrity: sha512-DiueEq2bttFcSxUs85GJcQVrOr0+VVsPfj9AEUPqmExJ3f8P/iQNvZHltV4tm1XVhu1kl0vWBZWT3l99izRMaA==}
engines: {node: '>= 10'}
peerDependencies:
react: '>= 0.14.0'
peerDependenciesMeta:
react:
optional: true
math-intrinsics@1.1.0:
resolution: {integrity: sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==}
@@ -7593,6 +7578,21 @@ packages:
resolution: {integrity: sha512-b+1eJOlsR9K8HJpow9Ok3fiWOWSIcIzXodvv0rQjVoOVNpWMpxf1wZNpt4y9h10odCNrqnYp1OBzRktckBe3sA==}
hasBin: true
validate.io-array@1.0.6:
resolution: {integrity: sha512-DeOy7CnPEziggrOO5CZhVKJw6S3Yi7e9e65R1Nl/RTN1vTQKnzjfvks0/8kQ40FP/dsjRAOd4hxmJ7uLa6vxkg==}
validate.io-function@1.0.2:
resolution: {integrity: sha512-LlFybRJEriSuBnUhQyG5bwglhh50EpTL2ul23MPIuR1odjO7XaMLFV8vHGwp7AZciFxtYOeiSCT5st+XSPONiQ==}
validate.io-integer-array@1.0.0:
resolution: {integrity: sha512-mTrMk/1ytQHtCY0oNO3dztafHYyGU88KL+jRxWuzfOmQb+4qqnWmI+gykvGp8usKZOM0H7keJHEbRaFiYA0VrA==}
validate.io-integer@1.0.5:
resolution: {integrity: sha512-22izsYSLojN/P6bppBqhgUDjCkr5RY2jd+N2a3DCAUey8ydvrZ/OkGvFPR7qfOpwR2LC5p4Ngzxz36g5Vgr/hQ==}
validate.io-number@1.0.3:
resolution: {integrity: sha512-kRAyotcbNaSYoDnXvb4MHg/0a1egJdLwS6oJ38TJY7aw9n93Fl/3blIXdyYvPOp55CNxywooG/3BcrwNrBpcSg==}
validator@13.15.20:
resolution: {integrity: sha512-KxPOq3V2LmfQPP4eqf3Mq/zrT0Dqp2Vmx2Bn285LwVahLc+CsxOM0crBHczm8ijlcjZ0Q5Xd6LW3z3odTPnlrw==}
engines: {node: '>= 0.10'}
@@ -9903,25 +9903,6 @@ snapshots:
'@types/react': 18.3.17
'@types/react-dom': 18.3.5(@types/react@18.3.17)
'@radix-ui/react-slider@1.3.6(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
dependencies:
'@radix-ui/number': 1.1.1
'@radix-ui/primitive': 1.1.3
'@radix-ui/react-collection': 1.1.7(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.17)(react@18.3.1)
'@radix-ui/react-context': 1.1.2(@types/react@18.3.17)(react@18.3.1)
'@radix-ui/react-direction': 1.1.1(@types/react@18.3.17)(react@18.3.1)
'@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@radix-ui/react-use-controllable-state': 1.2.2(@types/react@18.3.17)(react@18.3.1)
'@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.17)(react@18.3.1)
'@radix-ui/react-use-previous': 1.1.1(@types/react@18.3.17)(react@18.3.1)
'@radix-ui/react-use-size': 1.1.1(@types/react@18.3.17)(react@18.3.1)
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
optionalDependencies:
'@types/react': 18.3.17
'@types/react-dom': 18.3.5(@types/react@18.3.17)
'@radix-ui/react-slot@1.2.3(@types/react@18.3.17)(react@18.3.1)':
dependencies:
'@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.17)(react@18.3.1)
@@ -10084,28 +10065,27 @@ snapshots:
react: 18.3.1
react-redux: 9.2.0(@types/react@18.3.17)(react@18.3.1)(redux@5.0.1)
'@rjsf/core@6.1.2(@rjsf/utils@6.1.2(react@18.3.1))(react@18.3.1)':
'@rjsf/core@5.24.13(@rjsf/utils@5.24.13(react@18.3.1))(react@18.3.1)':
dependencies:
'@rjsf/utils': 6.1.2(react@18.3.1)
'@rjsf/utils': 5.24.13(react@18.3.1)
lodash: 4.17.21
lodash-es: 4.17.22
markdown-to-jsx: 8.0.0(react@18.3.1)
lodash-es: 4.17.21
markdown-to-jsx: 7.7.13(react@18.3.1)
prop-types: 15.8.1
react: 18.3.1
'@rjsf/utils@6.1.2(react@18.3.1)':
'@rjsf/utils@5.24.13(react@18.3.1)':
dependencies:
'@x0k/json-schema-merge': 1.0.2
fast-uri: 3.1.0
json-schema-merge-allof: 0.8.1
jsonpointer: 5.0.1
lodash: 4.17.21
lodash-es: 4.17.22
lodash-es: 4.17.21
react: 18.3.1
react-is: 18.3.1
'@rjsf/validator-ajv8@5.24.13(@rjsf/utils@6.1.2(react@18.3.1))':
'@rjsf/validator-ajv8@5.24.13(@rjsf/utils@5.24.13(react@18.3.1))':
dependencies:
'@rjsf/utils': 6.1.2(react@18.3.1)
'@rjsf/utils': 5.24.13(react@18.3.1)
ajv: 8.17.1
ajv-formats: 2.1.1(ajv@8.17.1)
lodash: 4.17.21
@@ -11522,10 +11502,6 @@ snapshots:
'@webassemblyjs/ast': 1.14.1
'@xtuc/long': 4.2.2
'@x0k/json-schema-merge@1.0.2':
dependencies:
'@types/json-schema': 7.0.15
'@xtuc/ieee754@1.2.0': {}
'@xtuc/long@4.2.2': {}
@@ -12065,6 +12041,19 @@ snapshots:
compare-versions@6.1.1: {}
compute-gcd@1.2.1:
dependencies:
validate.io-array: 1.0.6
validate.io-function: 1.0.2
validate.io-integer-array: 1.0.0
compute-lcm@1.1.2:
dependencies:
compute-gcd: 1.2.1
validate.io-array: 1.0.6
validate.io-function: 1.0.2
validate.io-integer-array: 1.0.0
concat-map@0.0.1: {}
concurrently@9.2.1:
@@ -12943,8 +12932,6 @@ snapshots:
fast-uri@3.0.6: {}
fast-uri@3.1.0: {}
fastq@1.19.1:
dependencies:
reusify: 1.1.0
@@ -13654,6 +13641,16 @@ snapshots:
json-parse-even-better-errors@2.3.1: {}
json-schema-compare@0.2.2:
dependencies:
lodash: 4.17.21
json-schema-merge-allof@0.8.1:
dependencies:
compute-lcm: 1.1.2
json-schema-compare: 0.2.2
lodash: 4.17.21
json-schema-traverse@0.4.1: {}
json-schema-traverse@1.0.0: {}
@@ -13769,8 +13766,6 @@ snapshots:
lodash-es@4.17.21: {}
lodash-es@4.17.22: {}
lodash.camelcase@4.3.0: {}
lodash.debounce@4.0.8: {}
@@ -13850,8 +13845,8 @@ snapshots:
markdown-table@3.0.4: {}
markdown-to-jsx@8.0.0(react@18.3.1):
optionalDependencies:
markdown-to-jsx@7.7.13(react@18.3.1):
dependencies:
react: 18.3.1
math-intrinsics@1.1.0: {}
@@ -16207,6 +16202,21 @@ snapshots:
uuid@9.0.1: {}
validate.io-array@1.0.6: {}
validate.io-function@1.0.2: {}
validate.io-integer-array@1.0.0:
dependencies:
validate.io-array: 1.0.6
validate.io-integer: 1.0.5
validate.io-integer@1.0.5:
dependencies:
validate.io-number: 1.0.3
validate.io-number@1.0.3: {}
validator@13.15.20: {}
vaul@1.1.2(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):

View File

@@ -1,6 +1,6 @@
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs";
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
import { CredentialsInput } from "@/components/contextual/CredentialsInputs/CredentialsInputs";
import { useState } from "react";
import { getSchemaDefaultCredentials } from "../../helpers";
import { areAllCredentialsSet, getCredentialFields } from "./helpers";

View File

@@ -1,12 +1,12 @@
"use client";
import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentInputs/RunAgentInputs";
import {
Card,
CardContent,
CardHeader,
CardTitle,
} from "@/components/__legacy__/ui/card";
import { RunAgentInputs } from "@/components/contextual/RunAgentInputs/RunAgentInputs";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { CircleNotchIcon } from "@phosphor-icons/react/dist/ssr";
import { Play } from "lucide-react";

View File

@@ -1,48 +0,0 @@
"use client";
import { ChatDrawer } from "@/components/contextual/Chat/ChatDrawer";
import { usePathname } from "next/navigation";
import { Children, ReactNode } from "react";
interface PlatformLayoutContentProps {
children: ReactNode;
}
export function PlatformLayoutContent({
children,
}: PlatformLayoutContentProps) {
const pathname = usePathname();
const isAuthPage =
pathname?.includes("/login") || pathname?.includes("/signup");
// Extract Navbar, AdminImpersonationBanner, and page content from children
const childrenArray = Children.toArray(children);
const navbar = childrenArray[0];
const adminBanner = childrenArray[1];
const pageContent = childrenArray.slice(2);
// For login/signup pages, use a simpler layout that doesn't interfere with centering
if (isAuthPage) {
return (
<main className="flex min-h-screen w-full flex-col">
{navbar}
{adminBanner}
<section className="flex-1">{pageContent}</section>
{/* ChatDrawer must always be rendered to maintain consistent hook count */}
<ChatDrawer />
</main>
);
}
// For logged-in pages, use the drawer layout
return (
<main className="flex h-screen w-full flex-col overflow-hidden">
{navbar}
{adminBanner}
<section className="flex min-h-0 flex-1 overflow-auto">
{pageContent}
</section>
<ChatDrawer />
</main>
);
}

View File

@@ -1,5 +1,8 @@
"use client";
import { Sidebar } from "@/components/__legacy__/Sidebar";
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
import { Cpu } from "@phosphor-icons/react";
import { IconSliders } from "@/components/__legacy__/ui/icons";
@@ -26,6 +29,11 @@ const sidebarLinkGroups = [
href: "/admin/execution-analytics",
icon: <FileText className="h-6 w-6" />,
},
{
text: "LLM Registry",
href: "/admin/llms",
icon: <Cpu size={24} />,
},
{
text: "Admin User Management",
href: "/admin/settings",

View File

@@ -0,0 +1,357 @@
"use server";
import { revalidatePath } from "next/cache";
// Generated API functions
import {
getV2ListLlmProviders,
postV2CreateLlmProvider,
getV2ListLlmModels,
postV2CreateLlmModel,
patchV2UpdateLlmModel,
patchV2ToggleLlmModelAvailability,
deleteV2DeleteLlmModelAndMigrateWorkflows,
getV2GetModelUsageCount,
getV2ListModelMigrations,
postV2RevertAModelMigration,
getV2ListModelCreators,
postV2CreateModelCreator,
patchV2UpdateModelCreator,
deleteV2DeleteModelCreator,
postV2SetRecommendedModel,
} from "@/app/api/__generated__/endpoints/admin/admin";
// Generated types
import type { LlmProvidersResponse } from "@/app/api/__generated__/models/llmProvidersResponse";
import type { LlmModelsResponse } from "@/app/api/__generated__/models/llmModelsResponse";
import type { UpsertLlmProviderRequest } from "@/app/api/__generated__/models/upsertLlmProviderRequest";
import type { CreateLlmModelRequest } from "@/app/api/__generated__/models/createLlmModelRequest";
import type { UpdateLlmModelRequest } from "@/app/api/__generated__/models/updateLlmModelRequest";
import type { ToggleLlmModelRequest } from "@/app/api/__generated__/models/toggleLlmModelRequest";
import type { LlmMigrationsResponse } from "@/app/api/__generated__/models/llmMigrationsResponse";
import type { LlmCreatorsResponse } from "@/app/api/__generated__/models/llmCreatorsResponse";
import type { UpsertLlmCreatorRequest } from "@/app/api/__generated__/models/upsertLlmCreatorRequest";
import type { LlmModelUsageResponse } from "@/app/api/__generated__/models/llmModelUsageResponse";
const ADMIN_LLM_PATH = "/admin/llms";
// =============================================================================
// Provider Actions
// =============================================================================
export async function fetchLlmProviders(): Promise<LlmProvidersResponse> {
const response = await getV2ListLlmProviders({ include_models: true });
if (response.status !== 200) {
throw new Error("Failed to fetch LLM providers");
}
return response.data;
}
export async function createLlmProviderAction(formData: FormData) {
const payload: UpsertLlmProviderRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
default_credential_provider: formData.get("default_credential_provider")
? String(formData.get("default_credential_provider")).trim()
: undefined,
default_credential_id: undefined,
default_credential_type: "api_key",
supports_tools: formData.get("supports_tools") === "on",
supports_json_output: formData.get("supports_json_output") === "on",
supports_reasoning: formData.get("supports_reasoning") === "on",
supports_parallel_tool: formData.get("supports_parallel_tool") === "on",
metadata: {},
};
const response = await postV2CreateLlmProvider(payload);
if (response.status !== 200) {
throw new Error("Failed to create LLM provider");
}
revalidatePath(ADMIN_LLM_PATH);
}
// =============================================================================
// Model Actions
// =============================================================================
export async function fetchLlmModels(): Promise<LlmModelsResponse> {
const response = await getV2ListLlmModels();
if (response.status !== 200) {
throw new Error("Failed to fetch LLM models");
}
return response.data;
}
export async function createLlmModelAction(formData: FormData) {
const providerId = String(formData.get("provider_id"));
const creatorId = formData.get("creator_id");
// Fetch provider to get default credentials
const providersResponse = await getV2ListLlmProviders({
include_models: false,
});
if (providersResponse.status !== 200) {
throw new Error("Failed to fetch providers");
}
const provider = providersResponse.data.providers.find(
(p) => p.id === providerId,
);
if (!provider) {
throw new Error("Provider not found");
}
const payload: CreateLlmModelRequest = {
slug: String(formData.get("slug") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
provider_id: providerId,
creator_id: creatorId ? String(creatorId) : undefined,
context_window: Number(formData.get("context_window") || 0),
max_output_tokens: formData.get("max_output_tokens")
? Number(formData.get("max_output_tokens"))
: undefined,
is_enabled: formData.get("is_enabled") === "on",
capabilities: {},
metadata: {},
costs: [
{
credit_cost: Number(formData.get("credit_cost") || 0),
credential_provider:
provider.default_credential_provider || provider.name,
credential_id: provider.default_credential_id || undefined,
credential_type: provider.default_credential_type || "api_key",
metadata: {},
},
],
};
const response = await postV2CreateLlmModel(payload);
if (response.status !== 200) {
throw new Error("Failed to create LLM model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function updateLlmModelAction(formData: FormData) {
const modelId = String(formData.get("model_id"));
const creatorId = formData.get("creator_id");
const payload: UpdateLlmModelRequest = {
display_name: formData.get("display_name")
? String(formData.get("display_name"))
: undefined,
description: formData.get("description")
? String(formData.get("description"))
: undefined,
provider_id: formData.get("provider_id")
? String(formData.get("provider_id"))
: undefined,
creator_id: creatorId ? String(creatorId) : undefined,
context_window: formData.get("context_window")
? Number(formData.get("context_window"))
: undefined,
max_output_tokens: formData.get("max_output_tokens")
? Number(formData.get("max_output_tokens"))
: undefined,
is_enabled: formData.get("is_enabled")
? formData.get("is_enabled") === "on"
: undefined,
costs: formData.get("credit_cost")
? [
{
credit_cost: Number(formData.get("credit_cost")),
credential_provider: String(
formData.get("credential_provider") || "",
).trim(),
credential_id: formData.get("credential_id")
? String(formData.get("credential_id"))
: undefined,
credential_type: formData.get("credential_type")
? String(formData.get("credential_type"))
: undefined,
metadata: {},
},
]
: undefined,
};
const response = await patchV2UpdateLlmModel(modelId, payload);
if (response.status !== 200) {
throw new Error("Failed to update LLM model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function toggleLlmModelAction(formData: FormData): Promise<void> {
const modelId = String(formData.get("model_id"));
const shouldEnable = formData.get("is_enabled") === "true";
const migrateToSlug = formData.get("migrate_to_slug");
const migrationReason = formData.get("migration_reason");
const customCreditCost = formData.get("custom_credit_cost");
const payload: ToggleLlmModelRequest = {
is_enabled: shouldEnable,
migrate_to_slug: migrateToSlug ? String(migrateToSlug) : undefined,
migration_reason: migrationReason ? String(migrationReason) : undefined,
custom_credit_cost: customCreditCost ? Number(customCreditCost) : undefined,
};
const response = await patchV2ToggleLlmModelAvailability(modelId, payload);
if (response.status !== 200) {
throw new Error("Failed to toggle LLM model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function deleteLlmModelAction(formData: FormData): Promise<void> {
const modelId = String(formData.get("model_id"));
const replacementModelSlug = String(formData.get("replacement_model_slug"));
if (!replacementModelSlug) {
throw new Error("Replacement model is required");
}
const response = await deleteV2DeleteLlmModelAndMigrateWorkflows(modelId, {
replacement_model_slug: replacementModelSlug,
});
if (response.status !== 200) {
throw new Error("Failed to delete model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function fetchLlmModelUsage(
modelId: string,
): Promise<LlmModelUsageResponse> {
const response = await getV2GetModelUsageCount(modelId);
if (response.status !== 200) {
throw new Error("Failed to fetch model usage");
}
return response.data;
}
// =============================================================================
// Migration Actions
// =============================================================================
export async function fetchLlmMigrations(
includeReverted: boolean = false,
): Promise<LlmMigrationsResponse> {
const response = await getV2ListModelMigrations({
include_reverted: includeReverted,
});
if (response.status !== 200) {
throw new Error("Failed to fetch migrations");
}
return response.data;
}
export async function revertLlmMigrationAction(
formData: FormData,
): Promise<void> {
const migrationId = String(formData.get("migration_id"));
const response = await postV2RevertAModelMigration(migrationId, null);
if (response.status !== 200) {
throw new Error("Failed to revert migration");
}
revalidatePath(ADMIN_LLM_PATH);
}
// =============================================================================
// Creator Actions
// =============================================================================
export async function fetchLlmCreators(): Promise<LlmCreatorsResponse> {
const response = await getV2ListModelCreators();
if (response.status !== 200) {
throw new Error("Failed to fetch creators");
}
return response.data;
}
export async function createLlmCreatorAction(
formData: FormData,
): Promise<void> {
const payload: UpsertLlmCreatorRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
website_url: formData.get("website_url")
? String(formData.get("website_url")).trim()
: undefined,
logo_url: formData.get("logo_url")
? String(formData.get("logo_url")).trim()
: undefined,
metadata: {},
};
const response = await postV2CreateModelCreator(payload);
if (response.status !== 200) {
throw new Error("Failed to create creator");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function updateLlmCreatorAction(
formData: FormData,
): Promise<void> {
const creatorId = String(formData.get("creator_id"));
const payload: UpsertLlmCreatorRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
website_url: formData.get("website_url")
? String(formData.get("website_url")).trim()
: undefined,
logo_url: formData.get("logo_url")
? String(formData.get("logo_url")).trim()
: undefined,
metadata: {},
};
const response = await patchV2UpdateModelCreator(creatorId, payload);
if (response.status !== 200) {
throw new Error("Failed to update creator");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function deleteLlmCreatorAction(
formData: FormData,
): Promise<void> {
const creatorId = String(formData.get("creator_id"));
const response = await deleteV2DeleteModelCreator(creatorId);
if (response.status !== 200) {
throw new Error("Failed to delete creator");
}
revalidatePath(ADMIN_LLM_PATH);
}
// =============================================================================
// Recommended Model Actions
// =============================================================================
export async function setRecommendedModelAction(
formData: FormData,
): Promise<void> {
const modelId = String(formData.get("model_id"));
const response = await postV2SetRecommendedModel({ model_id: modelId });
if (response.status !== 200) {
throw new Error("Failed to set recommended model");
}
revalidatePath(ADMIN_LLM_PATH);
}

View File

@@ -0,0 +1,147 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import { createLlmCreatorAction } from "../actions";
import { useRouter } from "next/navigation";
export function AddCreatorModal() {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await createLlmCreatorAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to create creator");
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Add Creator"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "512px" }}
>
<Dialog.Trigger>
<Button variant="primary" size="small">
Add Creator
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Add a new model creator (the organization that made/trained the
model).
</div>
<form action={handleSubmit} className="space-y-4">
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="name"
className="text-sm font-medium text-foreground"
>
Name (slug) <span className="text-destructive">*</span>
</label>
<input
id="name"
required
name="name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="openai"
/>
<p className="text-xs text-muted-foreground">
Lowercase identifier (e.g., openai, meta, anthropic)
</p>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="OpenAI"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={2}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Creator of GPT models..."
/>
</div>
<div className="space-y-2">
<label
htmlFor="website_url"
className="text-sm font-medium text-foreground"
>
Website URL
</label>
<input
id="website_url"
name="website_url"
type="url"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="https://openai.com"
/>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Creating..." : "Add Creator"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -0,0 +1,314 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import { createLlmModelAction } from "../actions";
import { useRouter } from "next/navigation";
interface Props {
providers: LlmProvider[];
creators: LlmModelCreator[];
}
export function AddModelModal({ providers, creators }: Props) {
const [open, setOpen] = useState(false);
const [selectedCreatorId, setSelectedCreatorId] = useState("");
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await createLlmModelAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to create model");
} finally {
setIsSubmitting(false);
}
}
// When provider changes, auto-select matching creator if one exists
function handleProviderChange(providerId: string) {
const provider = providers.find((p) => p.id === providerId);
if (provider) {
// Find creator with same name as provider (e.g., "openai" -> "openai")
const matchingCreator = creators.find((c) => c.name === provider.name);
if (matchingCreator) {
setSelectedCreatorId(matchingCreator.id);
} else {
// No matching creator (e.g., OpenRouter hosts other creators' models)
setSelectedCreatorId("");
}
}
}
return (
<Dialog
title="Add Model"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="primary" size="small">
Add Model
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Register a new model slug, metadata, and pricing.
</div>
<form action={handleSubmit} className="space-y-6">
{/* Basic Information */}
<div className="space-y-4">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Basic Information
</h3>
<p className="text-xs text-muted-foreground">
Core model details
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="slug"
className="text-sm font-medium text-foreground"
>
Model Slug <span className="text-destructive">*</span>
</label>
<input
id="slug"
required
name="slug"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="gpt-4.1-mini-2025-04-14"
/>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="GPT 4.1 Mini"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={3}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Optional description..."
/>
</div>
</div>
{/* Model Configuration */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Model Configuration
</h3>
<p className="text-xs text-muted-foreground">
Model capabilities and limits
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="provider_id"
className="text-sm font-medium text-foreground"
>
Provider <span className="text-destructive">*</span>
</label>
<select
id="provider_id"
required
name="provider_id"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
defaultValue=""
onChange={(e) => handleProviderChange(e.target.value)}
>
<option value="" disabled>
Select provider
</option>
{providers.map((provider) => (
<option key={provider.id} value={provider.id}>
{provider.display_name} ({provider.name})
</option>
))}
</select>
<p className="text-xs text-muted-foreground">
Who hosts/serves the model
</p>
</div>
<div className="space-y-2">
<label
htmlFor="creator_id"
className="text-sm font-medium text-foreground"
>
Creator
</label>
<select
id="creator_id"
name="creator_id"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
value={selectedCreatorId}
onChange={(e) => setSelectedCreatorId(e.target.value)}
>
<option value="">No creator selected</option>
{creators.map((creator) => (
<option key={creator.id} value={creator.id}>
{creator.display_name} ({creator.name})
</option>
))}
</select>
<p className="text-xs text-muted-foreground">
Who made/trained the model (e.g., OpenAI, Meta)
</p>
</div>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="context_window"
className="text-sm font-medium text-foreground"
>
Context Window <span className="text-destructive">*</span>
</label>
<input
id="context_window"
required
type="number"
name="context_window"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="128000"
min={1}
/>
</div>
<div className="space-y-2">
<label
htmlFor="max_output_tokens"
className="text-sm font-medium text-foreground"
>
Max Output Tokens
</label>
<input
id="max_output_tokens"
type="number"
name="max_output_tokens"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="16384"
min={1}
/>
</div>
</div>
</div>
{/* Pricing */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">Pricing</h3>
<p className="text-xs text-muted-foreground">
Credit cost per run (credentials are managed via the provider)
</p>
</div>
<div className="grid gap-4 sm:grid-cols-1">
<div className="space-y-2">
<label
htmlFor="credit_cost"
className="text-sm font-medium text-foreground"
>
Credit Cost <span className="text-destructive">*</span>
</label>
<input
id="credit_cost"
required
type="number"
name="credit_cost"
step="1"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="5"
min={0}
/>
</div>
</div>
<p className="text-xs text-muted-foreground">
Credit cost is always in platform credits. Credentials are
inherited from the selected provider.
</p>
</div>
{/* Enabled Toggle */}
<div className="flex items-center gap-3 border-t border-border pt-6">
<input type="hidden" name="is_enabled" value="off" />
<input
id="is_enabled"
type="checkbox"
name="is_enabled"
defaultChecked
className="h-4 w-4 rounded border-input"
/>
<label
htmlFor="is_enabled"
className="text-sm font-medium text-foreground"
>
Enabled by default
</label>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Creating..." : "Save Model"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -0,0 +1,268 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import { createLlmProviderAction } from "../actions";
import { useRouter } from "next/navigation";
export function AddProviderModal() {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await createLlmProviderAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to create provider",
);
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Add Provider"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="primary" size="small">
Add Provider
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Define a new upstream provider and default credential information.
</div>
{/* Setup Instructions */}
<div className="mb-6 rounded-lg border border-primary/30 bg-primary/5 p-4">
<div className="space-y-2">
<h4 className="text-sm font-semibold text-foreground">
Before Adding a Provider
</h4>
<p className="text-xs text-muted-foreground">
To use a new provider, you must first configure its credentials in
the backend:
</p>
<ol className="list-inside list-decimal space-y-1 text-xs text-muted-foreground">
<li>
Add the credential to{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
backend/integrations/credentials_store.py
</code>{" "}
with a UUID, provider name, and settings secret reference
</li>
<li>
Add it to the{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
PROVIDER_CREDENTIALS
</code>{" "}
dictionary in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
backend/data/block_cost_config.py
</code>
</li>
<li>
Use the <strong>same provider name</strong> in the
&quot;Credential Provider&quot; field below that matches the key
in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
PROVIDER_CREDENTIALS
</code>
</li>
</ol>
</div>
</div>
<form action={handleSubmit} className="space-y-6">
{/* Basic Information */}
<div className="space-y-4">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Basic Information
</h3>
<p className="text-xs text-muted-foreground">
Core provider details
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="name"
className="text-sm font-medium text-foreground"
>
Provider Slug <span className="text-destructive">*</span>
</label>
<input
id="name"
required
name="name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="e.g. openai"
/>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="OpenAI"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={3}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Optional description..."
/>
</div>
</div>
{/* Default Credentials */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Default Credentials
</h3>
<p className="text-xs text-muted-foreground">
Credential provider name that matches the key in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
PROVIDER_CREDENTIALS
</code>
</p>
</div>
<div className="space-y-2">
<label
htmlFor="default_credential_provider"
className="text-sm font-medium text-foreground"
>
Credential Provider <span className="text-destructive">*</span>
</label>
<input
id="default_credential_provider"
name="default_credential_provider"
required
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="openai"
/>
<p className="text-xs text-muted-foreground">
<strong>Important:</strong> This must exactly match the key in
the{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
PROVIDER_CREDENTIALS
</code>{" "}
dictionary in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
block_cost_config.py
</code>
. Common values: &quot;openai&quot;, &quot;anthropic&quot;,
&quot;groq&quot;, &quot;open_router&quot;, etc.
</p>
</div>
</div>
{/* Capabilities */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Capabilities
</h3>
<p className="text-xs text-muted-foreground">
Provider feature flags
</p>
</div>
<div className="grid gap-3 sm:grid-cols-2">
{[
{ name: "supports_tools", label: "Supports tools" },
{ name: "supports_json_output", label: "Supports JSON output" },
{ name: "supports_reasoning", label: "Supports reasoning" },
{
name: "supports_parallel_tool",
label: "Supports parallel tool calls",
},
].map(({ name, label }) => (
<div
key={name}
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
>
<input type="hidden" name={name} value="off" />
<input
id={name}
type="checkbox"
name={name}
defaultChecked={
name !== "supports_reasoning" &&
name !== "supports_parallel_tool"
}
className="h-4 w-4 rounded border-input"
/>
<label
htmlFor={name}
className="text-sm font-medium text-foreground"
>
{label}
</label>
</div>
))}
</div>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Creating..." : "Save Provider"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -0,0 +1,195 @@
"use client";
import { useState } from "react";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import { Button } from "@/components/atoms/Button/Button";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { updateLlmCreatorAction } from "../actions";
import { useRouter } from "next/navigation";
import { DeleteCreatorModal } from "./DeleteCreatorModal";
export function CreatorsTable({ creators }: { creators: LlmModelCreator[] }) {
if (!creators.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No creators registered yet.
</div>
);
}
return (
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Creator</TableHead>
<TableHead>Description</TableHead>
<TableHead>Website</TableHead>
<TableHead>Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{creators.map((creator) => (
<TableRow key={creator.id}>
<TableCell>
<div className="font-medium">{creator.display_name}</div>
<div className="text-xs text-muted-foreground">
{creator.name}
</div>
</TableCell>
<TableCell>
<span className="text-sm text-muted-foreground">
{creator.description || "—"}
</span>
</TableCell>
<TableCell>
{creator.website_url ? (
<a
href={creator.website_url}
target="_blank"
rel="noopener noreferrer"
className="text-sm text-primary hover:underline"
>
{(() => {
try {
return new URL(creator.website_url).hostname;
} catch {
return creator.website_url;
}
})()}
</a>
) : (
<span className="text-muted-foreground"></span>
)}
</TableCell>
<TableCell>
<div className="flex items-center justify-end gap-2">
<EditCreatorModal creator={creator} />
<DeleteCreatorModal creator={creator} />
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
);
}
function EditCreatorModal({ creator }: { creator: LlmModelCreator }) {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await updateLlmCreatorAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to update creator");
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Edit Creator"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "512px" }}
>
<Dialog.Trigger>
<Button variant="outline" size="small" className="min-w-0">
Edit
</Button>
</Dialog.Trigger>
<Dialog.Content>
<form action={handleSubmit} className="space-y-4">
<input type="hidden" name="creator_id" value={creator.id} />
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label className="text-sm font-medium">Name (slug)</label>
<input
required
name="name"
defaultValue={creator.name}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Display Name</label>
<input
required
name="display_name"
defaultValue={creator.display_name}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Description</label>
<textarea
name="description"
rows={2}
defaultValue={creator.description ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Website URL</label>
<input
name="website_url"
type="url"
defaultValue={creator.website_url ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Updating..." : "Update"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -0,0 +1,107 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import { deleteLlmCreatorAction } from "../actions";
export function DeleteCreatorModal({ creator }: { creator: LlmModelCreator }) {
const [open, setOpen] = useState(false);
const [isDeleting, setIsDeleting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleDelete(formData: FormData) {
setIsDeleting(true);
setError(null);
try {
await deleteLlmCreatorAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to delete creator");
} finally {
setIsDeleting(false);
}
}
return (
<Dialog
title="Delete Creator"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "480px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0 text-destructive hover:bg-destructive/10"
>
Delete
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="space-y-4">
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to delete:</p>
<p className="mt-1">
<span className="font-medium">{creator.display_name}</span>{" "}
<span className="text-muted-foreground">
({creator.name})
</span>
</p>
<p className="mt-2 text-muted-foreground">
Models using this creator will have their creator field
cleared. This is safe and won&apos;t affect model
functionality.
</p>
</div>
</div>
</div>
<form action={handleDelete} className="space-y-4">
<input type="hidden" name="creator_id" value={creator.id} />
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isDeleting}
type="button"
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={isDeleting}
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
>
{isDeleting ? "Deleting..." : "Delete Creator"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -0,0 +1,201 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import { deleteLlmModelAction, fetchLlmModelUsage } from "../actions";
export function DeleteModelModal({
model,
availableModels,
}: {
model: LlmModel;
availableModels: LlmModel[];
}) {
const router = useRouter();
const [open, setOpen] = useState(false);
const [selectedReplacement, setSelectedReplacement] = useState<string>("");
const [isDeleting, setIsDeleting] = useState(false);
const [error, setError] = useState<string | null>(null);
const [usageCount, setUsageCount] = useState<number | null>(null);
const [usageLoading, setUsageLoading] = useState(false);
const [usageError, setUsageError] = useState<string | null>(null);
// Filter out the current model and disabled models from replacement options
const replacementOptions = availableModels.filter(
(m) => m.id !== model.id && m.is_enabled,
);
async function fetchUsage() {
setUsageLoading(true);
setUsageError(null);
try {
const usage = await fetchLlmModelUsage(model.id);
setUsageCount(usage.node_count);
} catch (err) {
console.error("Failed to fetch model usage:", err);
setUsageError("Failed to load usage count");
setUsageCount(null);
} finally {
setUsageLoading(false);
}
}
async function handleDelete(formData: FormData) {
setIsDeleting(true);
setError(null);
try {
await deleteLlmModelAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to delete model");
} finally {
setIsDeleting(false);
}
}
return (
<Dialog
title="Delete Model"
controlled={{
isOpen: open,
set: async (isOpen) => {
setOpen(isOpen);
if (isOpen) {
setUsageCount(null);
setUsageError(null);
setError(null);
setSelectedReplacement("");
await fetchUsage();
}
},
}}
styling={{ maxWidth: "600px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0 text-destructive hover:bg-destructive/10"
>
Delete
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
This action cannot be undone. All workflows using this model will be
migrated to the replacement model you select.
</div>
<div className="space-y-4">
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to delete:</p>
<p className="mt-1">
<span className="font-medium">{model.display_name}</span>{" "}
<span className="text-muted-foreground">({model.slug})</span>
</p>
{usageLoading && (
<p className="mt-2 text-muted-foreground">
Loading usage count...
</p>
)}
{usageError && (
<p className="mt-2 text-destructive">{usageError}</p>
)}
{!usageLoading && !usageError && usageCount !== null && (
<p className="mt-2 font-semibold">
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
currently use this model
</p>
)}
<p className="mt-2 text-muted-foreground">
All workflows currently using this model will be automatically
updated to use the replacement model you choose below.
</p>
</div>
</div>
</div>
<form action={handleDelete} className="space-y-4">
<input type="hidden" name="model_id" value={model.id} />
<input
type="hidden"
name="replacement_model_slug"
value={selectedReplacement}
/>
<label className="text-sm font-medium">
<span className="mb-2 block">
Select Replacement Model{" "}
<span className="text-destructive">*</span>
</span>
<select
required
value={selectedReplacement}
onChange={(e) => setSelectedReplacement(e.target.value)}
className="w-full rounded border border-input bg-background p-2 text-sm"
>
<option value="">-- Choose a replacement model --</option>
{replacementOptions.map((m) => (
<option key={m.id} value={m.slug}>
{m.display_name} ({m.slug})
</option>
))}
</select>
{replacementOptions.length === 0 && (
<p className="mt-2 text-xs text-destructive">
No replacement models available. You must have at least one
other enabled model before deleting this one.
</p>
)}
</label>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setSelectedReplacement("");
setError(null);
}}
disabled={isDeleting}
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={
!selectedReplacement ||
isDeleting ||
replacementOptions.length === 0
}
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
>
{isDeleting ? "Deleting..." : "Delete and Migrate"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -0,0 +1,287 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import { toggleLlmModelAction, fetchLlmModelUsage } from "../actions";
export function DisableModelModal({
model,
availableModels,
}: {
model: LlmModel;
availableModels: LlmModel[];
}) {
const [open, setOpen] = useState(false);
const [isDisabling, setIsDisabling] = useState(false);
const [error, setError] = useState<string | null>(null);
const [usageCount, setUsageCount] = useState<number | null>(null);
const [selectedMigration, setSelectedMigration] = useState<string>("");
const [wantsMigration, setWantsMigration] = useState(false);
const [migrationReason, setMigrationReason] = useState("");
const [customCreditCost, setCustomCreditCost] = useState<string>("");
// Filter out the current model and disabled models from replacement options
const migrationOptions = availableModels.filter(
(m) => m.id !== model.id && m.is_enabled,
);
async function fetchUsage() {
try {
const usage = await fetchLlmModelUsage(model.id);
setUsageCount(usage.node_count);
} catch {
setUsageCount(null);
}
}
async function handleDisable(formData: FormData) {
setIsDisabling(true);
setError(null);
try {
await toggleLlmModelAction(formData);
setOpen(false);
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to disable model");
} finally {
setIsDisabling(false);
}
}
function resetState() {
setError(null);
setSelectedMigration("");
setWantsMigration(false);
setMigrationReason("");
setCustomCreditCost("");
}
const hasUsage = usageCount !== null && usageCount > 0;
return (
<Dialog
title="Disable Model"
controlled={{
isOpen: open,
set: async (isOpen) => {
setOpen(isOpen);
if (isOpen) {
setUsageCount(null);
resetState();
await fetchUsage();
}
},
}}
styling={{ maxWidth: "600px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0"
>
Disable
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Disabling a model will hide it from users when creating new workflows.
</div>
<div className="space-y-4">
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to disable:</p>
<p className="mt-1">
<span className="font-medium">{model.display_name}</span>{" "}
<span className="text-muted-foreground">({model.slug})</span>
</p>
{usageCount === null ? (
<p className="mt-2 text-muted-foreground">
Loading usage data...
</p>
) : usageCount > 0 ? (
<p className="mt-2 font-semibold">
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
currently use this model
</p>
) : (
<p className="mt-2 text-muted-foreground">
No workflows are currently using this model.
</p>
)}
</div>
</div>
</div>
{hasUsage && (
<div className="space-y-4 rounded-lg border border-border bg-muted/50 p-4">
<label className="flex items-start gap-3">
<input
type="checkbox"
checked={wantsMigration}
onChange={(e) => {
setWantsMigration(e.target.checked);
if (!e.target.checked) {
setSelectedMigration("");
}
}}
className="mt-1"
/>
<div className="text-sm">
<span className="font-medium">
Migrate existing workflows to another model
</span>
<p className="mt-1 text-muted-foreground">
Creates a revertible migration record. If unchecked,
existing workflows will use automatic fallback to an enabled
model from the same provider.
</p>
</div>
</label>
{wantsMigration && (
<div className="space-y-4 border-t border-border pt-4">
<label className="block text-sm font-medium">
<span className="mb-2 block">
Replacement Model{" "}
<span className="text-destructive">*</span>
</span>
<select
required
value={selectedMigration}
onChange={(e) => setSelectedMigration(e.target.value)}
className="w-full rounded border border-input bg-background p-2 text-sm"
>
<option value="">-- Choose a replacement model --</option>
{migrationOptions.map((m) => (
<option key={m.id} value={m.slug}>
{m.display_name} ({m.slug})
</option>
))}
</select>
{migrationOptions.length === 0 && (
<p className="mt-2 text-xs text-destructive">
No other enabled models available for migration.
</p>
)}
</label>
<label className="block text-sm font-medium">
<span className="mb-2 block">
Migration Reason{" "}
<span className="font-normal text-muted-foreground">
(optional)
</span>
</span>
<input
type="text"
value={migrationReason}
onChange={(e) => setMigrationReason(e.target.value)}
placeholder="e.g., Provider outage, Cost reduction"
className="w-full rounded border border-input bg-background p-2 text-sm"
/>
<p className="mt-1 text-xs text-muted-foreground">
Helps track why the migration was made
</p>
</label>
<label className="block text-sm font-medium">
<span className="mb-2 block">
Custom Credit Cost{" "}
<span className="font-normal text-muted-foreground">
(optional)
</span>
</span>
<input
type="number"
min="0"
value={customCreditCost}
onChange={(e) => setCustomCreditCost(e.target.value)}
placeholder="Leave blank to use target model's cost"
className="w-full rounded border border-input bg-background p-2 text-sm"
/>
<p className="mt-1 text-xs text-muted-foreground">
Override pricing for migrated workflows. When set, billing
will use this cost instead of the target model&apos;s cost.
</p>
</label>
</div>
)}
</div>
)}
<form action={handleDisable} className="space-y-4">
<input type="hidden" name="model_id" value={model.id} />
<input type="hidden" name="is_enabled" value="false" />
{wantsMigration && selectedMigration && (
<>
<input
type="hidden"
name="migrate_to_slug"
value={selectedMigration}
/>
{migrationReason && (
<input
type="hidden"
name="migration_reason"
value={migrationReason}
/>
)}
{customCreditCost && (
<input
type="hidden"
name="custom_credit_cost"
value={customCreditCost}
/>
)}
</>
)}
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
onClick={() => {
setOpen(false);
resetState();
}}
disabled={isDisabling}
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={
isDisabling ||
(wantsMigration && !selectedMigration) ||
usageCount === null
}
>
{isDisabling
? "Disabling..."
: wantsMigration && selectedMigration
? "Disable & Migrate"
: "Disable Model"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -0,0 +1,213 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import { updateLlmModelAction } from "../actions";
export function EditModelModal({
model,
providers,
creators,
}: {
model: LlmModel;
providers: LlmProvider[];
creators: LlmModelCreator[];
}) {
const router = useRouter();
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const cost = model.costs?.[0];
const provider = providers.find((p) => p.id === model.provider_id);
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await updateLlmModelAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to update model");
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Edit Model"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="outline" size="small" className="min-w-0">
Edit
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Update model metadata and pricing information.
</div>
{error && (
<div className="mb-4 rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<form action={handleSubmit} className="space-y-4">
<input type="hidden" name="model_id" value={model.id} />
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Display Name
<input
required
name="display_name"
defaultValue={model.display_name}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
/>
</label>
<label className="text-sm font-medium">
Provider
<select
required
name="provider_id"
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
defaultValue={model.provider_id}
>
{providers.map((p) => (
<option key={p.id} value={p.id}>
{p.display_name} ({p.name})
</option>
))}
</select>
<span className="text-xs text-muted-foreground">
Who hosts/serves the model
</span>
</label>
</div>
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Creator
<select
name="creator_id"
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
defaultValue={model.creator_id ?? ""}
>
<option value="">No creator selected</option>
{creators.map((c) => (
<option key={c.id} value={c.id}>
{c.display_name} ({c.name})
</option>
))}
</select>
<span className="text-xs text-muted-foreground">
Who made/trained the model (e.g., OpenAI, Meta)
</span>
</label>
</div>
<label className="text-sm font-medium">
Description
<textarea
name="description"
rows={2}
defaultValue={model.description ?? ""}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
placeholder="Optional description..."
/>
</label>
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Context Window
<input
required
type="number"
name="context_window"
defaultValue={model.context_window}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
min={1}
/>
</label>
<label className="text-sm font-medium">
Max Output Tokens
<input
type="number"
name="max_output_tokens"
defaultValue={model.max_output_tokens ?? undefined}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
min={1}
/>
</label>
</div>
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Credit Cost
<input
required
type="number"
name="credit_cost"
defaultValue={cost?.credit_cost ?? 0}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
min={0}
/>
<span className="text-xs text-muted-foreground">
Credits charged per run
</span>
</label>
<label className="text-sm font-medium">
Credential Provider
<select
required
name="credential_provider"
defaultValue={cost?.credential_provider ?? provider?.name ?? ""}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
>
<option value="" disabled>
Select provider
</option>
{providers.map((p) => (
<option key={p.id} value={p.name}>
{p.display_name} ({p.name})
</option>
))}
</select>
<span className="text-xs text-muted-foreground">
Must match a key in PROVIDER_CREDENTIALS
</span>
</label>
</div>
{/* Hidden defaults for credential_type */}
<input type="hidden" name="credential_type" value="api_key" />
<Dialog.Footer>
<Button
variant="ghost"
size="small"
onClick={() => setOpen(false)}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Updating..." : "Update Model"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -0,0 +1,131 @@
"use client";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import { ErrorBoundary } from "@/components/molecules/ErrorBoundary/ErrorBoundary";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { AddProviderModal } from "./AddProviderModal";
import { AddModelModal } from "./AddModelModal";
import { AddCreatorModal } from "./AddCreatorModal";
import { ProviderList } from "./ProviderList";
import { ModelsTable } from "./ModelsTable";
import { MigrationsTable } from "./MigrationsTable";
import { CreatorsTable } from "./CreatorsTable";
import { RecommendedModelSelector } from "./RecommendedModelSelector";
interface Props {
providers: LlmProvider[];
models: LlmModel[];
migrations: LlmModelMigration[];
creators: LlmModelCreator[];
}
function AdminErrorFallback() {
return (
<div className="mx-auto max-w-xl p-6">
<ErrorCard
responseError={{
message:
"An error occurred while loading the LLM Registry. Please refresh the page.",
}}
context="llm-registry"
onRetry={() => window.location.reload()}
/>
</div>
);
}
export function LlmRegistryDashboard({
providers,
models,
migrations,
creators,
}: Props) {
return (
<ErrorBoundary fallback={<AdminErrorFallback />} context="llm-registry">
<div className="mx-auto p-6">
<div className="flex flex-col gap-6">
{/* Header */}
<div>
<h1 className="text-3xl font-bold">LLM Registry</h1>
<p className="text-muted-foreground">
Manage providers, creators, models, and credit pricing
</p>
</div>
{/* Active Migrations Section - Only show if there are migrations */}
{migrations.length > 0 && (
<div className="rounded-lg border border-primary/30 bg-primary/5 p-6 shadow-sm">
<div className="mb-4">
<h2 className="text-xl font-semibold">Active Migrations</h2>
<p className="mt-1 text-sm text-muted-foreground">
These migrations can be reverted to restore workflows to their
original model
</p>
</div>
<MigrationsTable migrations={migrations} />
</div>
)}
{/* Providers & Creators Section - Side by Side */}
<div className="grid gap-6 lg:grid-cols-2">
{/* Providers */}
<div className="rounded-lg border bg-card p-6 shadow-sm">
<div className="mb-4 flex items-center justify-between">
<div>
<h2 className="text-xl font-semibold">Providers</h2>
<p className="mt-1 text-sm text-muted-foreground">
Who hosts/serves the models
</p>
</div>
<AddProviderModal />
</div>
<ProviderList providers={providers} />
</div>
{/* Creators */}
<div className="rounded-lg border bg-card p-6 shadow-sm">
<div className="mb-4 flex items-center justify-between">
<div>
<h2 className="text-xl font-semibold">Creators</h2>
<p className="mt-1 text-sm text-muted-foreground">
Who made/trained the models
</p>
</div>
<AddCreatorModal />
</div>
<CreatorsTable creators={creators} />
</div>
</div>
{/* Models Section */}
<div className="rounded-lg border bg-card p-6 shadow-sm">
<div className="mb-4 flex items-center justify-between">
<div>
<h2 className="text-xl font-semibold">Models</h2>
<p className="mt-1 text-sm text-muted-foreground">
Toggle availability, adjust context windows, and update credit
pricing
</p>
</div>
<AddModelModal providers={providers} creators={creators} />
</div>
{/* Recommended Model Selector */}
<div className="mb-6">
<RecommendedModelSelector models={models} />
</div>
<ModelsTable
models={models}
providers={providers}
creators={creators}
/>
</div>
</div>
</div>
</ErrorBoundary>
);
}

View File

@@ -0,0 +1,133 @@
"use client";
import { useState } from "react";
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
import { Button } from "@/components/atoms/Button/Button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import { revertLlmMigrationAction } from "../actions";
export function MigrationsTable({
migrations,
}: {
migrations: LlmModelMigration[];
}) {
if (!migrations.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No active migrations. Migrations are created when you disable a model
with the &quot;Migrate existing workflows&quot; option.
</div>
);
}
return (
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Migration</TableHead>
<TableHead>Reason</TableHead>
<TableHead>Nodes Affected</TableHead>
<TableHead>Custom Cost</TableHead>
<TableHead>Created</TableHead>
<TableHead className="text-right">Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{migrations.map((migration) => (
<MigrationRow key={migration.id} migration={migration} />
))}
</TableBody>
</Table>
</div>
);
}
function MigrationRow({ migration }: { migration: LlmModelMigration }) {
const [isReverting, setIsReverting] = useState(false);
const [error, setError] = useState<string | null>(null);
async function handleRevert(formData: FormData) {
setIsReverting(true);
setError(null);
try {
await revertLlmMigrationAction(formData);
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to revert migration",
);
} finally {
setIsReverting(false);
}
}
const createdDate = new Date(migration.created_at);
return (
<>
<TableRow>
<TableCell>
<div className="text-sm">
<span className="font-medium">{migration.source_model_slug}</span>
<span className="mx-2 text-muted-foreground"></span>
<span className="font-medium">{migration.target_model_slug}</span>
</div>
</TableCell>
<TableCell>
<div className="text-sm text-muted-foreground">
{migration.reason || "—"}
</div>
</TableCell>
<TableCell>
<div className="text-sm">{migration.node_count}</div>
</TableCell>
<TableCell>
<div className="text-sm">
{migration.custom_credit_cost !== null &&
migration.custom_credit_cost !== undefined
? `${migration.custom_credit_cost} credits`
: "—"}
</div>
</TableCell>
<TableCell>
<div className="text-sm text-muted-foreground">
{createdDate.toLocaleDateString()}{" "}
{createdDate.toLocaleTimeString([], {
hour: "2-digit",
minute: "2-digit",
})}
</div>
</TableCell>
<TableCell className="text-right">
<form action={handleRevert} className="inline">
<input type="hidden" name="migration_id" value={migration.id} />
<Button
type="submit"
variant="outline"
size="small"
disabled={isReverting}
>
{isReverting ? "Reverting..." : "Revert"}
</Button>
</form>
</TableCell>
</TableRow>
{error && (
<TableRow>
<TableCell colSpan={6}>
<div className="rounded border border-destructive/30 bg-destructive/10 p-2 text-sm text-destructive">
{error}
</div>
</TableCell>
</TableRow>
)}
</>
);
}

View File

@@ -0,0 +1,172 @@
"use client";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import { Button } from "@/components/atoms/Button/Button";
import { toggleLlmModelAction } from "../actions";
import { DeleteModelModal } from "./DeleteModelModal";
import { DisableModelModal } from "./DisableModelModal";
import { EditModelModal } from "./EditModelModal";
import { Star } from "@phosphor-icons/react";
export function ModelsTable({
models,
providers,
creators,
}: {
models: LlmModel[];
providers: LlmProvider[];
creators: LlmModelCreator[];
}) {
if (!models.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No models registered yet.
</div>
);
}
const providerLookup = new Map(
providers.map((provider) => [provider.id, provider]),
);
return (
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Model</TableHead>
<TableHead>Provider</TableHead>
<TableHead>Creator</TableHead>
<TableHead>Context Window</TableHead>
<TableHead>Max Output</TableHead>
<TableHead>Cost</TableHead>
<TableHead>Status</TableHead>
<TableHead>Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{models.map((model) => {
const cost = model.costs?.[0];
const provider = providerLookup.get(model.provider_id);
return (
<TableRow
key={model.id}
className={model.is_enabled ? "" : "opacity-60"}
>
<TableCell>
<div className="font-medium">{model.display_name}</div>
<div className="text-xs text-muted-foreground">
{model.slug}
</div>
</TableCell>
<TableCell>
{provider ? (
<>
<div>{provider.display_name}</div>
<div className="text-xs text-muted-foreground">
{provider.name}
</div>
</>
) : (
model.provider_id
)}
</TableCell>
<TableCell>
{model.creator ? (
<>
<div>{model.creator.display_name}</div>
<div className="text-xs text-muted-foreground">
{model.creator.name}
</div>
</>
) : (
<span className="text-muted-foreground"></span>
)}
</TableCell>
<TableCell>{model.context_window.toLocaleString()}</TableCell>
<TableCell>
{model.max_output_tokens
? model.max_output_tokens.toLocaleString()
: "—"}
</TableCell>
<TableCell>
{cost ? (
<>
<div className="font-medium">
{cost.credit_cost} credits
</div>
<div className="text-xs text-muted-foreground">
{cost.credential_provider}
</div>
</>
) : (
"—"
)}
</TableCell>
<TableCell>
<div className="flex flex-col gap-1">
<span
className={`inline-flex rounded-full px-2.5 py-1 text-xs font-semibold ${
model.is_enabled
? "bg-primary/10 text-primary"
: "bg-muted text-muted-foreground"
}`}
>
{model.is_enabled ? "Enabled" : "Disabled"}
</span>
{model.is_recommended && (
<span className="inline-flex items-center gap-1 rounded-full bg-amber-500/10 px-2.5 py-1 text-xs font-semibold text-amber-600 dark:text-amber-400">
<Star size={12} weight="fill" />
Recommended
</span>
)}
</div>
</TableCell>
<TableCell>
<div className="flex items-center justify-end gap-2">
{model.is_enabled ? (
<DisableModelModal
model={model}
availableModels={models}
/>
) : (
<EnableModelButton modelId={model.id} />
)}
<EditModelModal
model={model}
providers={providers}
creators={creators}
/>
<DeleteModelModal model={model} availableModels={models} />
</div>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
);
}
function EnableModelButton({ modelId }: { modelId: string }) {
return (
<form action={toggleLlmModelAction} className="inline">
<input type="hidden" name="model_id" value={modelId} />
<input type="hidden" name="is_enabled" value="true" />
<Button type="submit" variant="outline" size="small" className="min-w-0">
Enable
</Button>
</form>
);
}

Some files were not shown because too many files have changed in this diff Show More