mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-25 15:08:07 -05:00
Compare commits
12 Commits
add-llm-ma
...
abhi/show-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb852a947b | ||
|
|
10cc347563 | ||
|
|
71c0f909f3 | ||
|
|
9d9fea700b | ||
|
|
4e25b1d0b2 | ||
|
|
2cd9ec5106 | ||
|
|
9a6e17ff52 | ||
|
|
fb58827c61 | ||
|
|
595f3508c1 | ||
|
|
7892590b12 | ||
|
|
82d7134fc6 | ||
|
|
90466908a8 |
@@ -122,24 +122,6 @@ class ConnectionManager:
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
|
||||
"""Broadcast a message to all active websocket connections."""
|
||||
message = WSMessage(
|
||||
method=method,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
|
||||
connections = tuple(self.active_connections)
|
||||
if not connections:
|
||||
return 0
|
||||
|
||||
await asyncio.gather(
|
||||
*(connection.send_text(message) for connection in connections),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
|
||||
@@ -176,64 +176,30 @@ async def get_execution_analytics_config(
|
||||
# Return with provider prefix for clarity
|
||||
return f"{provider_name}: {model_name}"
|
||||
|
||||
# Get all models from the registry (dynamic, not hardcoded enum)
|
||||
from backend.data import llm_registry
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
|
||||
# Get the recommended model from the database (configurable via admin UI)
|
||||
recommended_model_slug = await llm_db.get_recommended_model_slug()
|
||||
|
||||
# Build the available models list
|
||||
first_enabled_slug = None
|
||||
for registry_model in llm_registry.iter_dynamic_models():
|
||||
# Only include enabled models in the list
|
||||
if not registry_model.is_enabled:
|
||||
continue
|
||||
|
||||
# Track first enabled model as fallback
|
||||
if first_enabled_slug is None:
|
||||
first_enabled_slug = registry_model.slug
|
||||
|
||||
model_enum = LlmModel(registry_model.slug) # Create enum instance from slug
|
||||
label = generate_model_label(model_enum)
|
||||
# Include all LlmModel values (no more filtering by hardcoded list)
|
||||
recommended_model = LlmModel.GPT4O_MINI.value
|
||||
for model in LlmModel:
|
||||
label = generate_model_label(model)
|
||||
# Add "(Recommended)" suffix to the recommended model
|
||||
if registry_model.slug == recommended_model_slug:
|
||||
if model.value == recommended_model:
|
||||
label += " (Recommended)"
|
||||
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
value=registry_model.slug,
|
||||
value=model.value,
|
||||
label=label,
|
||||
provider=registry_model.metadata.provider,
|
||||
provider=model.provider,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort models by provider and name for better UX
|
||||
available_models.sort(key=lambda x: (x.provider, x.label))
|
||||
|
||||
# Handle case where no models are available
|
||||
if not available_models:
|
||||
logger.warning(
|
||||
"No enabled LLM models found in registry. "
|
||||
"Ensure models are configured and enabled in the LLM Registry."
|
||||
)
|
||||
# Provide a placeholder entry so admins see meaningful feedback
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
value="",
|
||||
label="No models available - configure in LLM Registry",
|
||||
provider="none",
|
||||
)
|
||||
)
|
||||
|
||||
# Use the DB recommended model, or fallback to first enabled model
|
||||
final_recommended = recommended_model_slug or first_enabled_slug or ""
|
||||
|
||||
return ExecutionAnalyticsConfig(
|
||||
available_models=available_models,
|
||||
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||
default_user_prompt=DEFAULT_USER_PROMPT,
|
||||
recommended_model=final_recommended,
|
||||
recommended_model=recommended_model,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,595 +0,0 @@
|
||||
import logging
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
tags=["llm", "admin"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
async def _refresh_runtime_state() -> None:
|
||||
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
|
||||
logger.info("Refreshing LLM registry runtime state...")
|
||||
try:
|
||||
# Refresh registry from database
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated model options
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
logger.info("Cleared all block schema caches")
|
||||
|
||||
# Clear the /blocks endpoint cache so frontend gets updated schemas
|
||||
try:
|
||||
from backend.api.features.v1 import _get_cached_blocks
|
||||
|
||||
_get_cached_blocks.cache_clear()
|
||||
logger.info("Cleared /blocks endpoint cache")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to clear /blocks cache: %s", e)
|
||||
|
||||
# Clear the v2 builder caches (if they exist)
|
||||
try:
|
||||
from backend.api.features.builder import db as builder_db
|
||||
|
||||
if hasattr(builder_db, "_get_all_providers"):
|
||||
builder_db._get_all_providers.cache_clear()
|
||||
logger.info("Cleared v2 builder providers cache")
|
||||
if hasattr(builder_db, "_build_cached_search_results"):
|
||||
builder_db._build_cached_search_results.cache_clear()
|
||||
logger.info("Cleared v2 builder search results cache")
|
||||
except Exception as e:
|
||||
logger.debug("Could not clear v2 builder cache: %s", e)
|
||||
|
||||
# Notify all executor services to refresh their registry cache
|
||||
from backend.data.llm_registry import publish_registry_refresh_notification
|
||||
|
||||
await publish_registry_refresh_notification()
|
||||
logger.info("Published registry refresh notification")
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"LLM runtime state refresh failed; caches may be stale: %s", exc
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
summary="List LLM providers",
|
||||
response_model=llm_model.LlmProvidersResponse,
|
||||
)
|
||||
async def list_llm_providers(include_models: bool = True):
|
||||
providers = await llm_db.list_providers(include_models=include_models)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/providers",
|
||||
summary="Create LLM provider",
|
||||
response_model=llm_model.LlmProvider,
|
||||
)
|
||||
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
|
||||
provider = await llm_db.upsert_provider(request=request)
|
||||
await _refresh_runtime_state()
|
||||
return provider
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/providers/{provider_id}",
|
||||
summary="Update LLM provider",
|
||||
response_model=llm_model.LlmProvider,
|
||||
)
|
||||
async def update_llm_provider(
|
||||
provider_id: str,
|
||||
request: llm_model.UpsertLlmProviderRequest,
|
||||
):
|
||||
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
|
||||
await _refresh_runtime_state()
|
||||
return provider
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/providers/{provider_id}",
|
||||
summary="Delete LLM provider",
|
||||
response_model=dict,
|
||||
)
|
||||
async def delete_llm_provider(provider_id: str):
|
||||
"""
|
||||
Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
Delete all models from the provider first before deleting the provider.
|
||||
"""
|
||||
try:
|
||||
await llm_db.delete_provider(provider_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Deleted LLM provider '%s'", provider_id)
|
||||
return {"success": True, "message": "Provider deleted successfully"}
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to delete provider '%s': %s", provider_id, e)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to delete provider '%s': %s", provider_id, e)
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
summary="List LLM models",
|
||||
response_model=llm_model.LlmModelsResponse,
|
||||
)
|
||||
async def list_llm_models(
|
||||
provider_id: str | None = fastapi.Query(default=None),
|
||||
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = fastapi.Query(
|
||||
default=50, ge=1, le=100, description="Number of models per page"
|
||||
),
|
||||
):
|
||||
return await llm_db.list_models(
|
||||
provider_id=provider_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/models",
|
||||
summary="Create LLM model",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
|
||||
model = await llm_db.create_model(request=request)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/models/{model_id}",
|
||||
summary="Update LLM model",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def update_llm_model(
|
||||
model_id: str,
|
||||
request: llm_model.UpdateLlmModelRequest,
|
||||
):
|
||||
model = await llm_db.update_model(model_id=model_id, request=request)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/models/{model_id}/toggle",
|
||||
summary="Toggle LLM model availability",
|
||||
response_model=llm_model.ToggleLlmModelResponse,
|
||||
)
|
||||
async def toggle_llm_model(
|
||||
model_id: str,
|
||||
request: llm_model.ToggleLlmModelRequest,
|
||||
):
|
||||
"""
|
||||
Toggle a model's enabled status, optionally migrating workflows when disabling.
|
||||
|
||||
If disabling a model and `migrate_to_slug` is provided, all workflows using
|
||||
this model will be migrated to the specified replacement model before disabling.
|
||||
A migration record is created which can be reverted later using the revert endpoint.
|
||||
|
||||
Optional fields:
|
||||
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
|
||||
- `custom_credit_cost`: Custom pricing override for billing during migration
|
||||
"""
|
||||
try:
|
||||
result = await llm_db.toggle_model(
|
||||
model_id=model_id,
|
||||
is_enabled=request.is_enabled,
|
||||
migrate_to_slug=request.migrate_to_slug,
|
||||
migration_reason=request.migration_reason,
|
||||
custom_credit_cost=request.custom_credit_cost,
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
if result.nodes_migrated > 0:
|
||||
logger.info(
|
||||
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
|
||||
result.model.slug,
|
||||
"enabled" if request.is_enabled else "disabled",
|
||||
result.nodes_migrated,
|
||||
result.migrated_to_slug,
|
||||
result.migration_id,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
logger.warning("Model toggle validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to toggle model availability",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models/{model_id}/usage",
|
||||
summary="Get model usage count",
|
||||
response_model=llm_model.LlmModelUsageResponse,
|
||||
)
|
||||
async def get_llm_model_usage(model_id: str):
|
||||
"""Get the number of workflow nodes using this model."""
|
||||
try:
|
||||
return await llm_db.get_model_usage(model_id=model_id)
|
||||
except ValueError as exc:
|
||||
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get model usage %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get model usage",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/models/{model_id}",
|
||||
summary="Delete LLM model and migrate workflows",
|
||||
response_model=llm_model.DeleteLlmModelResponse,
|
||||
)
|
||||
async def delete_llm_model(
|
||||
model_id: str,
|
||||
replacement_model_slug: str | None = fastapi.Query(
|
||||
default=None,
|
||||
description="Slug of the model to migrate existing workflows to (required only if workflows use this model)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Delete a model and optionally migrate workflows using it to a replacement model.
|
||||
|
||||
If no workflows are using this model, it can be deleted without providing a
|
||||
replacement. If workflows exist, replacement_model_slug is required.
|
||||
|
||||
This endpoint:
|
||||
1. Counts how many workflow nodes use the model being deleted
|
||||
2. If nodes exist, validates the replacement model and migrates them
|
||||
3. Deletes the model record
|
||||
4. Refreshes all caches and notifies executors
|
||||
|
||||
Example: DELETE /admin/llm/models/{id}?replacement_model_slug=gpt-4o
|
||||
Example (no usage): DELETE /admin/llm/models/{id}
|
||||
"""
|
||||
try:
|
||||
result = await llm_db.delete_model(
|
||||
model_id=model_id, replacement_model_slug=replacement_model_slug
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Deleted model '%s' and migrated %d nodes to '%s'",
|
||||
result.deleted_model_slug,
|
||||
result.nodes_migrated,
|
||||
result.replacement_model_slug,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
# Validation errors (model not found, replacement invalid, etc.)
|
||||
logger.warning("Model deletion validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete model and migrate workflows",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Migration Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/migrations",
|
||||
summary="List model migrations",
|
||||
response_model=llm_model.LlmMigrationsResponse,
|
||||
)
|
||||
async def list_llm_migrations(
|
||||
include_reverted: bool = fastapi.Query(
|
||||
default=False, description="Include reverted migrations in the list"
|
||||
),
|
||||
):
|
||||
"""
|
||||
List all model migrations.
|
||||
|
||||
Migrations are created when disabling a model with the migrate_to_slug option.
|
||||
They can be reverted to restore the original model configuration.
|
||||
"""
|
||||
try:
|
||||
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
|
||||
return llm_model.LlmMigrationsResponse(migrations=migrations)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to list migrations: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to list migrations",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/migrations/{migration_id}",
|
||||
summary="Get migration details",
|
||||
response_model=llm_model.LlmModelMigration,
|
||||
)
|
||||
async def get_llm_migration(migration_id: str):
|
||||
"""Get details of a specific migration."""
|
||||
try:
|
||||
migration = await llm_db.get_migration(migration_id)
|
||||
if not migration:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Migration '{migration_id}' not found"
|
||||
)
|
||||
return migration
|
||||
except fastapi.HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get migration %s: %s", migration_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get migration",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/migrations/{migration_id}/revert",
|
||||
summary="Revert a model migration",
|
||||
response_model=llm_model.RevertMigrationResponse,
|
||||
)
|
||||
async def revert_llm_migration(
|
||||
migration_id: str,
|
||||
request: llm_model.RevertMigrationRequest | None = None,
|
||||
):
|
||||
"""
|
||||
Revert a model migration, restoring affected workflows to their original model.
|
||||
|
||||
This only reverts the specific nodes that were part of the migration.
|
||||
The source model must exist for the revert to succeed.
|
||||
|
||||
Options:
|
||||
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
|
||||
|
||||
Response includes:
|
||||
- `nodes_reverted`: Number of nodes successfully reverted
|
||||
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
|
||||
- `source_model_re_enabled`: Whether the source model was re-enabled
|
||||
|
||||
Requirements:
|
||||
- Migration must not already be reverted
|
||||
- Source model must exist
|
||||
"""
|
||||
try:
|
||||
re_enable = request.re_enable_source_model if request else True
|
||||
result = await llm_db.revert_migration(
|
||||
migration_id,
|
||||
re_enable_source_model=re_enable,
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
|
||||
"(%d already changed, source re-enabled=%s)",
|
||||
migration_id,
|
||||
result.nodes_reverted,
|
||||
result.target_model_slug,
|
||||
result.source_model_slug,
|
||||
result.nodes_already_changed,
|
||||
result.source_model_re_enabled,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
logger.warning("Migration revert validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to revert migration",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Creator Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creators",
|
||||
summary="List model creators",
|
||||
response_model=llm_model.LlmCreatorsResponse,
|
||||
)
|
||||
async def list_llm_creators():
|
||||
"""
|
||||
List all model creators.
|
||||
|
||||
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
|
||||
This is distinct from providers who host/serve the models (e.g., OpenRouter).
|
||||
"""
|
||||
try:
|
||||
creators = await llm_db.list_creators()
|
||||
return llm_model.LlmCreatorsResponse(creators=creators)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to list creators: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to list creators",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creators/{creator_id}",
|
||||
summary="Get creator details",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def get_llm_creator(creator_id: str):
|
||||
"""Get details of a specific model creator."""
|
||||
try:
|
||||
creator = await llm_db.get_creator(creator_id)
|
||||
if not creator:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Creator '{creator_id}' not found"
|
||||
)
|
||||
return creator
|
||||
except fastapi.HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/creators",
|
||||
summary="Create model creator",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
|
||||
"""
|
||||
Create a new model creator.
|
||||
|
||||
A creator represents an organization that creates/trains AI models,
|
||||
such as OpenAI, Anthropic, Meta, or Google.
|
||||
"""
|
||||
try:
|
||||
creator = await llm_db.upsert_creator(request=request)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
|
||||
return creator
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to create creator: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to create creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/creators/{creator_id}",
|
||||
summary="Update model creator",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def update_llm_creator(
|
||||
creator_id: str,
|
||||
request: llm_model.UpsertLlmCreatorRequest,
|
||||
):
|
||||
"""Update an existing model creator."""
|
||||
try:
|
||||
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
|
||||
return creator
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to update creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to update creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/creators/{creator_id}",
|
||||
summary="Delete model creator",
|
||||
response_model=dict,
|
||||
)
|
||||
async def delete_llm_creator(creator_id: str):
|
||||
"""
|
||||
Delete a model creator.
|
||||
|
||||
This will remove the creator association from all models that reference it
|
||||
(sets creatorId to NULL), but will not delete the models themselves.
|
||||
"""
|
||||
try:
|
||||
await llm_db.delete_creator(creator_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Deleted model creator '%s'", creator_id)
|
||||
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
|
||||
except ValueError as exc:
|
||||
logger.warning("Creator deletion validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete creator",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Recommended Model Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/recommended-model",
|
||||
summary="Get recommended model",
|
||||
response_model=llm_model.RecommendedModelResponse,
|
||||
)
|
||||
async def get_recommended_model():
|
||||
"""
|
||||
Get the currently recommended LLM model.
|
||||
|
||||
The recommended model is shown to users as the default/suggested option
|
||||
in model selection dropdowns.
|
||||
"""
|
||||
try:
|
||||
model = await llm_db.get_recommended_model()
|
||||
return llm_model.RecommendedModelResponse(
|
||||
model=model,
|
||||
slug=model.slug if model else None,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get recommended model: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get recommended model",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/recommended-model",
|
||||
summary="Set recommended model",
|
||||
response_model=llm_model.SetRecommendedModelResponse,
|
||||
)
|
||||
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
|
||||
"""
|
||||
Set a model as the recommended model.
|
||||
|
||||
This clears the recommended flag from any other model and sets it on
|
||||
the specified model. The model must be enabled to be set as recommended.
|
||||
|
||||
The recommended model is displayed to users as the default/suggested
|
||||
option in model selection dropdowns throughout the platform.
|
||||
"""
|
||||
try:
|
||||
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Set recommended model to '%s' (previous: %s)",
|
||||
model.slug,
|
||||
previous_slug or "none",
|
||||
)
|
||||
return llm_model.SetRecommendedModelResponse(
|
||||
model=model,
|
||||
previous_recommended_slug=previous_slug,
|
||||
message=f"Model '{model.display_name}' is now the recommended model",
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.warning("Set recommended model validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to set recommended model: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to set recommended model",
|
||||
) from exc
|
||||
@@ -1,491 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.api.features.admin.llm_routes as llm_routes
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
from backend.util.models import Pagination
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(llm_routes.router, prefix="/admin/llm")
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_list_llm_providers_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful listing of LLM providers"""
|
||||
# Mock the database function
|
||||
mock_providers = [
|
||||
{
|
||||
"id": "provider-1",
|
||||
"name": "openai",
|
||||
"display_name": "OpenAI",
|
||||
"description": "OpenAI LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": True,
|
||||
"metadata": {},
|
||||
"models": [],
|
||||
},
|
||||
{
|
||||
"id": "provider-2",
|
||||
"name": "anthropic",
|
||||
"display_name": "Anthropic",
|
||||
"description": "Anthropic LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": True,
|
||||
"metadata": {},
|
||||
"models": [],
|
||||
},
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.list_providers",
|
||||
new=AsyncMock(return_value=mock_providers),
|
||||
)
|
||||
|
||||
response = client.get("/admin/llm/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert len(response_data["providers"]) == 2
|
||||
assert response_data["providers"][0]["name"] == "openai"
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"list_llm_providers_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_list_llm_models_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful listing of LLM models with pagination"""
|
||||
# Mock the database function - now returns LlmModelsResponse
|
||||
mock_model = llm_model.LlmModel(
|
||||
id="model-1",
|
||||
slug="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description="GPT-4 Optimized",
|
||||
provider_id="provider-1",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
is_enabled=True,
|
||||
capabilities={},
|
||||
metadata={},
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
id="cost-1",
|
||||
credit_cost=10,
|
||||
credential_provider="openai",
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mock_response = llm_model.LlmModelsResponse(
|
||||
models=[mock_model],
|
||||
pagination=Pagination(
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=50,
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.list_models",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
response = client.get("/admin/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert len(response_data["models"]) == 1
|
||||
assert response_data["models"][0]["slug"] == "gpt-4o"
|
||||
assert response_data["pagination"]["total_items"] == 1
|
||||
assert response_data["pagination"]["page_size"] == 50
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"list_llm_models_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_create_llm_provider_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful creation of LLM provider"""
|
||||
mock_provider = {
|
||||
"id": "new-provider-id",
|
||||
"name": "groq",
|
||||
"display_name": "Groq",
|
||||
"description": "Groq LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": False,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
|
||||
new=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"name": "groq",
|
||||
"display_name": "Groq",
|
||||
"description": "Groq LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": False,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
response = client.post("/admin/llm/providers", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["name"] == "groq"
|
||||
assert response_data["display_name"] == "Groq"
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"create_llm_provider_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_create_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful creation of LLM model"""
|
||||
mock_model = {
|
||||
"id": "new-model-id",
|
||||
"slug": "gpt-4.1-mini",
|
||||
"display_name": "GPT-4.1 Mini",
|
||||
"description": "Latest GPT-4.1 Mini model",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"id": "cost-id",
|
||||
"credit_cost": 5,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.create_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"slug": "gpt-4.1-mini",
|
||||
"display_name": "GPT-4.1 Mini",
|
||||
"description": "Latest GPT-4.1 Mini model",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"credit_cost": 5,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/admin/llm/models", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["slug"] == "gpt-4.1-mini"
|
||||
assert response_data["is_enabled"] is True
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"create_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_update_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful update of LLM model"""
|
||||
mock_model = {
|
||||
"id": "model-1",
|
||||
"slug": "gpt-4o",
|
||||
"display_name": "GPT-4o Updated",
|
||||
"description": "Updated description",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 256000,
|
||||
"max_output_tokens": 32768,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"id": "cost-1",
|
||||
"credit_cost": 15,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.update_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"display_name": "GPT-4o Updated",
|
||||
"description": "Updated description",
|
||||
"context_window": 256000,
|
||||
"max_output_tokens": 32768,
|
||||
}
|
||||
|
||||
response = client.patch("/admin/llm/models/model-1", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["display_name"] == "GPT-4o Updated"
|
||||
assert response_data["context_window"] == 256000
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"update_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_toggle_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful toggling of LLM model enabled status"""
|
||||
# Create a proper mock model object
|
||||
mock_model = llm_model.LlmModel(
|
||||
id="model-1",
|
||||
slug="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description="GPT-4 Optimized",
|
||||
provider_id="provider-1",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
is_enabled=False,
|
||||
capabilities={},
|
||||
metadata={},
|
||||
costs=[],
|
||||
)
|
||||
|
||||
# Create a proper ToggleLlmModelResponse
|
||||
mock_response = llm_model.ToggleLlmModelResponse(
|
||||
model=mock_model,
|
||||
nodes_migrated=0,
|
||||
migrated_to_slug=None,
|
||||
migration_id=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {"is_enabled": False}
|
||||
|
||||
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["model"]["is_enabled"] is False
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"toggle_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful deletion of LLM model with migration"""
|
||||
# Create a proper DeleteLlmModelResponse
|
||||
mock_response = llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug="gpt-3.5-turbo",
|
||||
deleted_model_display_name="GPT-3.5 Turbo",
|
||||
replacement_model_slug="gpt-4o-mini",
|
||||
nodes_migrated=42,
|
||||
message="Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
|
||||
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.delete(
|
||||
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
|
||||
assert response_data["nodes_migrated"] == 42
|
||||
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"delete_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_llm_model_validation_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion fails with proper error when validation fails"""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Replacement model 'invalid' not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_llm_model_no_replacement_with_usage(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion fails when nodes exist but no replacement is provided"""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(
|
||||
side_effect=ValueError(
|
||||
"Cannot delete model 'test-model': 5 workflow node(s) are using it. "
|
||||
"Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "workflow node(s) are using it" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_llm_model_no_replacement_no_usage(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion succeeds when no nodes use the model and no replacement is provided"""
|
||||
mock_response = llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug="unused-model",
|
||||
deleted_model_display_name="Unused Model",
|
||||
replacement_model_slug=None,
|
||||
nodes_migrated=0,
|
||||
message="Successfully deleted model 'Unused Model' (unused-model). No workflows were using this model.",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["deleted_model_slug"] == "unused-model"
|
||||
assert response_data["nodes_migrated"] == 0
|
||||
assert response_data["replacement_model_slug"] is None
|
||||
mock_refresh.assert_called_once()
|
||||
@@ -15,7 +15,6 @@ from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.llm_registry import get_all_model_slugs_for_validation
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
@@ -32,14 +31,7 @@ from .model import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_llm_models() -> list[str]:
|
||||
"""Get LLM model names for search matching from the registry."""
|
||||
return [
|
||||
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
|
||||
]
|
||||
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
|
||||
MAX_LIBRARY_AGENT_RESULTS = 100
|
||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||
@@ -504,8 +496,8 @@ async def _get_static_counts():
|
||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
for field in schema_cls.model_fields.values():
|
||||
if field.annotation == LlmModel:
|
||||
# Check if query matches any value in llm_models from registry
|
||||
if any(query in name for name in _get_llm_models()):
|
||||
# Check if query matches any value in llm_models
|
||||
if any(query in name for name in llm_models):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,29 +1,28 @@
|
||||
"""Agent generator package - Creates agents from natural language."""
|
||||
|
||||
from .core import (
|
||||
apply_agent_patch,
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
json_to_graph,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .fixer import apply_all_fixes
|
||||
from .utils import get_blocks_info
|
||||
from .validator import validate_agent
|
||||
from .service import health_check as check_external_service_health
|
||||
from .service import is_external_service_configured
|
||||
|
||||
__all__ = [
|
||||
# Core functions
|
||||
"decompose_goal",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"apply_agent_patch",
|
||||
"save_agent_to_library",
|
||||
"get_agent_as_json",
|
||||
# Fixer
|
||||
"apply_all_fixes",
|
||||
# Validator
|
||||
"validate_agent",
|
||||
# Utils
|
||||
"get_blocks_info",
|
||||
"json_to_graph",
|
||||
# Exceptions
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
# Service
|
||||
"is_external_service_configured",
|
||||
"check_external_service_health",
|
||||
]
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
"""OpenRouter client configuration for agent generation."""
|
||||
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
|
||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY")
|
||||
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
|
||||
|
||||
# OpenRouter client (OpenAI-compatible API)
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
||||
|
||||
def get_client() -> AsyncOpenAI:
|
||||
"""Get or create the OpenRouter client."""
|
||||
global _client
|
||||
if _client is None:
|
||||
if not OPENROUTER_API_KEY:
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||
_client = AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=OPENROUTER_API_KEY,
|
||||
)
|
||||
return _client
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Core agent generation functions."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -9,13 +7,35 @@ from typing import Any
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
|
||||
from .client import AGENT_GENERATOR_MODEL, get_client
|
||||
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
|
||||
from .utils import get_block_summaries, parse_json_from_llm
|
||||
from .service import (
|
||||
decompose_goal_external,
|
||||
generate_agent_external,
|
||||
generate_agent_patch_external,
|
||||
is_external_service_configured,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentGeneratorNotConfiguredError(Exception):
|
||||
"""Raised when the external Agent Generator service is not configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _check_service_configured() -> None:
|
||||
"""Check if the external Agent Generator service is configured.
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the service is not configured.
|
||||
"""
|
||||
if not is_external_service_configured():
|
||||
raise AgentGeneratorNotConfiguredError(
|
||||
"Agent Generator service is not configured. "
|
||||
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
|
||||
)
|
||||
|
||||
|
||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
@@ -28,40 +48,13 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
full_description = description
|
||||
if context:
|
||||
full_description = f"{description}\n\nAdditional context:\n{context}"
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": full_description},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for decomposition")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse decomposition response: {content[:200]}")
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error decomposing goal: {e}")
|
||||
return None
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||
return await decompose_goal_external(description, context)
|
||||
|
||||
|
||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -72,31 +65,14 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": json.dumps(instructions, indent=2)},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for agent generation")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse agent JSON: {content[:200]}")
|
||||
return None
|
||||
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(instructions)
|
||||
if result:
|
||||
# Ensure required fields
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
@@ -104,12 +80,7 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
result["version"] = 1
|
||||
if "is_active" not in result:
|
||||
result["is_active"] = True
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating agent: {e}")
|
||||
return None
|
||||
return result
|
||||
|
||||
|
||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
@@ -284,108 +255,23 @@ async def get_agent_as_json(
|
||||
async def generate_agent_patch(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate a patch to update an existing agent.
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
The external Agent Generator service handles:
|
||||
- Generating the patch
|
||||
- Applying the patch
|
||||
- Fixing and validating the result
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
|
||||
Returns:
|
||||
Patch dict or clarifying questions, or None on error
|
||||
Updated agent JSON, clarifying questions dict, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = PATCH_PROMPT.format(
|
||||
current_agent=json.dumps(current_agent, indent=2),
|
||||
block_summaries=get_block_summaries(),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": update_request},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for patch generation")
|
||||
return None
|
||||
|
||||
return parse_json_from_llm(content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating patch: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def apply_agent_patch(
|
||||
current_agent: dict[str, Any], patch: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Apply a patch to an existing agent.
|
||||
|
||||
Args:
|
||||
current_agent: Current agent JSON
|
||||
patch: Patch dict with operations
|
||||
|
||||
Returns:
|
||||
Updated agent JSON
|
||||
"""
|
||||
agent = copy.deepcopy(current_agent)
|
||||
patches = patch.get("patches", [])
|
||||
|
||||
for p in patches:
|
||||
patch_type = p.get("type")
|
||||
|
||||
if patch_type == "modify":
|
||||
node_id = p.get("node_id")
|
||||
changes = p.get("changes", {})
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
_deep_update(node, changes)
|
||||
logger.debug(f"Modified node {node_id}")
|
||||
break
|
||||
|
||||
elif patch_type == "add":
|
||||
new_nodes = p.get("new_nodes", [])
|
||||
new_links = p.get("new_links", [])
|
||||
|
||||
agent["nodes"] = agent.get("nodes", []) + new_nodes
|
||||
agent["links"] = agent.get("links", []) + new_links
|
||||
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
|
||||
|
||||
elif patch_type == "remove":
|
||||
node_ids_to_remove = set(p.get("node_ids", []))
|
||||
link_ids_to_remove = set(p.get("link_ids", []))
|
||||
|
||||
# Remove nodes
|
||||
agent["nodes"] = [
|
||||
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
# Remove links (both explicit and those referencing removed nodes)
|
||||
agent["links"] = [
|
||||
link
|
||||
for link in agent.get("links", [])
|
||||
if link["id"] not in link_ids_to_remove
|
||||
and link["source_id"] not in node_ids_to_remove
|
||||
and link["sink_id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def _deep_update(target: dict, source: dict) -> None:
|
||||
"""Recursively update a dict with another dict."""
|
||||
for key, value in source.items():
|
||||
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
||||
_deep_update(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(update_request, current_agent)
|
||||
|
||||
@@ -1,606 +0,0 @@
|
||||
"""Agent fixer - Fixes common LLM generation errors."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
ADDTODICTIONARY_BLOCK_ID,
|
||||
ADDTOLIST_BLOCK_ID,
|
||||
CODE_EXECUTION_BLOCK_ID,
|
||||
CONDITION_BLOCK_ID,
|
||||
CREATEDICT_BLOCK_ID,
|
||||
CREATELIST_BLOCK_ID,
|
||||
DATA_SAMPLING_BLOCK_ID,
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS,
|
||||
GET_CURRENT_DATE_BLOCK_ID,
|
||||
STORE_VALUE_BLOCK_ID,
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
get_blocks_info,
|
||||
is_valid_uuid,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix invalid UUIDs in agent and link IDs."""
|
||||
# Fix agent ID
|
||||
if not is_valid_uuid(agent.get("id", "")):
|
||||
agent["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed agent ID: {agent['id']}")
|
||||
|
||||
# Fix node IDs
|
||||
id_mapping = {} # Old ID -> New ID
|
||||
for node in agent.get("nodes", []):
|
||||
if not is_valid_uuid(node.get("id", "")):
|
||||
old_id = node.get("id", "")
|
||||
new_id = str(uuid.uuid4())
|
||||
id_mapping[old_id] = new_id
|
||||
node["id"] = new_id
|
||||
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
|
||||
|
||||
# Fix link IDs and update references
|
||||
for link in agent.get("links", []):
|
||||
if not is_valid_uuid(link.get("id", "")):
|
||||
link["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed link ID: {link['id']}")
|
||||
|
||||
# Update source/sink IDs if they were remapped
|
||||
if link.get("source_id") in id_mapping:
|
||||
link["source_id"] = id_mapping[link["source_id"]]
|
||||
if link.get("sink_id") in id_mapping:
|
||||
link["sink_id"] = id_mapping[link["sink_id"]]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix single curly braces to double in template blocks."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
input_data = node.get("input_default", {})
|
||||
for key in ("prompt", "format"):
|
||||
if key in input_data and isinstance(input_data[key], str):
|
||||
original = input_data[key]
|
||||
# Fix simple variable references: {var} -> {{var}}
|
||||
fixed = re.sub(
|
||||
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
|
||||
r"{{\1}}",
|
||||
original,
|
||||
)
|
||||
if fixed != original:
|
||||
input_data[key] = fixed
|
||||
logger.debug(f"Fixed curly braces in {key}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
# Find all ConditionBlock nodes
|
||||
condition_node_ids = {
|
||||
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
|
||||
}
|
||||
|
||||
if not condition_node_ids:
|
||||
return agent
|
||||
|
||||
new_nodes = []
|
||||
new_links = []
|
||||
processed_conditions = set()
|
||||
|
||||
for link in links:
|
||||
sink_id = link.get("sink_id")
|
||||
sink_name = link.get("sink_name")
|
||||
|
||||
# Check if this link goes to a ConditionBlock's value2
|
||||
if sink_id in condition_node_ids and sink_name == "value2":
|
||||
source_node = next(
|
||||
(n for n in nodes if n["id"] == link.get("source_id")), None
|
||||
)
|
||||
|
||||
# Skip if source is already a StoreValueBlock
|
||||
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
|
||||
continue
|
||||
|
||||
# Skip if we already processed this condition
|
||||
if sink_id in processed_conditions:
|
||||
continue
|
||||
|
||||
processed_conditions.add(sink_id)
|
||||
|
||||
# Create StoreValueBlock
|
||||
store_node_id = str(uuid.uuid4())
|
||||
store_node = {
|
||||
"id": store_node_id,
|
||||
"block_id": STORE_VALUE_BLOCK_ID,
|
||||
"input_default": {"data": None},
|
||||
"metadata": {"position": {"x": 0, "y": -100}},
|
||||
}
|
||||
new_nodes.append(store_node)
|
||||
|
||||
# Create link: original source -> StoreValueBlock
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": store_node_id,
|
||||
"sink_name": "input",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Update original link: StoreValueBlock -> ConditionBlock
|
||||
link["source_id"] = store_node_id
|
||||
link["source_name"] = "output"
|
||||
|
||||
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
|
||||
|
||||
if new_nodes:
|
||||
agent["nodes"] = nodes + new_nodes
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
|
||||
|
||||
When an AddToList block is found:
|
||||
1. Checks if there's a CreateListBlock before it
|
||||
2. Removes CreateListBlock if linked directly to AddToList
|
||||
3. Adds an empty AddToList block before the original
|
||||
4. Ensures the original has a self-referencing link
|
||||
"""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
new_nodes = []
|
||||
original_addtolist_ids = set()
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
# First pass: identify CreateListBlock nodes to remove
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATELIST_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
|
||||
|
||||
# Second pass: process AddToList blocks
|
||||
filtered_nodes = []
|
||||
for node in nodes:
|
||||
if node.get("id") in nodes_to_remove:
|
||||
continue
|
||||
|
||||
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
|
||||
original_addtolist_ids.add(node.get("id"))
|
||||
node_id = node.get("id")
|
||||
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
|
||||
|
||||
# Check if already has prerequisite
|
||||
has_prereq = any(
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_name") == "updated_list"
|
||||
for link in links
|
||||
)
|
||||
|
||||
if not has_prereq:
|
||||
# Remove links to "list" input (except self-reference)
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_id") != node_id
|
||||
and link not in links_to_remove
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Create prerequisite AddToList block
|
||||
prereq_id = str(uuid.uuid4())
|
||||
prereq_node = {
|
||||
"id": prereq_id,
|
||||
"block_id": ADDTOLIST_BLOCK_ID,
|
||||
"input_default": {"list": [], "entry": None, "entries": []},
|
||||
"metadata": {
|
||||
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
|
||||
},
|
||||
}
|
||||
new_nodes.append(prereq_node)
|
||||
|
||||
# Link prerequisite to original
|
||||
links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": prereq_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added prerequisite AddToList block for {node_id}")
|
||||
|
||||
filtered_nodes.append(node)
|
||||
|
||||
# Remove marked links
|
||||
filtered_links = [link for link in links if link not in links_to_remove]
|
||||
|
||||
# Add self-referencing links for original AddToList blocks
|
||||
for node in filtered_nodes + new_nodes:
|
||||
if (
|
||||
node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
and node.get("id") in original_addtolist_ids
|
||||
):
|
||||
node_id = node.get("id")
|
||||
has_self_ref = any(
|
||||
link["source_id"] == node_id
|
||||
and link["sink_id"] == node_id
|
||||
and link["source_name"] == "updated_list"
|
||||
and link["sink_name"] == "list"
|
||||
for link in filtered_links
|
||||
)
|
||||
if not has_self_ref:
|
||||
filtered_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": node_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added self-reference for AddToList {node_id}")
|
||||
|
||||
agent["nodes"] = filtered_nodes + new_nodes
|
||||
agent["links"] = filtered_links
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
|
||||
|
||||
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
return agent
|
||||
|
||||
|
||||
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
if (
|
||||
source_node
|
||||
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
|
||||
and link.get("source_name") == "response"
|
||||
):
|
||||
link["source_name"] = "stdout_logs"
|
||||
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
links_to_remove = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
# Remove links to sample_size
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "sample_size"
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Set default
|
||||
input_default["sample_size"] = 1
|
||||
node["input_default"] = input_default
|
||||
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
|
||||
|
||||
if links_to_remove:
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
for link in links:
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
source_node = node_lookup.get(source_id)
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_pos = source_node.get("metadata", {}).get("position", {})
|
||||
sink_pos = sink_node.get("metadata", {}).get("position", {})
|
||||
|
||||
source_x = source_pos.get("x", 0)
|
||||
sink_x = sink_pos.get("x", 0)
|
||||
|
||||
if abs(sink_x - source_x) < 800:
|
||||
new_x = source_x + 800
|
||||
if "metadata" not in sink_node:
|
||||
sink_node["metadata"] = {}
|
||||
if "position" not in sink_node["metadata"]:
|
||||
sink_node["metadata"]["position"] = {}
|
||||
sink_node["metadata"]["position"]["x"] = new_x
|
||||
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
|
||||
input_default = node.get("input_default", {})
|
||||
if "offset" in input_default:
|
||||
offset = input_default["offset"]
|
||||
if isinstance(offset, (int, float)) and offset < 0:
|
||||
input_default["offset"] = abs(offset)
|
||||
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_ai_model_parameter(
|
||||
agent: dict[str, Any],
|
||||
blocks_info: list[dict[str, Any]],
|
||||
default_model: str = "gpt-4o",
|
||||
) -> dict[str, Any]:
|
||||
"""Add default model parameter to AI blocks if missing."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
# Check if block has AI category
|
||||
categories = block.get("categories", [])
|
||||
is_ai_block = any(
|
||||
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
|
||||
)
|
||||
|
||||
if is_ai_block:
|
||||
input_default = node.get("input_default", {})
|
||||
if "model" not in input_default:
|
||||
input_default["model"] = default_model
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Added model '{default_model}' to AI block {node.get('id')}"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_link_static_properties(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix is_static property based on source block's staticOutput."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
if not source_block:
|
||||
continue
|
||||
|
||||
static_output = source_block.get("staticOutput", False)
|
||||
if link.get("is_static") != static_output:
|
||||
link["is_static"] = static_output
|
||||
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_type_mismatch(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
def get_property_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_types_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
type_mapping = {
|
||||
"string": "string",
|
||||
"text": "string",
|
||||
"integer": "number",
|
||||
"number": "number",
|
||||
"float": "number",
|
||||
"boolean": "boolean",
|
||||
"bool": "boolean",
|
||||
"array": "list",
|
||||
"list": "list",
|
||||
"object": "dictionary",
|
||||
"dict": "dictionary",
|
||||
"dictionary": "dictionary",
|
||||
}
|
||||
|
||||
new_links = []
|
||||
nodes_to_add = []
|
||||
|
||||
for link in links:
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_property_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if (
|
||||
source_type
|
||||
and sink_type
|
||||
and not are_types_compatible(source_type, sink_type)
|
||||
):
|
||||
# Insert type converter
|
||||
converter_id = str(uuid.uuid4())
|
||||
target_type = type_mapping.get(sink_type, sink_type)
|
||||
|
||||
converter_node = {
|
||||
"id": converter_id,
|
||||
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
"input_default": {"type": target_type},
|
||||
"metadata": {"position": {"x": 0, "y": 100}},
|
||||
}
|
||||
nodes_to_add.append(converter_node)
|
||||
|
||||
# source -> converter
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": converter_id,
|
||||
"sink_name": "value",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# converter -> sink
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": converter_id,
|
||||
"source_name": "value",
|
||||
"sink_id": link["sink_id"],
|
||||
"sink_name": link["sink_name"],
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
|
||||
else:
|
||||
new_links.append(link)
|
||||
|
||||
if nodes_to_add:
|
||||
agent["nodes"] = nodes + nodes_to_add
|
||||
agent["links"] = new_links
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def apply_all_fixes(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Apply all fixes to an agent JSON.
|
||||
|
||||
Args:
|
||||
agent: Agent JSON dict
|
||||
blocks_info: Optional list of block info dicts for advanced fixes
|
||||
|
||||
Returns:
|
||||
Fixed agent JSON
|
||||
"""
|
||||
# Basic fixes (no block info needed)
|
||||
agent = fix_agent_ids(agent)
|
||||
agent = fix_double_curly_braces(agent)
|
||||
agent = fix_storevalue_before_condition(agent)
|
||||
agent = fix_addtolist_blocks(agent)
|
||||
agent = fix_addtodictionary_blocks(agent)
|
||||
agent = fix_code_execution_output(agent)
|
||||
agent = fix_data_sampling_sample_size(agent)
|
||||
agent = fix_node_x_coordinates(agent)
|
||||
agent = fix_getcurrentdate_offset(agent)
|
||||
|
||||
# Advanced fixes (require block info)
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
agent = fix_ai_model_parameter(agent, blocks_info)
|
||||
agent = fix_link_static_properties(agent, blocks_info)
|
||||
agent = fix_data_type_mismatch(agent, blocks_info)
|
||||
|
||||
return agent
|
||||
@@ -1,225 +0,0 @@
|
||||
"""Prompt templates for agent generation."""
|
||||
|
||||
DECOMPOSITION_PROMPT = """
|
||||
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
|
||||
|
||||
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
|
||||
|
||||
---
|
||||
|
||||
FIRST: Analyze the user's goal and determine:
|
||||
1) Design-time configuration (fixed settings that won't change per run)
|
||||
2) Runtime inputs (values the agent's end-user will provide each time it runs)
|
||||
|
||||
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
|
||||
- DO NOT ask for the actual value
|
||||
- Instead, define it as an Agent Input with a clear name, type, and description
|
||||
|
||||
Only ask clarifying questions about design-time config that affects how you build the workflow:
|
||||
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
|
||||
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
|
||||
- Business rules that must be hard-coded
|
||||
|
||||
IMPORTANT CLARIFICATIONS POLICY:
|
||||
- Ask no more than five essential questions
|
||||
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
|
||||
- Do not ask for API keys or credentials; the platform handles those directly
|
||||
- If there is enough information to infer reasonable defaults, prefer to propose defaults
|
||||
|
||||
---
|
||||
|
||||
GUIDELINES:
|
||||
1. List each step as a numbered item
|
||||
2. Describe the action clearly and specify inputs/outputs
|
||||
3. Ensure steps are in logical, sequential order
|
||||
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
|
||||
5. Help the user reach their goal efficiently
|
||||
|
||||
---
|
||||
|
||||
RULES:
|
||||
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
|
||||
2. USE ONLY THE BLOCKS PROVIDED
|
||||
3. ALL required_input fields must be provided
|
||||
4. Data types of linked properties must match
|
||||
5. Write expert-level prompts for AI-related blocks
|
||||
|
||||
---
|
||||
|
||||
CRITICAL BLOCK RESTRICTIONS:
|
||||
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
|
||||
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
|
||||
3. ConditionBlock: value2 is reference, value1 is contrast
|
||||
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
|
||||
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
|
||||
|
||||
---
|
||||
|
||||
OUTPUT FORMAT:
|
||||
|
||||
If more information is needed:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
|
||||
"keyword": "email_provider",
|
||||
"example": "Gmail"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If ready to proceed:
|
||||
```json
|
||||
{{
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{{
|
||||
"step_number": 1,
|
||||
"block_name": "AgentShortTextInputBlock",
|
||||
"description": "Get the URL of the content to analyze.",
|
||||
"inputs": [{{"name": "name", "value": "URL"}}],
|
||||
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
"""
|
||||
|
||||
GENERATION_PROMPT = """
|
||||
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
|
||||
|
||||
---
|
||||
|
||||
NODES:
|
||||
Each node must include:
|
||||
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
|
||||
- `block_id`: The block identifier (must match an Allowed Block)
|
||||
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
|
||||
- `metadata`: Must contain:
|
||||
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
|
||||
- `customized_name`: Clear name describing this block's purpose in the workflow
|
||||
|
||||
---
|
||||
|
||||
LINKS:
|
||||
Each link connects a source node's output to a sink node's input:
|
||||
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
|
||||
- `source_id`: ID of the source node
|
||||
- `source_name`: Output field name from the source block
|
||||
- `sink_id`: ID of the sink node
|
||||
- `sink_name`: Input field name on the sink block
|
||||
- `is_static`: true only if source block has static_output: true
|
||||
|
||||
CRITICAL: All IDs must be valid UUID v4 format!
|
||||
|
||||
---
|
||||
|
||||
AGENT (GRAPH):
|
||||
Wrap nodes and links in:
|
||||
- `id`: UUID of the agent
|
||||
- `name`: Short, generic name (avoid specific company names, URLs)
|
||||
- `description`: Short, generic description
|
||||
- `nodes`: List of all nodes
|
||||
- `links`: List of all links
|
||||
- `version`: 1
|
||||
- `is_active`: true
|
||||
|
||||
---
|
||||
|
||||
TIPS:
|
||||
- All required_input fields must be provided via input_default or a valid link
|
||||
- Ensure consistent source_id and sink_id references
|
||||
- Avoid dangling links
|
||||
- Input/output pins must match block schemas
|
||||
- Do not invent unknown block_ids
|
||||
|
||||
---
|
||||
|
||||
ALLOWED BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
|
||||
"""
|
||||
|
||||
PATCH_PROMPT = """
|
||||
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
|
||||
|
||||
CURRENT AGENT:
|
||||
{current_agent}
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
PATCH FORMAT:
|
||||
Return a JSON object with the following structure:
|
||||
|
||||
```json
|
||||
{{
|
||||
"type": "patch",
|
||||
"intent": "Brief description of what the patch does",
|
||||
"patches": [
|
||||
{{
|
||||
"type": "modify",
|
||||
"node_id": "uuid-of-node-to-modify",
|
||||
"changes": {{
|
||||
"input_default": {{"field": "new_value"}},
|
||||
"metadata": {{"customized_name": "New Name"}}
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"type": "add",
|
||||
"new_nodes": [
|
||||
{{
|
||||
"id": "new-uuid",
|
||||
"block_id": "block-uuid",
|
||||
"input_default": {{}},
|
||||
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
|
||||
}}
|
||||
],
|
||||
"new_links": [
|
||||
{{
|
||||
"id": "link-uuid",
|
||||
"source_id": "source-node-id",
|
||||
"source_name": "output_field",
|
||||
"sink_id": "sink-node-id",
|
||||
"sink_name": "input_field"
|
||||
}}
|
||||
]
|
||||
}},
|
||||
{{
|
||||
"type": "remove",
|
||||
"node_ids": ["uuid-of-node-to-remove"],
|
||||
"link_ids": ["uuid-of-link-to-remove"]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If you need more information, return:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "What specific change do you want?",
|
||||
"keyword": "change_type",
|
||||
"example": "Add error handling"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
Generate the minimal patch needed. Output ONLY valid JSON.
|
||||
"""
|
||||
@@ -0,0 +1,269 @@
|
||||
"""External Agent Generator service client.
|
||||
|
||||
This module provides a client for communicating with the external Agent Generator
|
||||
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_client: httpx.AsyncClient | None = None
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def _get_settings() -> Settings:
|
||||
"""Get or create settings singleton."""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
|
||||
|
||||
def is_external_service_configured() -> bool:
|
||||
"""Check if external Agent Generator service is configured."""
|
||||
settings = _get_settings()
|
||||
return bool(settings.config.agentgenerator_host)
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Get the base URL for the external service."""
|
||||
settings = _get_settings()
|
||||
host = settings.config.agentgenerator_host
|
||||
port = settings.config.agentgenerator_port
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def _get_client() -> httpx.AsyncClient:
|
||||
"""Get or create the HTTP client for the external service."""
|
||||
global _client
|
||||
if _client is None:
|
||||
settings = _get_settings()
|
||||
_client = httpx.AsyncClient(
|
||||
base_url=_get_base_url(),
|
||||
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
async def decompose_goal_external(
|
||||
description: str, context: str = ""
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to decompose a goal.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
- {"type": "unachievable_goal", ...}
|
||||
- {"type": "vague_goal", ...}
|
||||
Or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build the request payload
|
||||
payload: dict[str, Any] = {"description": description}
|
||||
if context:
|
||||
# The external service uses user_instruction for additional context
|
||||
payload["user_instruction"] = context
|
||||
|
||||
try:
|
||||
response = await client.post("/api/decompose-description", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error(f"External service returned error: {data.get('error')}")
|
||||
return None
|
||||
|
||||
# Map the response to the expected format
|
||||
response_type = data.get("type")
|
||||
if response_type == "instructions":
|
||||
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||
elif response_type == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
elif response_type == "unachievable_goal":
|
||||
return {
|
||||
"type": "unachievable_goal",
|
||||
"reason": data.get("reason"),
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "vague_goal":
|
||||
return {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
else:
|
||||
logger.error(
|
||||
f"Unknown response type from external service: {response_type}"
|
||||
)
|
||||
return None
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error calling external agent generator: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/generate-agent", json={"instructions": instructions}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error(f"External service returned error: {data.get('error')}")
|
||||
return None
|
||||
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error calling external agent generator: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/update-agent",
|
||||
json={
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error(f"External service returned error: {data.get('error')}")
|
||||
return None
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Otherwise return the updated agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error calling external agent generator: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
"""Get available blocks from the external service.
|
||||
|
||||
Returns:
|
||||
List of block info dicts or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.get("/api/blocks")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error("External service returned error getting blocks")
|
||||
return None
|
||||
|
||||
return data.get("blocks", [])
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error getting blocks from external service: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error getting blocks from external service: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting blocks from external service: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def health_check() -> bool:
|
||||
"""Check if the external service is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
if not is_external_service_configured():
|
||||
return False
|
||||
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.get("/health")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
|
||||
except Exception as e:
|
||||
logger.warning(f"External agent generator health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def close_client() -> None:
|
||||
"""Close the HTTP client."""
|
||||
global _client
|
||||
if _client is not None:
|
||||
await _client.aclose()
|
||||
_client = None
|
||||
@@ -1,213 +0,0 @@
|
||||
"""Utilities for agent generation."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# UUID validation regex
|
||||
UUID_REGEX = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
|
||||
)
|
||||
|
||||
# Block IDs for various fixes
|
||||
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
|
||||
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
|
||||
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
|
||||
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
|
||||
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
|
||||
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
|
||||
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
|
||||
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
|
||||
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
|
||||
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
|
||||
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
||||
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
||||
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
||||
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
|
||||
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
|
||||
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
|
||||
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
|
||||
"716a67b3-6760-42e7-86dc-18645c6e00fc",
|
||||
"530cf046-2ce0-4854-ae2c-659db17c7a46",
|
||||
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
|
||||
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
|
||||
]
|
||||
|
||||
|
||||
def is_valid_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID v4."""
|
||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
||||
|
||||
|
||||
def _compact_schema(schema: dict) -> dict[str, str]:
|
||||
"""Extract compact type info from a JSON schema properties dict.
|
||||
|
||||
Returns a dict of {field_name: type_string} for essential info only.
|
||||
"""
|
||||
props = schema.get("properties", {})
|
||||
result = {}
|
||||
|
||||
for name, prop in props.items():
|
||||
# Skip internal/complex fields
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Get type string
|
||||
type_str = prop.get("type", "any")
|
||||
|
||||
# Handle anyOf/oneOf (optional types)
|
||||
if "anyOf" in prop:
|
||||
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
|
||||
type_str = "|".join(types) if types else "any"
|
||||
elif "allOf" in prop:
|
||||
type_str = "object"
|
||||
|
||||
# Add array item type if present
|
||||
if type_str == "array" and "items" in prop:
|
||||
items = prop["items"]
|
||||
if isinstance(items, dict):
|
||||
item_type = items.get("type", "any")
|
||||
type_str = f"array[{item_type}]"
|
||||
|
||||
result[name] = type_str
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_block_summaries(include_schemas: bool = True) -> str:
|
||||
"""Generate compact block summaries for prompts.
|
||||
|
||||
Args:
|
||||
include_schemas: Whether to include input/output type info
|
||||
|
||||
Returns:
|
||||
Formatted string of block summaries (compact format)
|
||||
"""
|
||||
blocks = get_blocks()
|
||||
summaries = []
|
||||
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
name = block.name
|
||||
desc = getattr(block, "description", "") or ""
|
||||
|
||||
# Truncate description
|
||||
if len(desc) > 150:
|
||||
desc = desc[:147] + "..."
|
||||
|
||||
if not include_schemas:
|
||||
summaries.append(f"- {name} (id: {block_id}): {desc}")
|
||||
else:
|
||||
# Compact format with type info only
|
||||
inputs = {}
|
||||
outputs = {}
|
||||
required = []
|
||||
|
||||
if hasattr(block, "input_schema"):
|
||||
try:
|
||||
schema = block.input_schema.jsonschema()
|
||||
inputs = _compact_schema(schema)
|
||||
required = schema.get("required", [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(block, "output_schema"):
|
||||
try:
|
||||
schema = block.output_schema.jsonschema()
|
||||
outputs = _compact_schema(schema)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build compact line format
|
||||
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
|
||||
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
|
||||
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
|
||||
req_str = f" req=[{','.join(required)}]" if required else ""
|
||||
|
||||
static = " [static]" if getattr(block, "static_output", False) else ""
|
||||
|
||||
line = f"- {name} (id: {block_id}): {desc}"
|
||||
if in_str:
|
||||
line += f"\n in: {{{in_str}}}{req_str}"
|
||||
if out_str:
|
||||
line += f"\n out: {{{out_str}}}{static}"
|
||||
|
||||
summaries.append(line)
|
||||
|
||||
return "\n".join(summaries)
|
||||
|
||||
|
||||
def get_blocks_info() -> list[dict[str, Any]]:
|
||||
"""Get block information with schemas for validation and fixing."""
|
||||
blocks = get_blocks()
|
||||
blocks_info = []
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
blocks_info.append(
|
||||
{
|
||||
"id": block_id,
|
||||
"name": block.name,
|
||||
"description": getattr(block, "description", ""),
|
||||
"categories": getattr(block, "categories", []),
|
||||
"staticOutput": getattr(block, "static_output", False),
|
||||
"inputSchema": (
|
||||
block.input_schema.jsonschema()
|
||||
if hasattr(block, "input_schema")
|
||||
else {}
|
||||
),
|
||||
"outputSchema": (
|
||||
block.output_schema.jsonschema()
|
||||
if hasattr(block, "output_schema")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
return blocks_info
|
||||
|
||||
|
||||
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
|
||||
"""Extract JSON from LLM response (handles markdown code blocks)."""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# Try fenced code block
|
||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1).strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try raw text
|
||||
try:
|
||||
return json.loads(text.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding {...} span
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding [...] span
|
||||
start = text.find("[")
|
||||
end = text.rfind("]")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -1,279 +0,0 @@
|
||||
"""Agent validator - Validates agent structure and connections."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .utils import get_blocks_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentValidator:
|
||||
"""Validator for AutoGPT agents with detailed error reporting."""
|
||||
|
||||
def __init__(self):
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_error(self, error: str) -> None:
|
||||
"""Add an error message."""
|
||||
self.errors.append(error)
|
||||
|
||||
def validate_block_existence(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate all block IDs exist in the blocks library."""
|
||||
valid = True
|
||||
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
node_id = node.get("id")
|
||||
|
||||
if not block_id:
|
||||
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if block_id not in valid_block_ids:
|
||||
self.add_error(
|
||||
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate all node IDs referenced in links exist."""
|
||||
valid = True
|
||||
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
link_id = link.get("id", "Unknown")
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
if not source_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
|
||||
valid = False
|
||||
elif source_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent source_id '{source_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
if not sink_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
|
||||
valid = False
|
||||
elif sink_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_required_inputs(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate required inputs are provided."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
||||
input_defaults = node.get("input_default", {})
|
||||
node_id = node.get("id")
|
||||
|
||||
# Get linked inputs
|
||||
linked_inputs = {
|
||||
link["sink_name"]
|
||||
for link in agent.get("links", [])
|
||||
if link.get("sink_id") == node_id
|
||||
}
|
||||
|
||||
for req_input in required_inputs:
|
||||
if (
|
||||
req_input not in input_defaults
|
||||
and req_input not in linked_inputs
|
||||
and req_input != "credentials"
|
||||
):
|
||||
block_name = block.get("name", "Unknown Block")
|
||||
self.add_error(
|
||||
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_data_type_compatibility(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate linked data types are compatible."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
def get_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if source_type and sink_type and not are_compatible(source_type, sink_type):
|
||||
self.add_error(
|
||||
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
|
||||
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate nested sink links (with _#_ notation)."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
if not sink_node:
|
||||
continue
|
||||
|
||||
block = block_map.get(sink_node.get("block_id"))
|
||||
if not block:
|
||||
continue
|
||||
|
||||
input_props = block.get("inputSchema", {}).get("properties", {})
|
||||
parent_schema = input_props.get(parent)
|
||||
|
||||
if not parent_schema:
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if not parent_schema.get("additionalProperties"):
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and child in parent_schema.get("properties", {})
|
||||
):
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate prompts don't have spaces in template variables."""
|
||||
valid = True
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
input_default = node.get("input_default", {})
|
||||
prompt = input_default.get("prompt", "")
|
||||
|
||||
if not isinstance(prompt, str):
|
||||
continue
|
||||
|
||||
# Find {{...}} with spaces
|
||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
|
||||
for match in matches:
|
||||
content = match.group(1)
|
||||
if " " in content:
|
||||
self.add_error(
|
||||
f"Node '{node.get('id')}' has spaces in template variable: "
|
||||
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Run all validations.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
self.errors = []
|
||||
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
checks = [
|
||||
self.validate_block_existence(agent, blocks_info),
|
||||
self.validate_link_node_references(agent),
|
||||
self.validate_required_inputs(agent, blocks_info),
|
||||
self.validate_data_type_compatibility(agent, blocks_info),
|
||||
self.validate_nested_sink_links(agent, blocks_info),
|
||||
self.validate_prompt_spaces(agent),
|
||||
]
|
||||
|
||||
all_passed = all(checks)
|
||||
|
||||
if all_passed:
|
||||
logger.info("Agent validation successful")
|
||||
return True, None
|
||||
|
||||
error_message = "Agent validation failed:\n"
|
||||
for i, error in enumerate(self.errors, 1):
|
||||
error_message += f"{i}. {error}\n"
|
||||
|
||||
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
|
||||
return False, error_message
|
||||
|
||||
|
||||
def validate_agent(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Convenience function to validate an agent.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
validator = AgentValidator()
|
||||
return validator.validate(agent, blocks_info)
|
||||
@@ -8,12 +8,10 @@ from langfuse import observe
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_all_fixes,
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -27,9 +25,6 @@ from .models import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for agent generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
@@ -91,9 +86,8 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
Flow:
|
||||
1. Decompose the description into steps (may return clarifying questions)
|
||||
2. Generate agent JSON from the steps
|
||||
3. Apply fixes to correct common LLM errors
|
||||
4. Preview or save based on the save parameter
|
||||
2. Generate agent JSON (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
@@ -110,11 +104,13 @@ class CreateAgentTool(BaseTool):
|
||||
# Step 1: Decompose goal into steps
|
||||
try:
|
||||
decomposition_result = await decompose_goal(description, context)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -171,72 +167,32 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 2: Generate agent JSON with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
agent_json = None
|
||||
validation_errors = None
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate agent (include validation errors from previous attempt)
|
||||
if attempt == 0:
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_instructions = {
|
||||
**decomposition_result,
|
||||
"previous_errors": validation_errors,
|
||||
"retry_instructions": (
|
||||
"The previous generation had validation errors. "
|
||||
"Please fix these issues in the new generation:\n"
|
||||
f"{validation_errors}"
|
||||
),
|
||||
}
|
||||
agent_json = await generate_agent(retry_instructions)
|
||||
|
||||
if agent_json is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. Please try again.",
|
||||
error="Generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Step 3: Apply fixes to correct common errors
|
||||
agent_json = apply_all_fixes(agent_json, blocks_info)
|
||||
|
||||
# Step 4: Validate the agent
|
||||
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||
try:
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the workflow."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
if agent_json is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. Please try again.",
|
||||
error="Generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Step 4: Preview or save
|
||||
# Step 3: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
|
||||
@@ -8,13 +8,10 @@ from langfuse import observe
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_agent_patch,
|
||||
apply_all_fixes,
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -28,9 +25,6 @@ from .models import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for patch generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
@@ -43,7 +37,7 @@ class EditAgentTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates a patch to update the agent while preserving unchanged parts."
|
||||
"Generates updates to the agent while preserving unchanged parts."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -98,9 +92,8 @@ class EditAgentTool(BaseTool):
|
||||
|
||||
Flow:
|
||||
1. Fetch the current agent
|
||||
2. Generate a patch based on the requested changes
|
||||
3. Apply the patch to create an updated agent
|
||||
4. Preview or save based on the save parameter
|
||||
2. Generate updated agent (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
@@ -137,121 +130,58 @@ class EditAgentTool(BaseTool):
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
# Step 2: Generate patch with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
updated_agent = None
|
||||
validation_errors = None
|
||||
intent = "Applied requested changes"
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate patch (include validation errors from previous attempt)
|
||||
try:
|
||||
if attempt == 0:
|
||||
patch_result = await generate_agent_patch(
|
||||
update_request, current_agent
|
||||
)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_request = (
|
||||
f"{update_request}\n\n"
|
||||
f"IMPORTANT: The previous edit had validation errors. "
|
||||
f"Please fix these issues:\n{validation_errors}"
|
||||
)
|
||||
patch_result = await generate_agent_patch(
|
||||
retry_request, current_agent
|
||||
)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if patch_result is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. Please try rephrasing.",
|
||||
error="Patch generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if patch_result.get("type") == "clarifying_questions":
|
||||
questions = patch_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 3: Apply patch and fixes
|
||||
try:
|
||||
updated_agent = apply_agent_patch(current_agent, patch_result)
|
||||
updated_agent = apply_all_fixes(updated_agent, blocks_info)
|
||||
except Exception as e:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to apply changes: {str(e)}",
|
||||
error="patch_apply_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
validation_errors = str(e)
|
||||
continue
|
||||
|
||||
# Step 4: Validate the updated agent
|
||||
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
|
||||
intent = patch_result.get("intent", "Applied requested changes")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||
try:
|
||||
result = await generate_agent_patch(update_request, current_agent)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent editing is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Updated agent has validation errors after "
|
||||
f"{MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the changes."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
if result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. Please try rephrasing.",
|
||||
error="Update generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
|
||||
assert updated_agent is not None
|
||||
# Check if LLM returned clarifying questions
|
||||
if result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result is the updated agent JSON
|
||||
updated_agent = result
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
agent_description = updated_agent.get("description", "")
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
# Step 5: Preview or save
|
||||
# Step 3: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. Changes: {intent}. "
|
||||
f"I've updated the agent. "
|
||||
f"The agent now has {node_count} blocks. "
|
||||
f"Review it and call edit_agent with save=true to save the changes."
|
||||
),
|
||||
@@ -277,10 +207,7 @@ class EditAgentTool(BaseTool):
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Updated agent '{created_graph.name}' has been saved to your library! "
|
||||
f"Changes: {intent}"
|
||||
),
|
||||
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
|
||||
@@ -29,7 +29,7 @@ def mock_embedding_functions():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent(setup_test_data):
|
||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||
# Use test data from fixture
|
||||
@@ -70,7 +70,7 @@ async def test_run_agent(setup_test_data):
|
||||
assert result_data["graph_name"] == "Test Agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_missing_inputs(setup_test_data):
|
||||
"""Test that the run_agent tool returns error when inputs are missing"""
|
||||
# Use test data from fixture
|
||||
@@ -106,7 +106,7 @@ async def test_run_agent_missing_inputs(setup_test_data):
|
||||
assert "message" in result_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||
"""Test that the run_agent tool returns error for invalid agent ID"""
|
||||
# Use test data from fixture
|
||||
@@ -141,7 +141,7 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||
"""Test that run_agent works with an agent requiring LLM credentials"""
|
||||
# Use test data from fixture
|
||||
@@ -185,7 +185,7 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||
assert result_data["graph_name"] == "LLM Test Agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_data):
|
||||
"""Test that run_agent returns available inputs when called without inputs or use_defaults."""
|
||||
user = setup_test_data["user"]
|
||||
@@ -219,7 +219,7 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
|
||||
assert "inputs" in result_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_with_use_defaults(setup_test_data):
|
||||
"""Test that run_agent executes successfully with use_defaults=True."""
|
||||
user = setup_test_data["user"]
|
||||
@@ -251,7 +251,7 @@ async def test_run_agent_with_use_defaults(setup_test_data):
|
||||
assert result_data["graph_id"] == graph.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
||||
"""Test that run_agent returns setup_requirements when credentials are missing."""
|
||||
user = setup_firecrawl_test_data["user"]
|
||||
@@ -285,7 +285,7 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
||||
assert len(setup_info["user_readiness"]["missing_credentials"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
"""Test that run_agent returns error for invalid slug format (no slash)."""
|
||||
user = setup_test_data["user"]
|
||||
@@ -313,7 +313,7 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
assert "username/agent-name" in result_data["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_unauthenticated():
|
||||
"""Test that run_agent returns need_login for unauthenticated users."""
|
||||
tool = RunAgentTool()
|
||||
@@ -340,7 +340,7 @@ async def test_run_agent_unauthenticated():
|
||||
assert "sign in" in result_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_schedule_without_cron(setup_test_data):
|
||||
"""Test that run_agent returns error when scheduling without cron expression."""
|
||||
user = setup_test_data["user"]
|
||||
@@ -372,7 +372,7 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
|
||||
assert "cron" in result_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
"""Test that run_agent returns error when scheduling without schedule_name."""
|
||||
user = setup_test_data["user"]
|
||||
|
||||
@@ -23,6 +23,7 @@ class PendingHumanReviewModel(BaseModel):
|
||||
id: Unique identifier for the review record
|
||||
user_id: ID of the user who must perform the review
|
||||
node_exec_id: ID of the node execution that created this review
|
||||
node_id: ID of the node definition (for grouping reviews from same node)
|
||||
graph_exec_id: ID of the graph execution containing the node
|
||||
graph_id: ID of the graph template being executed
|
||||
graph_version: Version number of the graph template
|
||||
@@ -37,6 +38,10 @@ class PendingHumanReviewModel(BaseModel):
|
||||
"""
|
||||
|
||||
node_exec_id: str = Field(description="Node execution ID (primary key)")
|
||||
node_id: str = Field(
|
||||
description="Node definition ID (for grouping)",
|
||||
default="", # Temporary default for test compatibility
|
||||
)
|
||||
user_id: str = Field(description="User ID associated with the review")
|
||||
graph_exec_id: str = Field(description="Graph execution ID")
|
||||
graph_id: str = Field(description="Graph ID")
|
||||
@@ -66,7 +71,9 @@ class PendingHumanReviewModel(BaseModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, review: "PendingHumanReview") -> "PendingHumanReviewModel":
|
||||
def from_db(
|
||||
cls, review: "PendingHumanReview", node_id: str
|
||||
) -> "PendingHumanReviewModel":
|
||||
"""
|
||||
Convert a database model to a response model.
|
||||
|
||||
@@ -74,9 +81,14 @@ class PendingHumanReviewModel(BaseModel):
|
||||
payload, instructions, and editable flag.
|
||||
|
||||
Handles invalid data gracefully by using safe defaults.
|
||||
|
||||
Args:
|
||||
review: Database review object
|
||||
node_id: Node definition ID (fetched from NodeExecution)
|
||||
"""
|
||||
return cls(
|
||||
node_exec_id=review.nodeExecId,
|
||||
node_id=node_id,
|
||||
user_id=review.userId,
|
||||
graph_exec_id=review.graphExecId,
|
||||
graph_id=review.graphId,
|
||||
@@ -107,6 +119,13 @@ class ReviewItem(BaseModel):
|
||||
reviewed_data: SafeJsonData | None = Field(
|
||||
None, description="Optional edited data (ignored if approved=False)"
|
||||
)
|
||||
auto_approve_future: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true and this review is approved, future executions of this same "
|
||||
"block (node) will be automatically approved. This only affects approved reviews."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("reviewed_data")
|
||||
@classmethod
|
||||
@@ -174,6 +193,9 @@ class ReviewRequest(BaseModel):
|
||||
This request must include ALL pending reviews for a graph execution.
|
||||
Each review will be either approved (with optional data modifications)
|
||||
or rejected (data ignored). The execution will resume only after ALL reviews are processed.
|
||||
|
||||
Each review item can individually specify whether to auto-approve future executions
|
||||
of the same block via the `auto_approve_future` field on ReviewItem.
|
||||
"""
|
||||
|
||||
reviews: List[ReviewItem] = Field(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,27 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.data.execution import get_graph_execution_meta
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
get_graph_execution_meta,
|
||||
)
|
||||
from backend.data.graph import get_graph_settings
|
||||
from backend.data.human_review import (
|
||||
create_auto_approval_record,
|
||||
get_pending_reviews_by_node_exec_ids,
|
||||
get_pending_reviews_for_execution,
|
||||
get_pending_reviews_for_user,
|
||||
has_pending_reviews_for_graph_exec,
|
||||
process_all_reviews_for_execution,
|
||||
)
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
@@ -127,17 +137,70 @@ async def process_review_action(
|
||||
detail="At least one review must be provided",
|
||||
)
|
||||
|
||||
# Build review decisions map
|
||||
# Batch fetch all requested reviews
|
||||
reviews_map = await get_pending_reviews_by_node_exec_ids(
|
||||
list(all_request_node_ids), user_id
|
||||
)
|
||||
|
||||
# Validate all reviews were found
|
||||
missing_ids = all_request_node_ids - set(reviews_map.keys())
|
||||
if missing_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No pending review found for node execution(s): {', '.join(missing_ids)}",
|
||||
)
|
||||
|
||||
# Validate all reviews belong to the same execution
|
||||
graph_exec_ids = {review.graph_exec_id for review in reviews_map.values()}
|
||||
if len(graph_exec_ids) > 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="All reviews in a single request must belong to the same execution.",
|
||||
)
|
||||
|
||||
graph_exec_id = next(iter(graph_exec_ids))
|
||||
|
||||
# Validate execution status before processing reviews
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
# Only allow processing reviews if execution is paused for review
|
||||
# or incomplete (partial execution with some reviews already processed)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||
f"Current status: {graph_exec_meta.status}",
|
||||
)
|
||||
|
||||
# Build review decisions map and track which reviews requested auto-approval
|
||||
# Auto-approved reviews use original data (no modifications allowed)
|
||||
review_decisions = {}
|
||||
auto_approve_requests = {} # Map node_exec_id -> auto_approve_future flag
|
||||
|
||||
for review in request.reviews:
|
||||
review_status = (
|
||||
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
|
||||
)
|
||||
# If this review requested auto-approval, don't allow data modifications
|
||||
reviewed_data = None if review.auto_approve_future else review.reviewed_data
|
||||
review_decisions[review.node_exec_id] = (
|
||||
review_status,
|
||||
review.reviewed_data,
|
||||
reviewed_data,
|
||||
review.message,
|
||||
)
|
||||
auto_approve_requests[review.node_exec_id] = review.auto_approve_future
|
||||
|
||||
# Process all reviews
|
||||
updated_reviews = await process_all_reviews_for_execution(
|
||||
@@ -145,6 +208,87 @@ async def process_review_action(
|
||||
review_decisions=review_decisions,
|
||||
)
|
||||
|
||||
# Create auto-approval records for approved reviews that requested it
|
||||
# Deduplicate by node_id to avoid race conditions when multiple reviews
|
||||
# for the same node are processed in parallel
|
||||
async def create_auto_approval_for_node(
|
||||
node_id: str, review_result
|
||||
) -> tuple[str, bool]:
|
||||
"""
|
||||
Create auto-approval record for a node.
|
||||
Returns (node_id, success) tuple for tracking failures.
|
||||
"""
|
||||
try:
|
||||
await create_auto_approval_record(
|
||||
user_id=user_id,
|
||||
graph_exec_id=review_result.graph_exec_id,
|
||||
graph_id=review_result.graph_id,
|
||||
graph_version=review_result.graph_version,
|
||||
node_id=node_id,
|
||||
payload=review_result.payload,
|
||||
)
|
||||
return (node_id, True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create auto-approval record for node {node_id}",
|
||||
exc_info=e,
|
||||
)
|
||||
return (node_id, False)
|
||||
|
||||
# Collect node_exec_ids that need auto-approval
|
||||
node_exec_ids_needing_auto_approval = [
|
||||
node_exec_id
|
||||
for node_exec_id, review_result in updated_reviews.items()
|
||||
if review_result.status == ReviewStatus.APPROVED
|
||||
and auto_approve_requests.get(node_exec_id, False)
|
||||
]
|
||||
|
||||
# Batch-fetch node executions to get node_ids
|
||||
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||
if node_exec_ids_needing_auto_approval:
|
||||
from backend.data.execution import get_node_executions
|
||||
|
||||
node_execs = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||
)
|
||||
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
||||
|
||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||
node_exec = node_exec_map.get(node_exec_id)
|
||||
if node_exec:
|
||||
review_result = updated_reviews[node_exec_id]
|
||||
# Use the first approved review for this node (deduplicate by node_id)
|
||||
if node_exec.node_id not in nodes_needing_auto_approval:
|
||||
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to create auto-approval record for {node_exec_id}: "
|
||||
f"Node execution not found. This may indicate a race condition "
|
||||
f"or data inconsistency."
|
||||
)
|
||||
|
||||
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||
auto_approval_results = await asyncio.gather(
|
||||
*[
|
||||
create_auto_approval_for_node(node_id, review_result)
|
||||
for node_id, review_result in nodes_needing_auto_approval.items()
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Count auto-approval failures
|
||||
auto_approval_failed_count = 0
|
||||
for result in auto_approval_results:
|
||||
if isinstance(result, Exception):
|
||||
# Unexpected exception during auto-approval creation
|
||||
auto_approval_failed_count += 1
|
||||
logger.error(
|
||||
f"Unexpected exception during auto-approval creation: {result}"
|
||||
)
|
||||
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||
# Auto-approval creation failed (returned False)
|
||||
auto_approval_failed_count += 1
|
||||
|
||||
# Count results
|
||||
approved_count = sum(
|
||||
1
|
||||
@@ -157,30 +301,53 @@ async def process_review_action(
|
||||
if review.status == ReviewStatus.REJECTED
|
||||
)
|
||||
|
||||
# Resume execution if we processed some reviews
|
||||
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||
if updated_reviews:
|
||||
# Get graph execution ID from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
graph_exec_id = first_review.graph_exec_id
|
||||
|
||||
# Check if any pending reviews remain for this execution
|
||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||
|
||||
if not still_has_pending:
|
||||
# Resume execution
|
||||
# Get the graph_id from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
|
||||
try:
|
||||
# Fetch user and settings to build complete execution context
|
||||
user = await get_user_by_id(user_id)
|
||||
settings = await get_graph_settings(
|
||||
user_id=user_id, graph_id=first_review.graph_id
|
||||
)
|
||||
|
||||
# Preserve user's timezone preference when resuming execution
|
||||
user_timezone = (
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
user_timezone=user_timezone,
|
||||
)
|
||||
|
||||
await add_graph_execution(
|
||||
graph_id=first_review.graph_id,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Resumed execution {graph_exec_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}")
|
||||
|
||||
# Build error message if auto-approvals failed
|
||||
error_message = None
|
||||
if auto_approval_failed_count > 0:
|
||||
error_message = (
|
||||
f"{auto_approval_failed_count} auto-approval setting(s) could not be saved. "
|
||||
f"You may need to manually approve these reviews in future executions."
|
||||
)
|
||||
|
||||
return ReviewResponse(
|
||||
approved_count=approved_count,
|
||||
rejected_count=rejected_count,
|
||||
failed_count=0,
|
||||
error=None,
|
||||
failed_count=auto_approval_failed_count,
|
||||
error=error_message,
|
||||
)
|
||||
|
||||
@@ -583,7 +583,13 @@ async def update_library_agent(
|
||||
)
|
||||
update_fields["isDeleted"] = is_deleted
|
||||
if settings is not None:
|
||||
update_fields["settings"] = SafeJson(settings.model_dump())
|
||||
existing_agent = await get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
current_settings_dict = (
|
||||
existing_agent.settings.model_dump() if existing_agent.settings else {}
|
||||
)
|
||||
new_settings = settings.model_dump(exclude_unset=True)
|
||||
merged_settings = {**current_settings_dict, **new_settings}
|
||||
update_fields["settings"] = SafeJson(merged_settings)
|
||||
|
||||
try:
|
||||
# If graph_version is provided, update to that specific version
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission
|
||||
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
||||
@@ -38,13 +39,13 @@ keysmith = APIKeySmith()
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="session")
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID for OAuth tests."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def test_user(server, test_user_id: str):
|
||||
"""Create a test user in the database."""
|
||||
await PrismaUser.prisma().create(
|
||||
@@ -67,7 +68,7 @@ async def test_user(server, test_user_id: str):
|
||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def test_oauth_app(test_user: str):
|
||||
"""Create a test OAuth application in the database."""
|
||||
app_id = str(uuid.uuid4())
|
||||
@@ -122,7 +123,7 @@ def pkce_credentials() -> tuple[str, str]:
|
||||
return generate_pkce()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
"""
|
||||
Create an async HTTP client that talks directly to the FastAPI app.
|
||||
@@ -287,7 +288,7 @@ async def test_authorize_invalid_client_returns_error(
|
||||
assert query_params["error"][0] == "invalid_client"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def inactive_oauth_app(test_user: str):
|
||||
"""Create an inactive test OAuth application in the database."""
|
||||
app_id = str(uuid.uuid4())
|
||||
@@ -1004,7 +1005,7 @@ async def test_token_refresh_revoked(
|
||||
assert "revoked" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def other_oauth_app(test_user: str):
|
||||
"""Create a second OAuth application for cross-app tests."""
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
@@ -1552,7 +1552,7 @@ async def review_store_submission(
|
||||
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
embedding_success = await ensure_embedding(
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=store_listing_version.name,
|
||||
description=store_listing_version.description,
|
||||
@@ -1560,12 +1560,6 @@ async def review_store_submission(
|
||||
categories=store_listing_version.categories or [],
|
||||
tx=tx,
|
||||
)
|
||||
if not embedding_success:
|
||||
raise ValueError(
|
||||
f"Failed to generate embedding for listing {store_listing_version_id}. "
|
||||
"This is likely due to OpenAI API being unavailable. "
|
||||
"Please try again later or contact support if the issue persists."
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
|
||||
@@ -21,7 +21,6 @@ from backend.util.json import dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
# Embedding dimension for the model above
|
||||
@@ -63,49 +62,42 @@ def build_searchable_text(
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
async def generate_embedding(text: str) -> list[float] | None:
|
||||
async def generate_embedding(text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding for text using OpenAI API.
|
||||
|
||||
Returns None if embedding generation fails.
|
||||
Fail-fast: no retries to maintain consistency with approval flow.
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
try:
|
||||
client = get_openai_client()
|
||||
if not client:
|
||||
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
||||
return None
|
||||
client = get_openai_client()
|
||||
if not client:
|
||||
raise RuntimeError("openai_internal_api_key not set, cannot generate embedding")
|
||||
|
||||
# Truncate text to token limit using tiktoken
|
||||
# Character-based truncation is insufficient because token ratios vary by content type
|
||||
enc = encoding_for_model(EMBEDDING_MODEL)
|
||||
tokens = enc.encode(text)
|
||||
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
||||
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
||||
truncated_text = enc.decode(tokens)
|
||||
logger.info(
|
||||
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
||||
)
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
# Truncate text to token limit using tiktoken
|
||||
# Character-based truncation is insufficient because token ratios vary by content type
|
||||
enc = encoding_for_model(EMBEDDING_MODEL)
|
||||
tokens = enc.encode(text)
|
||||
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
||||
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
||||
truncated_text = enc.decode(tokens)
|
||||
logger.info(
|
||||
f"Generated embedding: {len(embedding)} dims, "
|
||||
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
||||
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
||||
)
|
||||
return embedding
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embedding: {e}")
|
||||
return None
|
||||
start_time = time.time()
|
||||
response = await client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
logger.info(
|
||||
f"Generated embedding: {len(embedding)} dims, "
|
||||
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
||||
)
|
||||
return embedding
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
@@ -144,48 +136,45 @@ async def store_content_embedding(
|
||||
|
||||
New function for unified content embedding storage.
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
try:
|
||||
client = tx if tx else prisma.get_client()
|
||||
client = tx if tx else prisma.get_client()
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = embedding_to_vector_string(embedding)
|
||||
metadata_json = dumps(metadata or {})
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = embedding_to_vector_string(embedding)
|
||||
metadata_json = dumps(metadata or {})
|
||||
|
||||
# Upsert the embedding
|
||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||
# Use unqualified ::vector - pgvector is in search_path on all environments
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||
ON CONFLICT ("contentType", "contentId", "userId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $4::vector,
|
||||
"searchableText" = $5,
|
||||
"metadata" = $6::jsonb,
|
||||
"updatedAt" = NOW()
|
||||
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
||||
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
metadata_json,
|
||||
client=client,
|
||||
# Upsert the embedding
|
||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||
# Use unqualified ::vector - pgvector is in search_path on all environments
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||
ON CONFLICT ("contentType", "contentId", "userId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $4::vector,
|
||||
"searchableText" = $5,
|
||||
"metadata" = $6::jsonb,
|
||||
"updatedAt" = NOW()
|
||||
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
||||
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
metadata_json,
|
||||
client=client,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||
return True
|
||||
|
||||
|
||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||
@@ -217,34 +206,31 @@ async def get_content_embedding(
|
||||
|
||||
New function for unified content embedding retrieval.
|
||||
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
||||
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
try:
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
"contentType",
|
||||
"contentId",
|
||||
"userId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"metadata",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
)
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
"contentType",
|
||||
"contentId",
|
||||
"userId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"metadata",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
|
||||
return None
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
|
||||
async def ensure_embedding(
|
||||
@@ -272,46 +258,38 @@ async def ensure_embedding(
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
True if embedding exists/was created
|
||||
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(f"Embedding for version {version_id} already exists")
|
||||
return True
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(f"Embedding for version {version_id} already exists")
|
||||
return True
|
||||
|
||||
# Build searchable text for embedding
|
||||
searchable_text = build_searchable_text(
|
||||
name, description, sub_heading, categories
|
||||
)
|
||||
# Build searchable text for embedding
|
||||
searchable_text = build_searchable_text(name, description, sub_heading, categories)
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
return False
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
|
||||
# Store the embedding with metadata using new function
|
||||
metadata = {
|
||||
"name": name,
|
||||
"subHeading": sub_heading,
|
||||
"categories": categories,
|
||||
}
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
||||
return False
|
||||
# Store the embedding with metadata using new function
|
||||
metadata = {
|
||||
"name": name,
|
||||
"subHeading": sub_heading,
|
||||
"categories": categories,
|
||||
}
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
|
||||
async def delete_embedding(version_id: str) -> bool:
|
||||
@@ -521,6 +499,24 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
||||
success = sum(1 for result in results if result is True)
|
||||
failed = len(results) - success
|
||||
|
||||
# Aggregate unique errors to avoid Sentry spam
|
||||
if failed > 0:
|
||||
# Group errors by type and message
|
||||
error_summary: dict[str, int] = {}
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
error_key = f"{type(result).__name__}: {str(result)}"
|
||||
error_summary[error_key] = error_summary.get(error_key, 0) + 1
|
||||
|
||||
# Log aggregated error summary
|
||||
error_details = ", ".join(
|
||||
f"{error} ({count}x)" for error, count in error_summary.items()
|
||||
)
|
||||
logger.error(
|
||||
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
|
||||
f"Errors: {error_details}"
|
||||
)
|
||||
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": len(missing_items),
|
||||
"success": success,
|
||||
@@ -557,11 +553,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float] | None:
|
||||
async def embed_query(query: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding for a search query.
|
||||
|
||||
Same as generate_embedding but with clearer intent.
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
return await generate_embedding(query)
|
||||
|
||||
@@ -594,40 +591,30 @@ async def ensure_content_embedding(
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
True if embedding exists/was created
|
||||
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(
|
||||
f"Embedding for {content_type}:{content_id} already exists"
|
||||
)
|
||||
return True
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(f"Embedding for {content_type}:{content_id} already exists")
|
||||
return True
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for {content_type}:{content_id}"
|
||||
)
|
||||
return False
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
|
||||
# Store the embedding
|
||||
return await store_content_embedding(
|
||||
content_type=content_type,
|
||||
content_id=content_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata or {},
|
||||
user_id=user_id,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
# Store the embedding
|
||||
return await store_content_embedding(
|
||||
content_type=content_type,
|
||||
content_id=content_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata or {},
|
||||
user_id=user_id,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
||||
@@ -854,9 +841,8 @@ async def semantic_search(
|
||||
limit = 100
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
if query_embedding is not None:
|
||||
try:
|
||||
query_embedding = await embed_query(query)
|
||||
# Semantic search with embeddings
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
|
||||
@@ -907,24 +893,21 @@ async def semantic_search(
|
||||
"""
|
||||
)
|
||||
|
||||
try:
|
||||
results = await query_raw_with_schema(sql, *params)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
"content_type": row["content_type"],
|
||||
"searchable_text": row["searchable_text"],
|
||||
"metadata": row["metadata"],
|
||||
"similarity": float(row["similarity"]),
|
||||
}
|
||||
for row in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Semantic search failed: {e}")
|
||||
# Fall through to lexical search below
|
||||
results = await query_raw_with_schema(sql, *params)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
"content_type": row["content_type"],
|
||||
"searchable_text": row["searchable_text"],
|
||||
"metadata": row["metadata"],
|
||||
"similarity": float(row["similarity"]),
|
||||
}
|
||||
for row in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic search failed, falling back to lexical search: {e}")
|
||||
|
||||
# Fallback to lexical search if embeddings unavailable
|
||||
logger.warning("Falling back to lexical search (embeddings unavailable)")
|
||||
|
||||
params_lexical: list[Any] = [limit]
|
||||
user_filter = ""
|
||||
|
||||
@@ -298,17 +298,16 @@ async def test_schema_handling_error_cases():
|
||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * EMBEDDING_DIM,
|
||||
searchable_text="test",
|
||||
metadata=None,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Should return False on error, not raise
|
||||
assert result is False
|
||||
# Should raise exception on error
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * EMBEDDING_DIM,
|
||||
searchable_text="test",
|
||||
metadata=None,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -80,9 +80,8 @@ async def test_generate_embedding_no_api_key():
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = None
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
with pytest.raises(RuntimeError, match="openai_internal_api_key not set"):
|
||||
await embeddings.generate_embedding("test text")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -97,9 +96,8 @@ async def test_generate_embedding_api_error():
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await embeddings.generate_embedding("test text")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -173,11 +171,10 @@ async def test_store_embedding_database_error(mocker):
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is False
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -277,17 +274,16 @@ async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
|
||||
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
||||
"""Test ensure_embedding when generation fails."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = None
|
||||
mock_generate.side_effect = Exception("Generation failed")
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is False
|
||||
with pytest.raises(Exception, match="Generation failed"):
|
||||
await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -186,13 +186,12 @@ async def unified_hybrid_search(
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Graceful degradation if embedding unavailable
|
||||
if query_embedding is None or not query_embedding:
|
||||
# Generate query embedding with graceful degradation
|
||||
try:
|
||||
query_embedding = await embed_query(query)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to generate query embedding - falling back to lexical-only search. "
|
||||
f"Failed to generate query embedding - falling back to lexical-only search: {e}. "
|
||||
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
||||
)
|
||||
query_embedding = [0.0] * EMBEDDING_DIM
|
||||
@@ -464,13 +463,12 @@ async def hybrid_search(
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Graceful degradation
|
||||
if query_embedding is None or not query_embedding:
|
||||
# Generate query embedding with graceful degradation
|
||||
try:
|
||||
query_embedding = await embed_query(query)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to generate query embedding - falling back to lexical-only search."
|
||||
f"Failed to generate query embedding - falling back to lexical-only search: {e}"
|
||||
)
|
||||
query_embedding = [0.0] * EMBEDDING_DIM
|
||||
total_non_semantic = (
|
||||
|
||||
@@ -172,8 +172,8 @@ async def test_hybrid_search_without_embeddings():
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Simulate embedding failure
|
||||
mock_embed.return_value = None
|
||||
# Simulate embedding failure by raising exception
|
||||
mock_embed.side_effect = Exception("Embedding generation failed")
|
||||
mock_query.return_value = mock_results
|
||||
|
||||
# Should NOT raise - graceful degradation
|
||||
@@ -613,7 +613,9 @@ async def test_unified_hybrid_search_graceful_degradation():
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = None # Embedding failure
|
||||
mock_embed.side_effect = Exception(
|
||||
"Embedding generation failed"
|
||||
) # Embedding failure
|
||||
|
||||
# Should NOT raise - graceful degradation
|
||||
results, total = await unified_hybrid_search(
|
||||
|
||||
@@ -393,7 +393,6 @@ async def get_creators(
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
summary="Get creator details",
|
||||
operation_id="getV2GetCreatorDetails",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,6 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.llm_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
@@ -38,11 +37,9 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.v2.llm.routes as public_llm_routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
@@ -112,27 +109,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated discriminator_mapping
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
# migrate_llm_models uses registry default model
|
||||
from backend.blocks.llm import LlmModel
|
||||
|
||||
default_model_slug = llm_registry.get_default_model_slug()
|
||||
if default_model_slug:
|
||||
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
|
||||
else:
|
||||
logger.warning("Skipping LLM model migration: no default model available")
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
@@ -317,16 +298,6 @@ app.include_router(
|
||||
tags=["v2", "executions", "review"],
|
||||
prefix="/api/review",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.llm_routes.router,
|
||||
tags=["v2", "admin", "llm"],
|
||||
prefix="/api/llm/admin",
|
||||
)
|
||||
app.include_router(
|
||||
public_llm_routes.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
|
||||
@@ -77,39 +77,7 @@ async def event_broadcaster(manager: ConnectionManager):
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
async def registry_refresh_worker():
|
||||
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
|
||||
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info(
|
||||
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
|
||||
)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if (
|
||||
message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info(
|
||||
"Broadcasting LLM registry refresh to all WebSocket clients"
|
||||
)
|
||||
await manager.broadcast_to_all(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data={
|
||||
"type": "LLM_REGISTRY_REFRESH",
|
||||
"event": "registry_updated",
|
||||
},
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
execution_worker(),
|
||||
notification_worker(),
|
||||
registry_refresh_worker(),
|
||||
)
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
@@ -9,7 +10,6 @@ from backend.blocks.llm import (
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
llm_model_schema_extra,
|
||||
)
|
||||
from backend.data.block import (
|
||||
BlockCategory,
|
||||
@@ -50,10 +50,9 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
@@ -83,7 +82,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
|
||||
@@ -116,6 +116,7 @@ class PrintToConsoleBlock(Block):
|
||||
input_schema=PrintToConsoleBlock.Input,
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
is_sensitive_action=True,
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
|
||||
659
autogpt_platform/backend/backend/blocks/claude_code.py
Normal file
659
autogpt_platform/backend/backend/blocks/claude_code.py
Normal file
@@ -0,0 +1,659 @@
|
||||
import json
|
||||
import shlex
|
||||
import uuid
|
||||
from typing import Literal, Optional
|
||||
|
||||
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class ClaudeCodeExecutionError(Exception):
|
||||
"""Exception raised when Claude Code execution fails.
|
||||
|
||||
Carries the sandbox_id so it can be returned to the user for cleanup
|
||||
when dispose_sandbox=False.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, sandbox_id: str = ""):
|
||||
super().__init__(message)
|
||||
self.sandbox_id = sandbox_id
|
||||
|
||||
|
||||
# Test credentials for E2B
|
||||
TEST_E2B_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="e2b",
|
||||
api_key=SecretStr("mock-e2b-api-key"),
|
||||
title="Mock E2B API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_E2B_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_E2B_CREDENTIALS.provider,
|
||||
"id": TEST_E2B_CREDENTIALS.id,
|
||||
"type": TEST_E2B_CREDENTIALS.type,
|
||||
"title": TEST_E2B_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
# Test credentials for Anthropic
|
||||
TEST_ANTHROPIC_CREDENTIALS = APIKeyCredentials(
|
||||
id="2e568a2b-b2ea-475a-8564-9a676bf31c56",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr("mock-anthropic-api-key"),
|
||||
title="Mock Anthropic API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_ANTHROPIC_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_ANTHROPIC_CREDENTIALS.provider,
|
||||
"id": TEST_ANTHROPIC_CREDENTIALS.id,
|
||||
"type": TEST_ANTHROPIC_CREDENTIALS.type,
|
||||
"title": TEST_ANTHROPIC_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class ClaudeCodeBlock(Block):
|
||||
"""
|
||||
Execute tasks using Claude Code (Anthropic's AI coding assistant) in an E2B sandbox.
|
||||
|
||||
Claude Code can create files, install tools, run commands, and perform complex
|
||||
coding tasks autonomously within a secure sandbox environment.
|
||||
"""
|
||||
|
||||
# Use base template - we'll install Claude Code ourselves for latest version
|
||||
DEFAULT_TEMPLATE = "base"
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
e2b_credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"API key for the E2B platform to create the sandbox. "
|
||||
"Get one on the [e2b website](https://e2b.dev/docs)"
|
||||
),
|
||||
)
|
||||
|
||||
anthropic_credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.ANTHROPIC], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"API key for Anthropic to power Claude Code. "
|
||||
"Get one at [Anthropic's website](https://console.anthropic.com)"
|
||||
),
|
||||
)
|
||||
|
||||
prompt: str = SchemaField(
|
||||
description=(
|
||||
"The task or instruction for Claude Code to execute. "
|
||||
"Claude Code can create files, install packages, run commands, "
|
||||
"and perform complex coding tasks."
|
||||
),
|
||||
placeholder="Create a hello world index.html file",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
timeout: int = SchemaField(
|
||||
description=(
|
||||
"Sandbox timeout in seconds. Claude Code tasks can take "
|
||||
"a while, so set this appropriately for your task complexity. "
|
||||
"Note: This only applies when creating a new sandbox. "
|
||||
"When reconnecting to an existing sandbox via sandbox_id, "
|
||||
"the original timeout is retained."
|
||||
),
|
||||
default=300, # 5 minutes default
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
setup_commands: list[str] = SchemaField(
|
||||
description=(
|
||||
"Optional shell commands to run before executing Claude Code. "
|
||||
"Useful for installing dependencies or setting up the environment."
|
||||
),
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
working_directory: str = SchemaField(
|
||||
description="Working directory for Claude Code to operate in.",
|
||||
default="/home/user",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Session/continuation support
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID to resume a previous conversation. "
|
||||
"Leave empty for a new conversation. "
|
||||
"Use the session_id from a previous run to continue that conversation."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
sandbox_id: str = SchemaField(
|
||||
description=(
|
||||
"Sandbox ID to reconnect to an existing sandbox. "
|
||||
"Required when resuming a session (along with session_id). "
|
||||
"Use the sandbox_id from a previous run where dispose_sandbox was False."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Previous conversation history to continue from. "
|
||||
"Use this to restore context on a fresh sandbox if the previous one timed out. "
|
||||
"Pass the conversation_history output from a previous run."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description=(
|
||||
"Whether to dispose of the sandbox immediately after execution. "
|
||||
"Set to False if you want to continue the conversation later "
|
||||
"(you'll need both sandbox_id and session_id from the output)."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class FileOutput(BaseModel):
|
||||
"""A file extracted from the sandbox."""
|
||||
|
||||
path: str
|
||||
relative_path: str # Path relative to working directory (for GitHub, etc.)
|
||||
name: str
|
||||
content: str
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
response: str = SchemaField(
|
||||
description="The output/response from Claude Code execution"
|
||||
)
|
||||
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
|
||||
description=(
|
||||
"List of text files created/modified by Claude Code during this execution. "
|
||||
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
|
||||
)
|
||||
)
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Full conversation history including this turn. "
|
||||
"Pass this to conversation_history input to continue on a fresh sandbox "
|
||||
"if the previous sandbox timed out."
|
||||
)
|
||||
)
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID for this conversation. "
|
||||
"Pass this back along with sandbox_id to continue the conversation."
|
||||
)
|
||||
)
|
||||
sandbox_id: Optional[str] = SchemaField(
|
||||
description=(
|
||||
"ID of the sandbox instance. "
|
||||
"Pass this back along with session_id to continue the conversation. "
|
||||
"This is None if dispose_sandbox was True (sandbox was disposed)."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4e34f4a5-9b89-4326-ba77-2dd6750b7194",
|
||||
description=(
|
||||
"Execute tasks using Claude Code in an E2B sandbox. "
|
||||
"Claude Code can create files, install tools, run commands, "
|
||||
"and perform complex coding tasks autonomously."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.AI},
|
||||
input_schema=ClaudeCodeBlock.Input,
|
||||
output_schema=ClaudeCodeBlock.Output,
|
||||
test_credentials={
|
||||
"e2b_credentials": TEST_E2B_CREDENTIALS,
|
||||
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS,
|
||||
},
|
||||
test_input={
|
||||
"e2b_credentials": TEST_E2B_CREDENTIALS_INPUT,
|
||||
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS_INPUT,
|
||||
"prompt": "Create a hello world HTML file",
|
||||
"timeout": 300,
|
||||
"setup_commands": [],
|
||||
"working_directory": "/home/user",
|
||||
"session_id": "",
|
||||
"sandbox_id": "",
|
||||
"conversation_history": "",
|
||||
"dispose_sandbox": True,
|
||||
},
|
||||
test_output=[
|
||||
("response", "Created index.html with hello world content"),
|
||||
(
|
||||
"files",
|
||||
[
|
||||
{
|
||||
"path": "/home/user/index.html",
|
||||
"relative_path": "index.html",
|
||||
"name": "index.html",
|
||||
"content": "<html>Hello World</html>",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"conversation_history",
|
||||
"User: Create a hello world HTML file\n"
|
||||
"Claude: Created index.html with hello world content",
|
||||
),
|
||||
("session_id", str),
|
||||
("sandbox_id", None), # None because dispose_sandbox=True in test_input
|
||||
],
|
||||
test_mock={
|
||||
"execute_claude_code": lambda *args, **kwargs: (
|
||||
"Created index.html with hello world content", # response
|
||||
[
|
||||
ClaudeCodeBlock.FileOutput(
|
||||
path="/home/user/index.html",
|
||||
relative_path="index.html",
|
||||
name="index.html",
|
||||
content="<html>Hello World</html>",
|
||||
)
|
||||
], # files
|
||||
"User: Create a hello world HTML file\n"
|
||||
"Claude: Created index.html with hello world content", # conversation_history
|
||||
"test-session-id", # session_id
|
||||
"sandbox_id", # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_claude_code(
|
||||
self,
|
||||
e2b_api_key: str,
|
||||
anthropic_api_key: str,
|
||||
prompt: str,
|
||||
timeout: int,
|
||||
setup_commands: list[str],
|
||||
working_directory: str,
|
||||
session_id: str,
|
||||
existing_sandbox_id: str,
|
||||
conversation_history: str,
|
||||
dispose_sandbox: bool,
|
||||
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
|
||||
"""
|
||||
Execute Claude Code in an E2B sandbox.
|
||||
|
||||
Returns:
|
||||
Tuple of (response, files, conversation_history, session_id, sandbox_id)
|
||||
"""
|
||||
|
||||
# Validate that sandbox_id is provided when resuming a session
|
||||
if session_id and not existing_sandbox_id:
|
||||
raise ValueError(
|
||||
"sandbox_id is required when resuming a session with session_id. "
|
||||
"The session state is stored in the original sandbox. "
|
||||
"If the sandbox has timed out, use conversation_history instead "
|
||||
"to restore context on a fresh sandbox."
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
sandbox_id = ""
|
||||
|
||||
try:
|
||||
# Either reconnect to existing sandbox or create a new one
|
||||
if existing_sandbox_id:
|
||||
# Reconnect to existing sandbox for conversation continuation
|
||||
sandbox = await BaseAsyncSandbox.connect(
|
||||
sandbox_id=existing_sandbox_id,
|
||||
api_key=e2b_api_key,
|
||||
)
|
||||
else:
|
||||
# Create new sandbox
|
||||
sandbox = await BaseAsyncSandbox.create(
|
||||
template=self.DEFAULT_TEMPLATE,
|
||||
api_key=e2b_api_key,
|
||||
timeout=timeout,
|
||||
envs={"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
)
|
||||
|
||||
# Install Claude Code from npm (ensures we get the latest version)
|
||||
install_result = await sandbox.commands.run(
|
||||
"npm install -g @anthropic-ai/claude-code@latest",
|
||||
timeout=120, # 2 min timeout for install
|
||||
)
|
||||
if install_result.exit_code != 0:
|
||||
raise Exception(
|
||||
f"Failed to install Claude Code: {install_result.stderr}"
|
||||
)
|
||||
|
||||
# Run any user-provided setup commands
|
||||
for cmd in setup_commands:
|
||||
setup_result = await sandbox.commands.run(cmd)
|
||||
if setup_result.exit_code != 0:
|
||||
raise Exception(
|
||||
f"Setup command failed: {cmd}\n"
|
||||
f"Exit code: {setup_result.exit_code}\n"
|
||||
f"Stdout: {setup_result.stdout}\n"
|
||||
f"Stderr: {setup_result.stderr}"
|
||||
)
|
||||
|
||||
# Capture sandbox_id immediately after creation/connection
|
||||
# so it's available for error recovery if dispose_sandbox=False
|
||||
sandbox_id = sandbox.sandbox_id
|
||||
|
||||
# Generate or use provided session ID
|
||||
current_session_id = session_id if session_id else str(uuid.uuid4())
|
||||
|
||||
# Build base Claude flags
|
||||
base_flags = "-p --dangerously-skip-permissions --output-format json"
|
||||
|
||||
# Add conversation history context if provided (for fresh sandbox continuation)
|
||||
history_flag = ""
|
||||
if conversation_history and not session_id:
|
||||
# Inject previous conversation as context via system prompt
|
||||
# Use consistent escaping via _escape_prompt helper
|
||||
escaped_history = self._escape_prompt(
|
||||
f"Previous conversation context: {conversation_history}"
|
||||
)
|
||||
history_flag = f" --append-system-prompt {escaped_history}"
|
||||
|
||||
# Build Claude command based on whether we're resuming or starting new
|
||||
# Use shlex.quote for working_directory and session IDs to prevent injection
|
||||
safe_working_dir = shlex.quote(working_directory)
|
||||
if session_id:
|
||||
# Resuming existing session (sandbox still alive)
|
||||
safe_session_id = shlex.quote(session_id)
|
||||
claude_command = (
|
||||
f"cd {safe_working_dir} && "
|
||||
f"echo {self._escape_prompt(prompt)} | "
|
||||
f"claude --resume {safe_session_id} {base_flags}"
|
||||
)
|
||||
else:
|
||||
# New session with specific ID
|
||||
safe_current_session_id = shlex.quote(current_session_id)
|
||||
claude_command = (
|
||||
f"cd {safe_working_dir} && "
|
||||
f"echo {self._escape_prompt(prompt)} | "
|
||||
f"claude --session-id {safe_current_session_id} {base_flags}{history_flag}"
|
||||
)
|
||||
|
||||
# Capture timestamp before running Claude Code to filter files later
|
||||
# Capture timestamp 1 second in the past to avoid race condition with file creation
|
||||
timestamp_result = await sandbox.commands.run(
|
||||
"date -u -d '1 second ago' +%Y-%m-%dT%H:%M:%S"
|
||||
)
|
||||
if timestamp_result.exit_code != 0:
|
||||
raise RuntimeError(
|
||||
f"Failed to capture timestamp: {timestamp_result.stderr}"
|
||||
)
|
||||
start_timestamp = (
|
||||
timestamp_result.stdout.strip() if timestamp_result.stdout else None
|
||||
)
|
||||
|
||||
result = await sandbox.commands.run(
|
||||
claude_command,
|
||||
timeout=0, # No command timeout - let sandbox timeout handle it
|
||||
)
|
||||
|
||||
# Check for command failure
|
||||
if result.exit_code != 0:
|
||||
error_msg = result.stderr or result.stdout or "Unknown error"
|
||||
raise Exception(
|
||||
f"Claude Code command failed with exit code {result.exit_code}:\n"
|
||||
f"{error_msg}"
|
||||
)
|
||||
|
||||
raw_output = result.stdout or ""
|
||||
|
||||
# Parse JSON output to extract response and build conversation history
|
||||
response = ""
|
||||
new_conversation_history = conversation_history or ""
|
||||
|
||||
try:
|
||||
# The JSON output contains the result
|
||||
output_data = json.loads(raw_output)
|
||||
response = output_data.get("result", raw_output)
|
||||
|
||||
# Build conversation history entry
|
||||
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||
if new_conversation_history:
|
||||
new_conversation_history = (
|
||||
f"{new_conversation_history}\n\n{turn_entry}"
|
||||
)
|
||||
else:
|
||||
new_conversation_history = turn_entry
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# If not valid JSON, use raw output
|
||||
response = raw_output
|
||||
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||
if new_conversation_history:
|
||||
new_conversation_history = (
|
||||
f"{new_conversation_history}\n\n{turn_entry}"
|
||||
)
|
||||
else:
|
||||
new_conversation_history = turn_entry
|
||||
|
||||
# Extract files created/modified during this run
|
||||
files = await self._extract_files(
|
||||
sandbox, working_directory, start_timestamp
|
||||
)
|
||||
|
||||
return (
|
||||
response,
|
||||
files,
|
||||
new_conversation_history,
|
||||
current_session_id,
|
||||
sandbox_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Wrap exception with sandbox_id so caller can access/cleanup
|
||||
# the preserved sandbox when dispose_sandbox=False
|
||||
raise ClaudeCodeExecutionError(str(e), sandbox_id) from e
|
||||
|
||||
finally:
|
||||
if dispose_sandbox and sandbox:
|
||||
await sandbox.kill()
|
||||
|
||||
async def _extract_files(
|
||||
self,
|
||||
sandbox: BaseAsyncSandbox,
|
||||
working_directory: str,
|
||||
since_timestamp: str | None = None,
|
||||
) -> list["ClaudeCodeBlock.FileOutput"]:
|
||||
"""
|
||||
Extract text files created/modified during this Claude Code execution.
|
||||
|
||||
Args:
|
||||
sandbox: The E2B sandbox instance
|
||||
working_directory: Directory to search for files
|
||||
since_timestamp: ISO timestamp - only return files modified after this time
|
||||
|
||||
Returns:
|
||||
List of FileOutput objects with path, relative_path, name, and content
|
||||
"""
|
||||
files: list[ClaudeCodeBlock.FileOutput] = []
|
||||
|
||||
# Text file extensions we can safely read as text
|
||||
text_extensions = {
|
||||
".txt",
|
||||
".md",
|
||||
".html",
|
||||
".htm",
|
||||
".css",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".py",
|
||||
".rb",
|
||||
".php",
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rs",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".sql",
|
||||
".graphql",
|
||||
".env",
|
||||
".gitignore",
|
||||
".dockerfile",
|
||||
"Dockerfile",
|
||||
".vue",
|
||||
".svelte",
|
||||
".astro",
|
||||
".mdx",
|
||||
".rst",
|
||||
".tex",
|
||||
".csv",
|
||||
".log",
|
||||
}
|
||||
|
||||
try:
|
||||
# List files recursively using find command
|
||||
# Exclude node_modules and .git directories, but allow hidden files
|
||||
# like .env and .gitignore (they're filtered by text_extensions later)
|
||||
# Filter by timestamp to only get files created/modified during this run
|
||||
safe_working_dir = shlex.quote(working_directory)
|
||||
timestamp_filter = ""
|
||||
if since_timestamp:
|
||||
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
|
||||
find_result = await sandbox.commands.run(
|
||||
f"find {safe_working_dir} -type f "
|
||||
f"{timestamp_filter}"
|
||||
f"-not -path '*/node_modules/*' "
|
||||
f"-not -path '*/.git/*' "
|
||||
f"2>/dev/null"
|
||||
)
|
||||
|
||||
if find_result.stdout:
|
||||
for file_path in find_result.stdout.strip().split("\n"):
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Check if it's a text file we can read
|
||||
is_text = any(
|
||||
file_path.endswith(ext) for ext in text_extensions
|
||||
) or file_path.endswith("Dockerfile")
|
||||
|
||||
if is_text:
|
||||
try:
|
||||
content = await sandbox.files.read(file_path)
|
||||
# Handle bytes or string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
|
||||
# Extract filename from path
|
||||
file_name = file_path.split("/")[-1]
|
||||
|
||||
# Calculate relative path by stripping working directory
|
||||
relative_path = file_path
|
||||
if file_path.startswith(working_directory):
|
||||
relative_path = file_path[len(working_directory) :]
|
||||
# Remove leading slash if present
|
||||
if relative_path.startswith("/"):
|
||||
relative_path = relative_path[1:]
|
||||
|
||||
files.append(
|
||||
ClaudeCodeBlock.FileOutput(
|
||||
path=file_path,
|
||||
relative_path=relative_path,
|
||||
name=file_name,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Skip files that can't be read
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
# If file extraction fails, return empty results
|
||||
pass
|
||||
|
||||
return files
|
||||
|
||||
def _escape_prompt(self, prompt: str) -> str:
|
||||
"""Escape the prompt for safe shell execution."""
|
||||
# Use single quotes and escape any single quotes in the prompt
|
||||
escaped = prompt.replace("'", "'\"'\"'")
|
||||
return f"'{escaped}'"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
e2b_credentials: APIKeyCredentials,
|
||||
anthropic_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
(
|
||||
response,
|
||||
files,
|
||||
conversation_history,
|
||||
session_id,
|
||||
sandbox_id,
|
||||
) = await self.execute_claude_code(
|
||||
e2b_api_key=e2b_credentials.api_key.get_secret_value(),
|
||||
anthropic_api_key=anthropic_credentials.api_key.get_secret_value(),
|
||||
prompt=input_data.prompt,
|
||||
timeout=input_data.timeout,
|
||||
setup_commands=input_data.setup_commands,
|
||||
working_directory=input_data.working_directory,
|
||||
session_id=input_data.session_id,
|
||||
existing_sandbox_id=input_data.sandbox_id,
|
||||
conversation_history=input_data.conversation_history,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
# Always yield files (empty list if none) to match Output schema
|
||||
yield "files", [f.model_dump() for f in files]
|
||||
# Always yield conversation_history so user can restore context on fresh sandbox
|
||||
yield "conversation_history", conversation_history
|
||||
# Always yield session_id so user can continue conversation
|
||||
yield "session_id", session_id
|
||||
# Always yield sandbox_id (None if disposed) to match Output schema
|
||||
yield "sandbox_id", sandbox_id if not input_data.dispose_sandbox else None
|
||||
|
||||
except ClaudeCodeExecutionError as e:
|
||||
yield "error", str(e)
|
||||
# If sandbox was preserved (dispose_sandbox=False), yield sandbox_id
|
||||
# so user can reconnect to or clean up the orphaned sandbox
|
||||
if not input_data.dispose_sandbox and e.sandbox_id:
|
||||
yield "sandbox_id", e.sandbox_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -9,7 +9,7 @@ from typing import Any, Optional
|
||||
from prisma.enums import ReviewStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
@@ -28,6 +28,11 @@ class ReviewDecision(BaseModel):
|
||||
class HITLReviewHelper:
|
||||
"""Helper class for Human-In-The-Loop review operations."""
|
||||
|
||||
@staticmethod
|
||||
async def check_approval(**kwargs) -> Optional[ReviewResult]:
|
||||
"""Check if there's an existing approval for this node execution."""
|
||||
return await get_database_manager_async_client().check_approval(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
||||
"""Create or retrieve a human review from the database."""
|
||||
@@ -55,11 +60,11 @@ class HITLReviewHelper:
|
||||
async def _handle_review_request(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewResult]:
|
||||
@@ -69,11 +74,11 @@ class HITLReviewHelper:
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_id: ID of the node in the graph definition
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
@@ -83,15 +88,41 @@ class HITLReviewHelper:
|
||||
Raises:
|
||||
Exception: If review creation or status update fails
|
||||
"""
|
||||
# Skip review if safe mode is disabled - return auto-approved result
|
||||
if not execution_context.human_in_the_loop_safe_mode:
|
||||
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
||||
# are handled by the caller:
|
||||
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
||||
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
||||
# This function only handles checking for existing approvals.
|
||||
|
||||
# Check if this node has already been approved (normal or auto-approval)
|
||||
if approval_result := await HITLReviewHelper.check_approval(
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_id=node_id,
|
||||
user_id=user_id,
|
||||
input_data=input_data,
|
||||
):
|
||||
logger.info(
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - "
|
||||
f"found existing approval"
|
||||
)
|
||||
# Return a new ReviewResult with the current node_exec_id but approved status
|
||||
# For auto-approvals, always use current input_data
|
||||
# For normal approvals, use approval_result.data unless it's None
|
||||
is_auto_approval = approval_result.node_exec_id != node_exec_id
|
||||
approved_data = (
|
||||
input_data
|
||||
if is_auto_approval
|
||||
else (
|
||||
approval_result.data
|
||||
if approval_result.data is not None
|
||||
else input_data
|
||||
)
|
||||
)
|
||||
return ReviewResult(
|
||||
data=input_data,
|
||||
data=approved_data,
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="Auto-approved (safe mode disabled)",
|
||||
message=approval_result.message,
|
||||
processed=True,
|
||||
node_exec_id=node_exec_id,
|
||||
)
|
||||
@@ -103,7 +134,7 @@ class HITLReviewHelper:
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
input_data=input_data,
|
||||
message=f"Review required for {block_name} execution",
|
||||
message=block_name, # Use block_name directly as the message
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
@@ -129,11 +160,11 @@ class HITLReviewHelper:
|
||||
async def handle_review_decision(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewDecision]:
|
||||
@@ -143,11 +174,11 @@ class HITLReviewHelper:
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_id: ID of the node in the graph definition
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
@@ -158,11 +189,11 @@ class HITLReviewHelper:
|
||||
review_result = await HITLReviewHelper._handle_review_request(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
@@ -97,6 +97,7 @@ class HumanInTheLoopBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
@@ -115,12 +116,12 @@ class HumanInTheLoopBlock(Block):
|
||||
decision = await self.handle_review_decision(
|
||||
input_data=input_data.data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
block_name=input_data.name, # Use user-provided name instead of block type
|
||||
editable=input_data.editable,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,19 +4,17 @@ import logging
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Iterable, List, Literal, Optional
|
||||
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -24,7 +22,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.llm_registry import ModelMetadata
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -69,123 +66,114 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
|
||||
def AICredentialsField() -> AICredentials:
|
||||
"""
|
||||
Returns a CredentialsField for LLM providers.
|
||||
The discriminator_mapping will be refreshed when the schema is generated
|
||||
if it's empty, ensuring the LLM registry is loaded.
|
||||
"""
|
||||
# Get the mapping now - it may be empty initially, but will be refreshed
|
||||
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
|
||||
mapping = llm_registry.get_llm_discriminator_mapping()
|
||||
|
||||
return CredentialsField(
|
||||
description="API key for the LLM provider.",
|
||||
discriminator="model",
|
||||
discriminator_mapping=mapping, # May be empty initially, refreshed later
|
||||
discriminator_mapping={
|
||||
model.value: model.metadata.provider for model in LlmModel
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def llm_model_schema_extra() -> dict[str, Any]:
|
||||
return {"options": llm_registry.get_llm_model_schema_options()}
|
||||
class ModelMetadata(NamedTuple):
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
display_name: str
|
||||
provider_name: str
|
||||
creator_name: str
|
||||
price_tier: Literal[1, 2, 3]
|
||||
|
||||
|
||||
class LlmModelMeta(type):
|
||||
"""
|
||||
Metaclass for LlmModel that enables attribute-style access to dynamic models.
|
||||
|
||||
This allows code like `LlmModel.GPT4O` to work by converting the attribute
|
||||
name to a slug format:
|
||||
- GPT4O -> gpt-4o
|
||||
- GPT4O_MINI -> gpt-4o-mini
|
||||
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
|
||||
"""
|
||||
|
||||
def __getattr__(cls, name: str):
|
||||
# Don't intercept private/dunder attributes
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
|
||||
|
||||
# Convert attribute name to slug format:
|
||||
# 1. Lowercase: GPT4O -> gpt4o
|
||||
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
|
||||
slug = name.lower().replace("_", "-")
|
||||
|
||||
# Check for exact match in registry first (e.g., "o1" stays "o1")
|
||||
registry_slugs = llm_registry.get_dynamic_model_slugs()
|
||||
if slug in registry_slugs:
|
||||
return cls(slug)
|
||||
|
||||
# If no exact match, try inserting hyphen between letter and digit
|
||||
# e.g., gpt4o -> gpt-4o
|
||||
transformed_slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
|
||||
return cls(transformed_slug)
|
||||
|
||||
def __iter__(cls):
|
||||
"""Iterate over all models from the registry.
|
||||
|
||||
Yields LlmModel instances for each model in the dynamic registry.
|
||||
Used by __get_pydantic_json_schema__ to build model metadata.
|
||||
"""
|
||||
for model in llm_registry.iter_dynamic_models():
|
||||
yield cls(model.slug)
|
||||
class LlmModelMeta(EnumMeta):
|
||||
pass
|
||||
|
||||
|
||||
class LlmModel(str, metaclass=LlmModelMeta):
|
||||
"""
|
||||
Dynamic LLM model type that accepts any model slug from the registry.
|
||||
|
||||
This is a string subclass (not an Enum) that allows any model slug value.
|
||||
All models are managed via the LLM Registry in the database.
|
||||
|
||||
Usage:
|
||||
model = LlmModel("gpt-4o") # Direct construction
|
||||
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
|
||||
model.value # Returns the slug string
|
||||
model.provider # Returns the provider from registry
|
||||
"""
|
||||
|
||||
def __new__(cls, value: str):
|
||||
if isinstance(value, LlmModel):
|
||||
return value
|
||||
return str.__new__(cls, value)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""
|
||||
Tell Pydantic how to validate LlmModel.
|
||||
|
||||
Accepts strings and converts them to LlmModel instances.
|
||||
"""
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls, # The validator function (LlmModel constructor)
|
||||
core_schema.str_schema(), # Accept string input
|
||||
serialization=core_schema.to_string_ser_schema(), # Serialize as string
|
||||
)
|
||||
|
||||
@property
|
||||
def value(self) -> str:
|
||||
"""Return the model slug (for compatibility with enum-style access)."""
|
||||
return str(self)
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "LlmModel":
|
||||
"""
|
||||
Get the default model from the registry.
|
||||
|
||||
Returns the recommended model if set, otherwise gpt-4o if available
|
||||
and enabled, otherwise the first enabled model from the registry.
|
||||
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
|
||||
"""
|
||||
from backend.data.llm_registry import get_default_model_slug
|
||||
|
||||
slug = get_default_model_slug()
|
||||
if slug is None:
|
||||
# Registry is empty (e.g., at module import time before DB connection).
|
||||
# Fall back to gpt-4o for backward compatibility.
|
||||
slug = "gpt-4o"
|
||||
return cls(slug)
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
|
||||
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
||||
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
|
||||
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
|
||||
# Groq models
|
||||
LLAMA3_3_70B = "llama-3.3-70b-versatile"
|
||||
LLAMA3_1_8B = "llama-3.1-8b-instant"
|
||||
# Ollama models
|
||||
OLLAMA_LLAMA3_3 = "llama3.3"
|
||||
OLLAMA_LLAMA3_2 = "llama3.2"
|
||||
OLLAMA_LLAMA3_8B = "llama3"
|
||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
QWEN3_CODER = "qwen/qwen3-coder"
|
||||
# Llama API models
|
||||
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
|
||||
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
|
||||
# v0 by Vercel models
|
||||
V0_1_5_MD = "v0-1.5-md"
|
||||
V0_1_5_LG = "v0-1.5-lg"
|
||||
V0_1_0_MD = "v0-1.0-md"
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, schema, handler):
|
||||
@@ -193,15 +181,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
||||
llm_model_metadata = {}
|
||||
for model in cls:
|
||||
model_name = model.value
|
||||
# Skip disabled models - only show enabled models in the picker
|
||||
if not llm_registry.is_model_enabled(model_name):
|
||||
continue
|
||||
# Use registry directly with None check to gracefully handle
|
||||
# missing metadata during startup/import before registry is populated
|
||||
metadata = llm_registry.get_llm_model_metadata(model_name)
|
||||
if metadata is None:
|
||||
# Skip models without metadata (registry not yet populated)
|
||||
continue
|
||||
metadata = model.metadata
|
||||
llm_model_metadata[model_name] = {
|
||||
"creator": metadata.creator_name,
|
||||
"creator_name": metadata.creator_name,
|
||||
@@ -217,12 +197,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
raise ValueError(
|
||||
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
|
||||
)
|
||||
return MODEL_METADATA[self]
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
@@ -237,11 +212,300 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
||||
return self.metadata.max_output_tokens
|
||||
|
||||
|
||||
# MODEL_METADATA removed - all models now come from the database via llm_registry
|
||||
MODEL_METADATA = {
|
||||
# https://platform.openai.com/docs/models
|
||||
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2),
|
||||
LlmModel.O3_MINI: ModelMetadata(
|
||||
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1
|
||||
), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata(
|
||||
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
|
||||
), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata(
|
||||
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
|
||||
), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3
|
||||
),
|
||||
LlmModel.GPT5_1: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2
|
||||
),
|
||||
LlmModel.GPT5: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_MINI: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_NANO: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata(
|
||||
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
|
||||
),
|
||||
LlmModel.GPT41: ModelMetadata(
|
||||
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT41_MINI: ModelMetadata(
|
||||
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1
|
||||
), # gpt-4o-mini-2024-07-18
|
||||
LlmModel.GPT4O: ModelMetadata(
|
||||
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
|
||||
), # gpt-4o-2024-08-06
|
||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
||||
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
|
||||
), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3
|
||||
), # claude-4-opus-20250514
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-5-20251101
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||
), # claude-haiku-4-5-20251001
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||
), # claude-3-haiku-20240307
|
||||
# https://docs.aimlapi.com/api-overview/model-database/text-models
|
||||
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata(
|
||||
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata(
|
||||
"aiml_api",
|
||||
128000,
|
||||
40000,
|
||||
"Llama 3.1 Nemotron 70B Instruct",
|
||||
"AI/ML",
|
||||
"Nvidia",
|
||||
1,
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
|
||||
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
|
||||
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
|
||||
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
# https://console.groq.com/docs/models
|
||||
LlmModel.LLAMA3_3_70B: ModelMetadata(
|
||||
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1
|
||||
),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata(
|
||||
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
|
||||
),
|
||||
# https://ollama.com/library
|
||||
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
|
||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||
),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1050000,
|
||||
8192,
|
||||
"Gemini 2.5 Pro Preview 03.25",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65535,
|
||||
"Gemini 2.5 Flash Lite Preview 06.17",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
8192,
|
||||
"Gemini 2.0 Flash Lite 001",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
|
||||
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
|
||||
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
16000,
|
||||
"Sonar Deep Research",
|
||||
"OpenRouter",
|
||||
"Perplexity",
|
||||
3,
|
||||
),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
||||
"open_router",
|
||||
131000,
|
||||
4096,
|
||||
"Hermes 3 Llama 3.1 405B",
|
||||
"OpenRouter",
|
||||
"Nous Research",
|
||||
1,
|
||||
),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
||||
"open_router",
|
||||
12288,
|
||||
12288,
|
||||
"Hermes 3 Llama 3.1 70B",
|
||||
"OpenRouter",
|
||||
"Nous Research",
|
||||
1,
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
|
||||
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
|
||||
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
|
||||
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
|
||||
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
|
||||
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||
),
|
||||
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.GROK_4: ModelMetadata(
|
||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||
),
|
||||
LlmModel.GROK_4_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_4_1_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
|
||||
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.KIMI_K2: ModelMetadata(
|
||||
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
|
||||
),
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
262144,
|
||||
"Qwen 3 235B A22B Thinking 2507",
|
||||
"OpenRouter",
|
||||
"Qwen",
|
||||
1,
|
||||
),
|
||||
LlmModel.QWEN3_CODER: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
|
||||
),
|
||||
# Llama API models
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"llama_api",
|
||||
128000,
|
||||
4028,
|
||||
"Llama 4 Scout 17B 16E Instruct FP8",
|
||||
"Llama API",
|
||||
"Meta",
|
||||
1,
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
|
||||
"llama_api",
|
||||
128000,
|
||||
4028,
|
||||
"Llama 4 Maverick 17B 128E Instruct FP8",
|
||||
"Llama API",
|
||||
"Meta",
|
||||
1,
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
|
||||
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
|
||||
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
|
||||
),
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1),
|
||||
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1),
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1),
|
||||
}
|
||||
|
||||
# Default model constant for backward compatibility
|
||||
# Uses the dynamic registry to get the default model
|
||||
DEFAULT_LLM_MODEL = LlmModel.default()
|
||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
@@ -334,10 +598,7 @@ def get_parallel_tool_calls_param(
|
||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||
):
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
# Check for o-series models (o1, o1-mini, o3-mini, etc.) which don't support
|
||||
# parallel tool calls. Use regex to avoid false positives like "openai/gpt-oss".
|
||||
is_o_series = re.match(r"^o\d", llm_model) is not None
|
||||
if is_o_series or parallel_tool_calls is None:
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.NOT_GIVEN
|
||||
return parallel_tool_calls
|
||||
|
||||
@@ -373,98 +634,19 @@ async def llm_call(
|
||||
- prompt_tokens: The number of tokens used in the prompt.
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
# Get model metadata and check if enabled - with fallback support
|
||||
# The model we'll actually use (may differ if original is disabled)
|
||||
model_to_use = llm_model.value
|
||||
|
||||
# Check if model is in registry and if it's enabled
|
||||
from backend.data.llm_registry import (
|
||||
get_fallback_model_for_disabled,
|
||||
get_model_info,
|
||||
)
|
||||
|
||||
model_info = get_model_info(llm_model.value)
|
||||
|
||||
if model_info and not model_info.is_enabled:
|
||||
# Model is disabled - try to find a fallback from the same provider
|
||||
fallback = get_fallback_model_for_disabled(llm_model.value)
|
||||
if fallback:
|
||||
logger.warning(
|
||||
f"Model '{llm_model.value}' is disabled. Using fallback model '{fallback.slug}' from the same provider ({fallback.metadata.provider})."
|
||||
)
|
||||
model_to_use = fallback.slug
|
||||
# Use fallback model's metadata
|
||||
provider = fallback.metadata.provider
|
||||
context_window = fallback.metadata.context_window
|
||||
model_max_output = fallback.metadata.max_output_tokens or int(2**15)
|
||||
else:
|
||||
# No fallback available - raise error
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' is disabled and no fallback model "
|
||||
f"from the same provider is available. Please enable the model or "
|
||||
f"select a different model in the block configuration."
|
||||
)
|
||||
else:
|
||||
# Model is enabled or not in registry (legacy/static model)
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
except ValueError:
|
||||
# Model not in cache - try refreshing the registry once if we have DB access
|
||||
logger.warning(f"Model {llm_model.value} not found in registry cache")
|
||||
|
||||
# Try refreshing the registry if we have database access
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
try:
|
||||
logger.info(
|
||||
f"Refreshing LLM registry and retrying lookup for {llm_model.value}"
|
||||
)
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Try again after refresh
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
logger.info(
|
||||
f"Successfully loaded model {llm_model.value} metadata after registry refresh"
|
||||
)
|
||||
except ValueError:
|
||||
# Still not found after refresh
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry after refresh. "
|
||||
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
|
||||
)
|
||||
except Exception as refresh_exc:
|
||||
logger.error(f"Failed to refresh LLM registry: {refresh_exc}")
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
|
||||
"Please ensure the model is added to the LLM registry via the admin UI."
|
||||
) from refresh_exc
|
||||
else:
|
||||
# No DB access (e.g., in executor without direct DB connection)
|
||||
# The registry should have been loaded on startup
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry cache. "
|
||||
"The registry may need to be refreshed. Please contact support or try again later."
|
||||
)
|
||||
|
||||
# Create effective model for model-specific parameter resolution (e.g., o-series check)
|
||||
# This uses the resolved model_to_use which may differ from llm_model if fallback occurred
|
||||
effective_model = LlmModel(model_to_use)
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
prompt = compress_prompt(
|
||||
messages=prompt,
|
||||
target_tokens=context_window // 2,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
lossy_ok=True,
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
# model_max_output already set above
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
@@ -475,14 +657,14 @@ async def llm_call(
|
||||
response_format = None
|
||||
|
||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
@@ -529,7 +711,7 @@ async def llm_call(
|
||||
)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
@@ -593,7 +775,7 @@ async def llm_call(
|
||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if force_json_output else None
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
@@ -615,7 +797,7 @@ async def llm_call(
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = await client.generate(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
options={"num_ctx": max_tokens},
|
||||
@@ -637,7 +819,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
@@ -645,7 +827,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
@@ -679,7 +861,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
@@ -687,7 +869,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
@@ -714,7 +896,7 @@ async def llm_call(
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "aiml_api":
|
||||
client = openai.AsyncOpenAI(
|
||||
client = openai.OpenAI(
|
||||
base_url="https://api.aimlapi.com/v2",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
default_headers={
|
||||
@@ -724,8 +906,8 @@ async def llm_call(
|
||||
},
|
||||
)
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
completion = client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@@ -753,11 +935,11 @@ async def llm_call(
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
@@ -808,10 +990,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
@@ -874,7 +1055,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -1240,10 +1421,9 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
@@ -1337,9 +1517,8 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for summarizing the text.",
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
title="Focus",
|
||||
@@ -1555,9 +1734,8 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for the conversation.",
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
max_tokens: int | None = SchemaField(
|
||||
@@ -1594,7 +1772,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1657,10 +1835,9 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
max_retries: int = SchemaField(
|
||||
@@ -1715,7 +1892,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
|
||||
@@ -226,10 +226,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=llm.LlmModel.default,
|
||||
default=llm.DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm.llm_model_schema_extra(),
|
||||
)
|
||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||
multiple_tool_calls: bool = SchemaField(
|
||||
|
||||
@@ -10,13 +10,13 @@ import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
ModelMetadata,
|
||||
)
|
||||
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
||||
from backend.data import llm_registry
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
Returns the provider name for the model in the required format for Stagehand:
|
||||
provider/model_name
|
||||
"""
|
||||
model_metadata = self.metadata
|
||||
model_metadata = MODEL_METADATA[LlmModel(self.value)]
|
||||
model_name = self.value
|
||||
|
||||
if len(model_name.split("/")) == 1 and not self.value.startswith(
|
||||
@@ -107,23 +107,19 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self.metadata.provider
|
||||
return MODEL_METADATA[LlmModel(self.value)].provider
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
# Fallback to LlmModel enum if registry lookup fails
|
||||
return LlmModel(self.value).metadata
|
||||
return MODEL_METADATA[LlmModel(self.value)]
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return self.metadata.context_window
|
||||
return MODEL_METADATA[LlmModel(self.value)].context_window
|
||||
|
||||
@property
|
||||
def max_output_tokens(self) -> int | None:
|
||||
return self.metadata.max_output_tokens
|
||||
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
|
||||
|
||||
|
||||
class StagehandObserveBlock(Block):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
@@ -19,7 +19,7 @@ if not os.getenv("PRISMA_DEBUG"):
|
||||
prisma_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def server():
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -27,7 +27,7 @@ async def server():
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||
async def graph_cleanup(server):
|
||||
created_graph_ids = []
|
||||
original_create_graph = server.agent_server.test_create_graph
|
||||
|
||||
@@ -25,7 +25,6 @@ from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.llm_registry import update_schema_with_llm_registry
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
@@ -144,59 +143,35 @@ class BlockInfo(BaseModel):
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
|
||||
|
||||
@classmethod
|
||||
def clear_schema_cache(cls) -> None:
|
||||
"""Clear the cached JSON schema for this class."""
|
||||
# Use None instead of {} because {} is truthy and would prevent regeneration
|
||||
cls.cached_jsonschema = None # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def clear_all_schema_caches() -> None:
|
||||
"""Clear cached JSON schemas for all BlockSchema subclasses."""
|
||||
|
||||
def clear_recursive(cls: type) -> None:
|
||||
"""Recursively clear cache for class and all subclasses."""
|
||||
if hasattr(cls, "clear_schema_cache"):
|
||||
cls.clear_schema_cache()
|
||||
for subclass in cls.__subclasses__():
|
||||
clear_recursive(subclass)
|
||||
|
||||
clear_recursive(BlockSchema)
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def jsonschema(cls) -> dict[str, Any]:
|
||||
# Generate schema if not cached
|
||||
if not cls.cached_jsonschema:
|
||||
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||
if cls.cached_jsonschema:
|
||||
return cls.cached_jsonschema
|
||||
|
||||
def ref_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||
keys = {"allOf", "anyOf", "oneOf"}
|
||||
one_key = next(
|
||||
(k for k in keys if k in obj and len(obj[k]) == 1), None
|
||||
)
|
||||
if one_key:
|
||||
obj.update(obj[one_key][0])
|
||||
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||
|
||||
return {
|
||||
key: ref_to_dict(value)
|
||||
for key, value in obj.items()
|
||||
if not key.startswith("$") and key != one_key
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
def ref_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||
keys = {"allOf", "anyOf", "oneOf"}
|
||||
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
||||
if one_key:
|
||||
obj.update(obj[one_key][0])
|
||||
|
||||
return obj
|
||||
return {
|
||||
key: ref_to_dict(value)
|
||||
for key, value in obj.items()
|
||||
if not key.startswith("$") and key != one_key
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
return obj
|
||||
|
||||
# Always post-process to ensure LLM registry data is up-to-date
|
||||
# This refreshes model options and discriminator mappings even if schema was cached
|
||||
update_schema_with_llm_registry(cls.cached_jsonschema, cls)
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
|
||||
return cls.cached_jsonschema
|
||||
|
||||
@@ -259,7 +234,7 @@ class BlockSchema(BaseModel):
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||
cls.cached_jsonschema = None
|
||||
cls.cached_jsonschema = {}
|
||||
|
||||
credentials_fields = cls.get_credentials_fields()
|
||||
|
||||
@@ -466,6 +441,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
static_output: bool = False,
|
||||
block_type: BlockType = BlockType.STANDARD,
|
||||
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||
is_sensitive_action: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the block with the given schema.
|
||||
@@ -498,8 +474,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.static_output = static_output
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.is_sensitive_action: bool = False
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -647,6 +623,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
input_data: BlockInput,
|
||||
*,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
@@ -673,11 +650,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
decision = await HITLReviewHelper.handle_review_decision(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
)
|
||||
@@ -896,28 +873,6 @@ def is_block_auth_configured(
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
# This ensures the registry cache is populated even in executor context
|
||||
try:
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
|
||||
# Only refresh if we have DB access (check if Prisma is connected)
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
logger.info("LLM registry refreshed during block initialization")
|
||||
else:
|
||||
logger.warning(
|
||||
"Prisma not connected, skipping LLM registry refresh during block initialization"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh LLM registry during block initialization: %s", exc
|
||||
)
|
||||
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
|
||||
@@ -24,18 +23,19 @@ from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AIConversationBlock,
|
||||
AIListGeneratorBlock,
|
||||
AIStructuredResponseGeneratorBlock,
|
||||
AITextGeneratorBlock,
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
@@ -55,63 +55,210 @@ from backend.integrations.credentials_store import (
|
||||
v0_credentials,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
|
||||
PROVIDER_CREDENTIALS = {
|
||||
"openai": openai_credentials,
|
||||
"anthropic": anthropic_credentials,
|
||||
"groq": groq_credentials,
|
||||
"open_router": open_router_credentials,
|
||||
"llama_api": llama_api_credentials,
|
||||
"aiml_api": aiml_api_credentials,
|
||||
"v0": v0_credentials,
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2,
|
||||
LlmModel.O1: 16,
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: 6,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
LlmModel.GPT41: 2,
|
||||
LlmModel.GPT41_MINI: 1,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1,
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
|
||||
LlmModel.GEMINI_2_5_FLASH: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH: 1,
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.DEEPSEEK_R1_0528: 1,
|
||||
LlmModel.PERPLEXITY_SONAR: 1,
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.GROK_4: 9,
|
||||
LlmModel.GROK_4_FAST: 1,
|
||||
LlmModel.GROK_4_1_FAST: 1,
|
||||
LlmModel.GROK_CODE_FAST_1: 1,
|
||||
LlmModel.KIMI_K2: 1,
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: 1,
|
||||
LlmModel.QWEN3_CODER: 9,
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: 1,
|
||||
LlmModel.V0_1_5_LG: 2,
|
||||
LlmModel.V0_1_0_MD: 1,
|
||||
}
|
||||
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
# All LLM costs now come from the database via llm_registry
|
||||
|
||||
LLM_COST: list[BlockCost] = []
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_COST:
|
||||
raise ValueError(f"Missing MODEL_COST for model: {model}")
|
||||
|
||||
|
||||
def _build_llm_costs_from_registry() -> list[BlockCost]:
|
||||
"""Build BlockCost list from all models in the LLM registry."""
|
||||
costs: list[BlockCost] = []
|
||||
for model in llm_registry.iter_dynamic_models():
|
||||
for cost in model.costs:
|
||||
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
|
||||
if not credentials:
|
||||
logger.warning(
|
||||
"Skipping cost entry for %s due to unknown credentials provider %s",
|
||||
model.slug,
|
||||
cost.credential_provider,
|
||||
)
|
||||
continue
|
||||
cost_filter = {
|
||||
"model": model.slug,
|
||||
LLM_COST = (
|
||||
# Anthropic Models
|
||||
[
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": credentials.id,
|
||||
"provider": credentials.provider,
|
||||
"type": credentials.type,
|
||||
"id": anthropic_credentials.id,
|
||||
"provider": anthropic_credentials.provider,
|
||||
"type": anthropic_credentials.type,
|
||||
},
|
||||
}
|
||||
costs.append(
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter=cost_filter,
|
||||
cost_amount=cost.credit_cost,
|
||||
)
|
||||
)
|
||||
return costs
|
||||
|
||||
|
||||
def refresh_llm_costs() -> None:
|
||||
"""Refresh LLM costs from the registry. All costs now come from the database."""
|
||||
LLM_COST.clear()
|
||||
LLM_COST.extend(_build_llm_costs_from_registry())
|
||||
|
||||
|
||||
# Initial load will happen after registry is refreshed at startup
|
||||
# Don't call refresh_llm_costs() here - it will be called after registry refresh
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "anthropic"
|
||||
]
|
||||
# OpenAI Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": openai_credentials.id,
|
||||
"provider": openai_credentials.provider,
|
||||
"type": openai_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "openai"
|
||||
]
|
||||
# Groq Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {"id": groq_credentials.id},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "groq"
|
||||
]
|
||||
# Open Router Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": open_router_credentials.id,
|
||||
"provider": open_router_credentials.provider,
|
||||
"type": open_router_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "open_router"
|
||||
]
|
||||
# Llama API Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": llama_api_credentials.id,
|
||||
"provider": llama_api_credentials.provider,
|
||||
"type": llama_api_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "llama_api"
|
||||
]
|
||||
# v0 by Vercel Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": v0_credentials.id,
|
||||
"provider": v0_credentials.provider,
|
||||
"type": v0_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "v0"
|
||||
]
|
||||
# AI/ML Api Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": aiml_api_credentials.id,
|
||||
"provider": aiml_api_credentials.provider,
|
||||
"type": aiml_api_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "aiml_api"
|
||||
]
|
||||
)
|
||||
|
||||
# =============== This is the exhaustive list of cost for each Block =============== #
|
||||
|
||||
|
||||
@@ -1511,10 +1511,8 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
llm_model_fields[block.id] = field_name
|
||||
|
||||
# Get all model slugs from the registry (dynamic, not hardcoded enum)
|
||||
from backend.data import llm_registry
|
||||
|
||||
enum_values = list(llm_registry.get_all_model_slugs_for_validation())
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel]
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
|
||||
# Update each block
|
||||
|
||||
@@ -6,10 +6,10 @@ Handles all database operations for pending human reviews.
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from prisma.models import PendingHumanReview
|
||||
from prisma.models import AgentNodeExecution, PendingHumanReview
|
||||
from prisma.types import PendingHumanReviewUpdateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -17,8 +17,12 @@ from backend.api.features.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
SafeJsonData,
|
||||
)
|
||||
from backend.data.execution import get_graph_execution_meta
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -32,6 +36,125 @@ class ReviewResult(BaseModel):
|
||||
node_exec_id: str
|
||||
|
||||
|
||||
def get_auto_approve_key(graph_exec_id: str, node_id: str) -> str:
|
||||
"""Generate the special nodeExecId key for auto-approval records."""
|
||||
return f"auto_approve_{graph_exec_id}_{node_id}"
|
||||
|
||||
|
||||
async def check_approval(
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
input_data: SafeJsonData | None = None,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Check if there's an existing approval for this node execution.
|
||||
|
||||
Checks both:
|
||||
1. Normal approval by node_exec_id (previous run of the same node execution)
|
||||
2. Auto-approval by special key pattern "auto_approve_{graph_exec_id}_{node_id}"
|
||||
|
||||
Args:
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
node_id: ID of the node definition (not execution)
|
||||
user_id: ID of the user (for data isolation)
|
||||
input_data: Current input data (used for auto-approvals to avoid stale data)
|
||||
|
||||
Returns:
|
||||
ReviewResult if approval found (either normal or auto), None otherwise
|
||||
"""
|
||||
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
|
||||
|
||||
# Check for either normal approval or auto-approval in a single query
|
||||
existing_review = await PendingHumanReview.prisma().find_first(
|
||||
where={
|
||||
"OR": [
|
||||
{"nodeExecId": node_exec_id},
|
||||
{"nodeExecId": auto_approve_key},
|
||||
],
|
||||
"status": ReviewStatus.APPROVED,
|
||||
"userId": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
if existing_review:
|
||||
is_auto_approval = existing_review.nodeExecId == auto_approve_key
|
||||
logger.info(
|
||||
f"Found {'auto-' if is_auto_approval else ''}approval for node {node_id} "
|
||||
f"(exec: {node_exec_id}) in execution {graph_exec_id}"
|
||||
)
|
||||
# For auto-approvals, use current input_data to avoid replaying stale payload
|
||||
# For normal approvals, use the stored payload (which may have been edited)
|
||||
return ReviewResult(
|
||||
data=(
|
||||
input_data
|
||||
if is_auto_approval and input_data is not None
|
||||
else existing_review.payload
|
||||
),
|
||||
status=ReviewStatus.APPROVED,
|
||||
message=(
|
||||
"Auto-approved (user approved all future actions for this node)"
|
||||
if is_auto_approval
|
||||
else existing_review.reviewMessage or ""
|
||||
),
|
||||
processed=True,
|
||||
node_exec_id=existing_review.nodeExecId,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def create_auto_approval_record(
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_id: str,
|
||||
payload: SafeJsonData,
|
||||
) -> None:
|
||||
"""
|
||||
Create an auto-approval record for a node in this execution.
|
||||
|
||||
This is stored as a PendingHumanReview with a special nodeExecId pattern
|
||||
and status=APPROVED, so future executions of the same node can skip review.
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph execution doesn't belong to the user
|
||||
"""
|
||||
# Validate that the graph execution belongs to this user (defense in depth)
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise ValueError(
|
||||
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||
)
|
||||
|
||||
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
|
||||
|
||||
await PendingHumanReview.prisma().upsert(
|
||||
where={"nodeExecId": auto_approve_key},
|
||||
data={
|
||||
"create": {
|
||||
"nodeExecId": auto_approve_key,
|
||||
"userId": user_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(payload),
|
||||
"instructions": "Auto-approval record",
|
||||
"editable": False,
|
||||
"status": ReviewStatus.APPROVED,
|
||||
"processed": True,
|
||||
"reviewedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
"update": {}, # Already exists, no update needed
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_or_create_human_review(
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
@@ -108,6 +231,87 @@ async def get_or_create_human_review(
|
||||
)
|
||||
|
||||
|
||||
async def get_pending_review_by_node_exec_id(
|
||||
node_exec_id: str, user_id: str
|
||||
) -> Optional["PendingHumanReviewModel"]:
|
||||
"""
|
||||
Get a pending review by its node execution ID.
|
||||
|
||||
Args:
|
||||
node_exec_id: The node execution ID to look up
|
||||
user_id: User ID for authorization (only returns if review belongs to this user)
|
||||
|
||||
Returns:
|
||||
The pending review if found and belongs to user, None otherwise
|
||||
"""
|
||||
review = await PendingHumanReview.prisma().find_first(
|
||||
where={
|
||||
"nodeExecId": node_exec_id,
|
||||
"userId": user_id,
|
||||
"status": ReviewStatus.WAITING,
|
||||
}
|
||||
)
|
||||
|
||||
if not review:
|
||||
return None
|
||||
|
||||
# Local import to avoid event loop conflicts in tests
|
||||
from backend.data.execution import get_node_execution
|
||||
|
||||
node_exec = await get_node_execution(review.nodeExecId)
|
||||
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||
return PendingHumanReviewModel.from_db(review, node_id=node_id)
|
||||
|
||||
|
||||
async def get_pending_reviews_by_node_exec_ids(
|
||||
node_exec_ids: list[str], user_id: str
|
||||
) -> dict[str, "PendingHumanReviewModel"]:
|
||||
"""
|
||||
Get multiple pending reviews by their node execution IDs in a single batch query.
|
||||
|
||||
Args:
|
||||
node_exec_ids: List of node execution IDs to look up
|
||||
user_id: User ID for authorization (only returns reviews belonging to this user)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping node_exec_id -> PendingHumanReviewModel for found reviews
|
||||
"""
|
||||
if not node_exec_ids:
|
||||
return {}
|
||||
|
||||
reviews = await PendingHumanReview.prisma().find_many(
|
||||
where={
|
||||
"nodeExecId": {"in": node_exec_ids},
|
||||
"userId": user_id,
|
||||
"status": ReviewStatus.WAITING,
|
||||
}
|
||||
)
|
||||
|
||||
if not reviews:
|
||||
return {}
|
||||
|
||||
# Batch fetch all node executions to avoid N+1 queries
|
||||
node_exec_ids_to_fetch = [review.nodeExecId for review in reviews]
|
||||
node_execs = await AgentNodeExecution.prisma().find_many(
|
||||
where={"id": {"in": node_exec_ids_to_fetch}},
|
||||
include={"Node": True},
|
||||
)
|
||||
|
||||
# Create mapping from node_exec_id to node_id
|
||||
node_exec_id_to_node_id = {
|
||||
node_exec.id: node_exec.agentNodeId for node_exec in node_execs
|
||||
}
|
||||
|
||||
result = {}
|
||||
for review in reviews:
|
||||
node_id = node_exec_id_to_node_id.get(review.nodeExecId, review.nodeExecId)
|
||||
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
|
||||
review, node_id=node_id
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
||||
"""
|
||||
Check if a graph execution has any pending reviews.
|
||||
@@ -137,8 +341,11 @@ async def get_pending_reviews_for_user(
|
||||
page_size: Number of reviews per page
|
||||
|
||||
Returns:
|
||||
List of pending review models
|
||||
List of pending review models with node_id included
|
||||
"""
|
||||
# Local import to avoid event loop conflicts in tests
|
||||
from backend.data.execution import get_node_execution
|
||||
|
||||
# Calculate offset for pagination
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
@@ -149,7 +356,14 @@ async def get_pending_reviews_for_user(
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
return [PendingHumanReviewModel.from_db(review) for review in reviews]
|
||||
# Fetch node_id for each review from NodeExecution
|
||||
result = []
|
||||
for review in reviews:
|
||||
node_exec = await get_node_execution(review.nodeExecId)
|
||||
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_pending_reviews_for_execution(
|
||||
@@ -163,8 +377,11 @@ async def get_pending_reviews_for_execution(
|
||||
user_id: User ID for security validation
|
||||
|
||||
Returns:
|
||||
List of pending review models
|
||||
List of pending review models with node_id included
|
||||
"""
|
||||
# Local import to avoid event loop conflicts in tests
|
||||
from backend.data.execution import get_node_execution
|
||||
|
||||
reviews = await PendingHumanReview.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
@@ -174,7 +391,14 @@ async def get_pending_reviews_for_execution(
|
||||
order={"createdAt": "asc"},
|
||||
)
|
||||
|
||||
return [PendingHumanReviewModel.from_db(review) for review in reviews]
|
||||
# Fetch node_id for each review from NodeExecution
|
||||
result = []
|
||||
for review in reviews:
|
||||
node_exec = await get_node_execution(review.nodeExecId)
|
||||
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def process_all_reviews_for_execution(
|
||||
@@ -244,11 +468,19 @@ async def process_all_reviews_for_execution(
|
||||
# Note: Execution resumption is now handled at the API layer after ALL reviews
|
||||
# for an execution are processed (both approved and rejected)
|
||||
|
||||
# Return as dict for easy access
|
||||
return {
|
||||
review.nodeExecId: PendingHumanReviewModel.from_db(review)
|
||||
for review in updated_reviews
|
||||
}
|
||||
# Fetch node_id for each review and return as dict for easy access
|
||||
# Local import to avoid event loop conflicts in tests
|
||||
from backend.data.execution import get_node_execution
|
||||
|
||||
result = {}
|
||||
for review in updated_reviews:
|
||||
node_exec = await get_node_execution(review.nodeExecId)
|
||||
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
|
||||
review, node_id=node_id
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def update_review_processed_status(node_exec_id: str, processed: bool) -> None:
|
||||
@@ -256,3 +488,44 @@ async def update_review_processed_status(node_exec_id: str, processed: bool) ->
|
||||
await PendingHumanReview.prisma().update(
|
||||
where={"nodeExecId": node_exec_id}, data={"processed": processed}
|
||||
)
|
||||
|
||||
|
||||
async def cancel_pending_reviews_for_execution(graph_exec_id: str, user_id: str) -> int:
|
||||
"""
|
||||
Cancel all pending reviews for a graph execution (e.g., when execution is stopped).
|
||||
|
||||
Marks all WAITING reviews as REJECTED with a message indicating the execution was stopped.
|
||||
|
||||
Args:
|
||||
graph_exec_id: The graph execution ID
|
||||
user_id: User ID who owns the execution (for security validation)
|
||||
|
||||
Returns:
|
||||
Number of reviews cancelled
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph execution doesn't belong to the user
|
||||
"""
|
||||
# Validate user ownership before cancelling reviews
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise ValueError(
|
||||
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||
)
|
||||
|
||||
result = await PendingHumanReview.prisma().update_many(
|
||||
where={
|
||||
"graphExecId": graph_exec_id,
|
||||
"userId": user_id,
|
||||
"status": ReviewStatus.WAITING,
|
||||
},
|
||||
data={
|
||||
"status": ReviewStatus.REJECTED,
|
||||
"reviewMessage": "Execution was stopped by user",
|
||||
"processed": True,
|
||||
"reviewedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -36,7 +36,7 @@ def sample_db_review():
|
||||
return mock_review
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_get_or_create_human_review_new(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
sample_db_review,
|
||||
@@ -46,8 +46,8 @@ async def test_get_or_create_human_review_new(
|
||||
sample_db_review.status = ReviewStatus.WAITING
|
||||
sample_db_review.processed = False
|
||||
|
||||
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
|
||||
result = await get_or_create_human_review(
|
||||
user_id="test-user-123",
|
||||
@@ -64,7 +64,7 @@ async def test_get_or_create_human_review_new(
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_get_or_create_human_review_approved(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
sample_db_review,
|
||||
@@ -75,8 +75,8 @@ async def test_get_or_create_human_review_approved(
|
||||
sample_db_review.processed = False
|
||||
sample_db_review.reviewMessage = "Looks good"
|
||||
|
||||
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
|
||||
result = await get_or_create_human_review(
|
||||
user_id="test-user-123",
|
||||
@@ -96,7 +96,7 @@ async def test_get_or_create_human_review_approved(
|
||||
assert result.message == "Looks good"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_has_pending_reviews_for_graph_exec_true(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
@@ -109,7 +109,7 @@ async def test_has_pending_reviews_for_graph_exec_true(
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_has_pending_reviews_for_graph_exec_false(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
@@ -122,7 +122,7 @@ async def test_has_pending_reviews_for_graph_exec_false(
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_get_pending_reviews_for_user(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
sample_db_review,
|
||||
@@ -131,10 +131,19 @@ async def test_get_pending_reviews_for_user(
|
||||
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
|
||||
|
||||
# Mock get_node_execution to return node with node_id (async function)
|
||||
mock_node_exec = Mock()
|
||||
mock_node_exec.node_id = "test_node_def_789"
|
||||
mocker.patch(
|
||||
"backend.data.execution.get_node_execution",
|
||||
new=AsyncMock(return_value=mock_node_exec),
|
||||
)
|
||||
|
||||
result = await get_pending_reviews_for_user("test_user", page=2, page_size=10)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].node_exec_id == "test_node_123"
|
||||
assert result[0].node_id == "test_node_def_789"
|
||||
|
||||
# Verify pagination parameters
|
||||
call_args = mock_find_many.return_value.find_many.call_args
|
||||
@@ -142,7 +151,7 @@ async def test_get_pending_reviews_for_user(
|
||||
assert call_args.kwargs["take"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_get_pending_reviews_for_execution(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
sample_db_review,
|
||||
@@ -151,12 +160,21 @@ async def test_get_pending_reviews_for_execution(
|
||||
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
|
||||
|
||||
# Mock get_node_execution to return node with node_id (async function)
|
||||
mock_node_exec = Mock()
|
||||
mock_node_exec.node_id = "test_node_def_789"
|
||||
mocker.patch(
|
||||
"backend.data.execution.get_node_execution",
|
||||
new=AsyncMock(return_value=mock_node_exec),
|
||||
)
|
||||
|
||||
result = await get_pending_reviews_for_execution(
|
||||
"test_graph_exec_456", "test-user-123"
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].graph_exec_id == "test_graph_exec_456"
|
||||
assert result[0].node_id == "test_node_def_789"
|
||||
|
||||
# Verify it filters by execution and user
|
||||
call_args = mock_find_many.return_value.find_many.call_args
|
||||
@@ -166,7 +184,7 @@ async def test_get_pending_reviews_for_execution(
|
||||
assert where_clause["status"] == ReviewStatus.WAITING
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_process_all_reviews_for_execution_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
sample_db_review,
|
||||
@@ -201,6 +219,14 @@ async def test_process_all_reviews_for_execution_success(
|
||||
new=AsyncMock(return_value=[updated_review]),
|
||||
)
|
||||
|
||||
# Mock get_node_execution to return node with node_id (async function)
|
||||
mock_node_exec = Mock()
|
||||
mock_node_exec.node_id = "test_node_def_789"
|
||||
mocker.patch(
|
||||
"backend.data.execution.get_node_execution",
|
||||
new=AsyncMock(return_value=mock_node_exec),
|
||||
)
|
||||
|
||||
result = await process_all_reviews_for_execution(
|
||||
user_id="test-user-123",
|
||||
review_decisions={
|
||||
@@ -211,9 +237,10 @@ async def test_process_all_reviews_for_execution_success(
|
||||
assert len(result) == 1
|
||||
assert "test_node_123" in result
|
||||
assert result["test_node_123"].status == ReviewStatus.APPROVED
|
||||
assert result["test_node_123"].node_id == "test_node_def_789"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_process_all_reviews_for_execution_validation_errors(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
@@ -233,7 +260,7 @@ async def test_process_all_reviews_for_execution_validation_errors(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_process_all_reviews_edit_permission_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
sample_db_review,
|
||||
@@ -259,7 +286,7 @@ async def test_process_all_reviews_edit_permission_error(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_process_all_reviews_mixed_approval_rejection(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
sample_db_review,
|
||||
@@ -329,6 +356,14 @@ async def test_process_all_reviews_mixed_approval_rejection(
|
||||
new=AsyncMock(return_value=[approved_review, rejected_review]),
|
||||
)
|
||||
|
||||
# Mock get_node_execution to return node with node_id (async function)
|
||||
mock_node_exec = Mock()
|
||||
mock_node_exec.node_id = "test_node_def_789"
|
||||
mocker.patch(
|
||||
"backend.data.execution.get_node_execution",
|
||||
new=AsyncMock(return_value=mock_node_exec),
|
||||
)
|
||||
|
||||
result = await process_all_reviews_for_execution(
|
||||
user_id="test-user-123",
|
||||
review_decisions={
|
||||
@@ -340,3 +375,5 @@ async def test_process_all_reviews_mixed_approval_rejection(
|
||||
assert len(result) == 2
|
||||
assert "test_node_123" in result
|
||||
assert "test_node_456" in result
|
||||
assert result["test_node_123"].node_id == "test_node_def_789"
|
||||
assert result["test_node_456"].node_id == "test_node_def_789"
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
"""
|
||||
LLM Registry module for managing LLM models, providers, and costs dynamically.
|
||||
|
||||
This module provides a database-driven registry system for LLM models,
|
||||
replacing hardcoded model configurations with a flexible admin-managed system.
|
||||
"""
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
# Re-export for backwards compatibility
|
||||
from backend.data.llm_registry.notifications import (
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
from backend.data.llm_registry.registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_default_model_slug,
|
||||
get_dynamic_model_slugs,
|
||||
get_fallback_model_for_disabled,
|
||||
get_llm_discriminator_mapping,
|
||||
get_llm_model_cost,
|
||||
get_llm_model_metadata,
|
||||
get_llm_model_schema_options,
|
||||
get_model_info,
|
||||
is_model_enabled,
|
||||
iter_dynamic_models,
|
||||
refresh_llm_registry,
|
||||
register_static_costs,
|
||||
register_static_metadata,
|
||||
)
|
||||
from backend.data.llm_registry.schema_utils import (
|
||||
is_llm_model_field,
|
||||
refresh_llm_discriminator_mapping,
|
||||
refresh_llm_model_options,
|
||||
update_schema_with_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Registry functions
|
||||
"get_all_model_slugs_for_validation",
|
||||
"get_default_model_slug",
|
||||
"get_dynamic_model_slugs",
|
||||
"get_fallback_model_for_disabled",
|
||||
"get_llm_discriminator_mapping",
|
||||
"get_llm_model_cost",
|
||||
"get_llm_model_metadata",
|
||||
"get_llm_model_schema_options",
|
||||
"get_model_info",
|
||||
"is_model_enabled",
|
||||
"iter_dynamic_models",
|
||||
"refresh_llm_registry",
|
||||
"register_static_costs",
|
||||
"register_static_metadata",
|
||||
# Notifications
|
||||
"REGISTRY_REFRESH_CHANNEL",
|
||||
"publish_registry_refresh_notification",
|
||||
"subscribe_to_registry_refresh",
|
||||
# Schema utilities
|
||||
"is_llm_model_field",
|
||||
"refresh_llm_discriminator_mapping",
|
||||
"refresh_llm_model_options",
|
||||
"update_schema_with_llm_registry",
|
||||
]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Type definitions for LLM model metadata."""
|
||||
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
|
||||
class ModelMetadata(NamedTuple):
|
||||
"""Metadata for an LLM model.
|
||||
|
||||
Attributes:
|
||||
provider: The provider identifier (e.g., "openai", "anthropic")
|
||||
context_window: Maximum context window size in tokens
|
||||
max_output_tokens: Maximum output tokens (None if unlimited)
|
||||
display_name: Human-readable name for the model
|
||||
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
|
||||
creator_name: Name of the organization that created the model
|
||||
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
|
||||
"""
|
||||
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
display_name: str
|
||||
provider_name: str
|
||||
creator_name: str
|
||||
price_tier: Literal[1, 2, 3]
|
||||
@@ -1,89 +0,0 @@
|
||||
"""
|
||||
Redis pub/sub notifications for LLM registry updates.
|
||||
|
||||
When models are added/updated/removed via the admin UI, this module
|
||||
publishes notifications to Redis that all executor services subscribe to,
|
||||
ensuring they refresh their registry cache in real-time.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis channel name for LLM registry refresh notifications
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
async def publish_registry_refresh_notification() -> None:
|
||||
"""
|
||||
Publish a notification to Redis that the LLM registry has been updated.
|
||||
All executor services subscribed to this channel will refresh their registry.
|
||||
"""
|
||||
try:
|
||||
redis = await connect_async()
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.info("Published LLM registry refresh notification to Redis")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to publish LLM registry refresh notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Any, # Async callable that takes no args
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to Redis notifications for LLM registry updates.
|
||||
This runs in a loop and processes messages as they arrive.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable to execute when a refresh notification is received
|
||||
"""
|
||||
try:
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info(
|
||||
"Subscribed to LLM registry refresh notifications on channel: %s",
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
)
|
||||
|
||||
# Process messages in a loop
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if (
|
||||
message
|
||||
and message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info("Received LLM registry refresh notification")
|
||||
try:
|
||||
await on_refresh()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Error refreshing LLM registry from notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Error processing registry refresh message: %s", exc, exc_info=True
|
||||
)
|
||||
# Continue listening even if one message fails
|
||||
await asyncio.sleep(1)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to subscribe to LLM registry refresh notifications: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
@@ -1,388 +0,0 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _json_to_dict(value: Any) -> dict[str, Any]:
|
||||
"""Convert Prisma Json type to dict, with fallback to empty dict."""
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
# Prisma Json type should always be a dict at runtime
|
||||
return dict(value) if value else {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCost:
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None
|
||||
credential_type: str | None
|
||||
currency: str | None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCreator:
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
website_url: str | None
|
||||
logo_url: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModel:
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any]
|
||||
extra_metadata: dict[str, Any]
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
|
||||
_static_metadata: dict[str, ModelMetadata] = {}
|
||||
_static_costs: dict[str, int] = {}
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_discriminator_mapping: dict[str, str] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
|
||||
"""Register static metadata for legacy models (deprecated)."""
|
||||
_static_metadata.update({str(key): value for key, value in metadata.items()})
|
||||
_refresh_cached_schema()
|
||||
|
||||
|
||||
def register_static_costs(costs: dict[Any, int]) -> None:
|
||||
"""Register static costs for legacy models (deprecated)."""
|
||||
_static_costs.update({str(key): value for key, value in costs.items()})
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
options: list[dict[str, str]] = []
|
||||
# Only include enabled models in the dropdown options
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
options.append(
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
)
|
||||
|
||||
for slug, metadata in _static_metadata.items():
|
||||
if slug in _dynamic_models:
|
||||
continue
|
||||
options.append(
|
||||
{
|
||||
"label": slug,
|
||||
"value": slug,
|
||||
"group": metadata.provider,
|
||||
"description": "",
|
||||
}
|
||||
)
|
||||
return options
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
|
||||
async with _lock:
|
||||
try:
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
"Creator": True,
|
||||
}
|
||||
)
|
||||
logger.debug("Found %d LLM model records in database", len(records))
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
dynamic: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
provider_display_name = (
|
||||
record.Provider.displayName if record.Provider else record.providerId
|
||||
)
|
||||
# Creator name: prefer Creator.name, fallback to provider display name
|
||||
creator_name = (
|
||||
record.Creator.name if record.Creator else provider_display_name
|
||||
)
|
||||
# Price tier: default to 1 (cheapest) if not set
|
||||
price_tier = getattr(record, "priceTier", 1) or 1
|
||||
# Clamp to valid range 1-3
|
||||
price_tier = max(1, min(3, price_tier))
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display_name,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier, # type: ignore[arg-type]
|
||||
)
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=_json_to_dict(cost.metadata),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
# Map creator if present
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
dynamic[record.slug] = RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=_json_to_dict(record.capabilities),
|
||||
extra_metadata=_json_to_dict(record.metadata),
|
||||
provider_display_name=(
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
),
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
)
|
||||
|
||||
# Atomic swap - build new structures then replace references
|
||||
# This ensures readers never see partially updated state
|
||||
global _dynamic_models
|
||||
_dynamic_models = dynamic
|
||||
_refresh_cached_schema()
|
||||
logger.info(
|
||||
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
|
||||
len(dynamic),
|
||||
sum(1 for m in dynamic.values() if m.is_enabled),
|
||||
sum(1 for m in dynamic.values() if not m.is_enabled),
|
||||
)
|
||||
|
||||
|
||||
def _refresh_cached_schema() -> None:
|
||||
"""Refresh cached schema options and discriminator mapping."""
|
||||
global _schema_options, _discriminator_mapping
|
||||
|
||||
# Build new structures
|
||||
new_options = _build_schema_options()
|
||||
new_mapping = {
|
||||
slug: entry.metadata.provider for slug, entry in _dynamic_models.items()
|
||||
}
|
||||
for slug, metadata in _static_metadata.items():
|
||||
new_mapping.setdefault(slug, metadata.provider)
|
||||
|
||||
# Atomic swap - replace references to ensure readers see consistent state
|
||||
_schema_options = new_options
|
||||
_discriminator_mapping = new_mapping
|
||||
|
||||
|
||||
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
|
||||
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
|
||||
if slug in _dynamic_models:
|
||||
return _dynamic_models[slug].metadata
|
||||
return _static_metadata.get(slug)
|
||||
|
||||
|
||||
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
|
||||
"""Get model cost configuration by slug."""
|
||||
if slug in _dynamic_models:
|
||||
return _dynamic_models[slug].costs
|
||||
cost_value = _static_costs.get(slug)
|
||||
if cost_value is None:
|
||||
return tuple()
|
||||
return (
|
||||
RegistryModelCost(
|
||||
credit_cost=cost_value,
|
||||
credential_provider="static",
|
||||
credential_id=None,
|
||||
credential_type=None,
|
||||
currency=None,
|
||||
metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_llm_model_schema_options() -> list[dict[str, str]]:
|
||||
"""
|
||||
Get schema options for LLM model selection dropdown.
|
||||
|
||||
Returns a copy of cached schema options that are refreshed when the registry is
|
||||
updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
|
||||
"""
|
||||
# Return a copy to prevent external mutation
|
||||
return list(_schema_options)
|
||||
|
||||
|
||||
def get_llm_discriminator_mapping() -> dict[str, str]:
|
||||
"""
|
||||
Get discriminator mapping for LLM models.
|
||||
|
||||
Returns a copy of cached discriminator mapping that is refreshed when the registry
|
||||
is updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
|
||||
"""
|
||||
# Return a copy to prevent external mutation
|
||||
return dict(_discriminator_mapping)
|
||||
|
||||
|
||||
def get_dynamic_model_slugs() -> set[str]:
|
||||
"""Get all dynamic model slugs from the registry."""
|
||||
return set(_dynamic_models.keys())
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> set[str]:
|
||||
"""
|
||||
Get ALL model slugs (both enabled and disabled) for validation purposes.
|
||||
|
||||
This is used for JSON schema enum validation - we need to accept any known
|
||||
model value (even disabled ones) so that existing graphs don't fail validation.
|
||||
The actual fallback/enforcement happens at runtime in llm_call().
|
||||
"""
|
||||
all_slugs = set(_dynamic_models.keys())
|
||||
all_slugs.update(_static_metadata.keys())
|
||||
return all_slugs
|
||||
|
||||
|
||||
def iter_dynamic_models() -> Iterable[RegistryModel]:
|
||||
"""Iterate over all dynamic models in the registry."""
|
||||
return tuple(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
|
||||
"""
|
||||
Find a fallback model when the requested model is disabled.
|
||||
|
||||
Looks for an enabled model from the same provider. Prefers models with
|
||||
similar names or capabilities if possible.
|
||||
|
||||
Args:
|
||||
disabled_model_slug: The slug of the disabled model
|
||||
|
||||
Returns:
|
||||
An enabled RegistryModel from the same provider, or None if no fallback found
|
||||
"""
|
||||
disabled_model = _dynamic_models.get(disabled_model_slug)
|
||||
if not disabled_model:
|
||||
return None
|
||||
|
||||
provider = disabled_model.metadata.provider
|
||||
|
||||
# Find all enabled models from the same provider
|
||||
candidates = [
|
||||
model
|
||||
for model in _dynamic_models.values()
|
||||
if model.is_enabled and model.metadata.provider == provider
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Sort by: prefer models with similar context window, then by name
|
||||
candidates.sort(
|
||||
key=lambda m: (
|
||||
abs(m.metadata.context_window - disabled_model.metadata.context_window),
|
||||
m.display_name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def is_model_enabled(model_slug: str) -> bool:
|
||||
"""Check if a model is enabled in the registry."""
|
||||
model = _dynamic_models.get(model_slug)
|
||||
if not model:
|
||||
# Model not in registry - assume it's a static/legacy model and allow it
|
||||
return True
|
||||
return model.is_enabled
|
||||
|
||||
|
||||
def get_model_info(model_slug: str) -> RegistryModel | None:
|
||||
"""Get model info from the registry."""
|
||||
return _dynamic_models.get(model_slug)
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""
|
||||
Get the default model slug to use for block defaults.
|
||||
|
||||
Returns the recommended model if set (configured via admin UI),
|
||||
otherwise returns the first enabled model alphabetically.
|
||||
Returns None if no models are available or enabled.
|
||||
"""
|
||||
# Return the recommended model if one is set and enabled
|
||||
for model in _dynamic_models.values():
|
||||
if model.is_recommended and model.is_enabled:
|
||||
return model.slug
|
||||
|
||||
# No recommended model set - find first enabled model alphabetically
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
logger.warning(
|
||||
"No recommended model set, using '%s' as default",
|
||||
model.slug,
|
||||
)
|
||||
return model.slug
|
||||
|
||||
# No enabled models available
|
||||
if _dynamic_models:
|
||||
logger.error(
|
||||
"No enabled models found in registry (%d models registered but all disabled)",
|
||||
len(_dynamic_models),
|
||||
)
|
||||
else:
|
||||
logger.error("No models registered in LLM registry")
|
||||
|
||||
return None
|
||||
@@ -1,130 +0,0 @@
|
||||
"""
|
||||
Helper utilities for LLM registry integration with block schemas.
|
||||
|
||||
This module handles the dynamic injection of discriminator mappings
|
||||
and model options from the LLM registry into block schemas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.llm_registry.registry import (
|
||||
get_all_model_slugs_for_validation,
|
||||
get_default_model_slug,
|
||||
get_llm_discriminator_mapping,
|
||||
get_llm_model_schema_options,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
|
||||
"""
|
||||
Check if a field is an LLM model selection field.
|
||||
|
||||
Returns True if the field has 'options' in json_schema_extra
|
||||
(set by llm_model_schema_extra() in blocks/llm.py).
|
||||
"""
|
||||
if not hasattr(field_info, "json_schema_extra"):
|
||||
return False
|
||||
|
||||
extra = field_info.json_schema_extra
|
||||
if isinstance(extra, dict):
|
||||
return "options" in extra
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Refresh LLM model options from the registry.
|
||||
|
||||
Updates 'options' (for frontend dropdown) to show only enabled models,
|
||||
but keeps the 'enum' (for validation) inclusive of ALL known models.
|
||||
|
||||
This is important because:
|
||||
- Options: What users see in the dropdown (enabled models only)
|
||||
- Enum: What values pass validation (all known models, including disabled)
|
||||
|
||||
Existing graphs may have disabled models selected - they should pass validation
|
||||
and the fallback logic in llm_call() will handle using an alternative model.
|
||||
"""
|
||||
fresh_options = get_llm_model_schema_options()
|
||||
if not fresh_options:
|
||||
return
|
||||
|
||||
# Update options array (UI dropdown) - only enabled models
|
||||
if "options" in field_schema:
|
||||
field_schema["options"] = fresh_options
|
||||
|
||||
all_known_slugs = get_all_model_slugs_for_validation()
|
||||
if all_known_slugs and "enum" in field_schema:
|
||||
existing_enum = set(field_schema.get("enum", []))
|
||||
combined_enum = existing_enum | all_known_slugs
|
||||
field_schema["enum"] = sorted(combined_enum)
|
||||
|
||||
# Set the default value from the registry (gpt-4o if available, else first enabled)
|
||||
# This ensures new blocks have a sensible default pre-selected
|
||||
default_slug = get_default_model_slug()
|
||||
if default_slug:
|
||||
field_schema["default"] = default_slug
|
||||
|
||||
|
||||
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Refresh discriminator_mapping for fields that use model-based discrimination.
|
||||
|
||||
The discriminator is already set when AICredentialsField() creates the field.
|
||||
We only need to refresh the mapping when models are added/removed.
|
||||
"""
|
||||
if field_schema.get("discriminator") != "model":
|
||||
return
|
||||
|
||||
# Always refresh the mapping to get latest models
|
||||
fresh_mapping = get_llm_discriminator_mapping()
|
||||
if fresh_mapping is not None:
|
||||
field_schema["discriminator_mapping"] = fresh_mapping
|
||||
|
||||
|
||||
def update_schema_with_llm_registry(
|
||||
schema: dict[str, Any], model_class: type | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update a JSON schema with current LLM registry data.
|
||||
|
||||
Refreshes:
|
||||
1. Model options for LLM model selection fields (dropdown choices)
|
||||
2. Discriminator mappings for credentials fields (model → provider)
|
||||
|
||||
Args:
|
||||
schema: The JSON schema to update (mutated in-place)
|
||||
model_class: The Pydantic model class (optional, for field introspection)
|
||||
"""
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
if not isinstance(field_schema, dict):
|
||||
continue
|
||||
|
||||
# Refresh model options for LLM model fields
|
||||
if model_class and hasattr(model_class, "model_fields"):
|
||||
field_info = model_class.model_fields.get(field_name)
|
||||
if field_info and is_llm_model_field(field_name, field_info):
|
||||
try:
|
||||
refresh_llm_model_options(field_schema)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh LLM options for field %s: %s",
|
||||
field_name,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Refresh discriminator mapping for fields that use model discrimination
|
||||
try:
|
||||
refresh_llm_discriminator_mapping(field_schema)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh discriminator mapping for field %s: %s",
|
||||
field_name,
|
||||
exc,
|
||||
)
|
||||
@@ -40,7 +40,6 @@ from pydantic_core import (
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.llm_registry import update_schema_with_llm_registry
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.settings import Secrets
|
||||
@@ -545,9 +544,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
else:
|
||||
schema["credentials_provider"] = allowed_providers
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
|
||||
# Ensure LLM discriminators are populated (delegates to shared helper)
|
||||
update_schema_with_llm_registry(schema, model_class)
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
@@ -696,20 +693,16 @@ def CredentialsField(
|
||||
This is enforced by the `BlockSchema` base class.
|
||||
"""
|
||||
|
||||
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
|
||||
field_schema_extra: dict[str, Any] = {}
|
||||
|
||||
# Always include discriminator if provided
|
||||
if discriminator is not None:
|
||||
field_schema_extra["discriminator"] = discriminator
|
||||
# Always include discriminator_mapping when discriminator is set (even if empty initially)
|
||||
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
|
||||
|
||||
# Include other optional fields (only if not None)
|
||||
if required_scopes:
|
||||
field_schema_extra["credentials_scopes"] = list(required_scopes)
|
||||
if discriminator_values:
|
||||
field_schema_extra["discriminator_values"] = discriminator_values
|
||||
field_schema_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
"discriminator_values": discriminator_values,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# Merge any json_schema_extra passed in kwargs
|
||||
if "json_schema_extra" in kwargs:
|
||||
|
||||
@@ -50,6 +50,8 @@ from backend.data.graph import (
|
||||
validate_graph_execution_permissions,
|
||||
)
|
||||
from backend.data.human_review import (
|
||||
cancel_pending_reviews_for_execution,
|
||||
check_approval,
|
||||
get_or_create_human_review,
|
||||
has_pending_reviews_for_graph_exec,
|
||||
update_review_processed_status,
|
||||
@@ -190,6 +192,8 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Human In The Loop
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||
update_review_processed_status = _(update_review_processed_status)
|
||||
@@ -313,6 +317,8 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
# Human In The Loop
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
update_review_processed_status = d.update_review_processed_status
|
||||
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
Helper functions for LLM registry initialization in executor context.
|
||||
|
||||
These functions handle refreshing the LLM registry when the executor starts
|
||||
and subscribing to real-time updates via Redis pub/sub.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from backend.data import db, llm_registry
|
||||
from backend.data.block import BlockSchema, initialize_blocks
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.data.llm_registry import subscribe_to_registry_refresh
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def initialize_registry_for_executor() -> None:
|
||||
"""
|
||||
Initialize blocks and refresh LLM registry in the executor context.
|
||||
|
||||
This must run in the executor's event loop to have access to the database.
|
||||
"""
|
||||
try:
|
||||
# Connect to database if not already connected
|
||||
if not db.is_connected():
|
||||
await db.connect()
|
||||
logger.info("[GraphExecutor] Connected to database for registry refresh")
|
||||
|
||||
# Initialize blocks (internally refreshes LLM registry and costs)
|
||||
await initialize_blocks()
|
||||
logger.info("[GraphExecutor] Blocks initialized")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_registry_on_notification() -> None:
|
||||
"""Refresh LLM registry when notified via Redis pub/sub."""
|
||||
try:
|
||||
# Ensure DB is connected
|
||||
if not db.is_connected():
|
||||
await db.connect()
|
||||
|
||||
# Refresh registry and costs
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they regenerate with new model options
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
logger.info("[GraphExecutor] LLM registry refreshed from notification")
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def subscribe_to_registry_updates() -> None:
|
||||
"""Subscribe to Redis pub/sub for LLM registry refresh notifications."""
|
||||
await subscribe_to_registry_refresh(refresh_registry_on_notification)
|
||||
@@ -702,20 +702,6 @@ class ExecutionProcessor:
|
||||
)
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
|
||||
# Initialize LLM registry and subscribe to updates
|
||||
from backend.executor.llm_registry_init import (
|
||||
initialize_registry_for_executor,
|
||||
subscribe_to_registry_updates,
|
||||
)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
initialize_registry_for_executor(), self.node_execution_loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
subscribe_to_registry_updates(), self.node_execution_loop
|
||||
)
|
||||
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import human_review as human_review_db
|
||||
from backend.data import onboarding as onboarding_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import (
|
||||
@@ -749,9 +750,27 @@ async def stop_graph_execution(
|
||||
if graph_exec.status in [
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.REVIEW,
|
||||
]:
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
# If the graph is queued/incomplete/paused for review, terminate immediately
|
||||
# No need to wait for executor since it's not actively running
|
||||
|
||||
# If graph is in REVIEW status, clean up pending reviews before terminating
|
||||
if graph_exec.status == ExecutionStatus.REVIEW:
|
||||
# Use human_review_db if Prisma connected, else database manager
|
||||
review_db = (
|
||||
human_review_db
|
||||
if prisma.is_connected()
|
||||
else get_database_manager_async_client()
|
||||
)
|
||||
# Mark all pending reviews as rejected/cancelled
|
||||
cancelled_count = await review_db.cancel_pending_reviews_for_execution(
|
||||
graph_exec_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"Cancelled {cancelled_count} pending review(s) for stopped execution {graph_exec_id}"
|
||||
)
|
||||
|
||||
graph_exec.status = ExecutionStatus.TERMINATED
|
||||
|
||||
await asyncio.gather(
|
||||
@@ -887,9 +906,28 @@ async def add_graph_execution(
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
logger.info(f"Queueing execution {graph_exec.id}")
|
||||
|
||||
# Update execution status to QUEUED BEFORE publishing to prevent race condition
|
||||
# where two concurrent requests could both publish the same execution
|
||||
updated_exec = await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.QUEUED,
|
||||
)
|
||||
|
||||
# Verify the status update succeeded (prevents duplicate queueing in race conditions)
|
||||
# If another request already updated the status, this execution will not be QUEUED
|
||||
if not updated_exec or updated_exec.status != ExecutionStatus.QUEUED:
|
||||
logger.warning(
|
||||
f"Skipping queue publish for execution {graph_exec.id} - "
|
||||
f"status update failed or execution already queued by another request"
|
||||
)
|
||||
return graph_exec
|
||||
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
|
||||
# Publish to execution queue for executor to pick up
|
||||
# This happens AFTER status update to ensure only one request publishes
|
||||
exec_queue = await get_async_execution_queue()
|
||||
await exec_queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
@@ -897,13 +935,6 @@ async def add_graph_execution(
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
# Update execution status to QUEUED
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=graph_exec.status,
|
||||
)
|
||||
except BaseException as e:
|
||||
err = str(e) or type(e).__name__
|
||||
if not graph_exec:
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
|
||||
@@ -346,6 +347,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Mock the queue and event bus
|
||||
@@ -611,6 +613,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
@@ -670,3 +673,232 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_in_review_status_cancels_pending_reviews(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Test that stopping an execution in REVIEW status cancels pending reviews."""
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
from backend.executor.utils import stop_graph_execution
|
||||
|
||||
user_id = "test-user"
|
||||
graph_exec_id = "test-exec-123"
|
||||
|
||||
# Mock graph execution in REVIEW status
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_graph_exec.id = graph_exec_id
|
||||
mock_graph_exec.status = ExecutionStatus.REVIEW
|
||||
|
||||
# Mock dependencies
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_queue_client = mocker.AsyncMock()
|
||||
mock_get_queue.return_value = mock_queue_client
|
||||
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_prisma.is_connected.return_value = True
|
||||
|
||||
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||
return_value=2 # 2 reviews cancelled
|
||||
)
|
||||
|
||||
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
|
||||
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
mock_get_child_executions = mocker.patch(
|
||||
"backend.executor.utils._get_child_executions"
|
||||
)
|
||||
mock_get_child_executions.return_value = [] # No children
|
||||
|
||||
# Call stop_graph_execution with timeout to allow status check
|
||||
await stop_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
wait_timeout=1.0, # Wait to allow status check
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
# Verify pending reviews were cancelled
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||
graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Verify execution status was updated to TERMINATED
|
||||
mock_execution_db.update_graph_execution_stats.assert_called_once()
|
||||
call_kwargs = mock_execution_db.update_graph_execution_stats.call_args[1]
|
||||
assert call_kwargs["graph_exec_id"] == graph_exec_id
|
||||
assert call_kwargs["status"] == ExecutionStatus.TERMINATED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_with_database_manager_when_prisma_disconnected(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Test that stop uses database manager when Prisma is not connected."""
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
from backend.executor.utils import stop_graph_execution
|
||||
|
||||
user_id = "test-user"
|
||||
graph_exec_id = "test-exec-456"
|
||||
|
||||
# Mock graph execution in REVIEW status
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_graph_exec.id = graph_exec_id
|
||||
mock_graph_exec.status = ExecutionStatus.REVIEW
|
||||
|
||||
# Mock dependencies
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_queue_client = mocker.AsyncMock()
|
||||
mock_get_queue.return_value = mock_queue_client
|
||||
|
||||
# Prisma is NOT connected
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_prisma.is_connected.return_value = False
|
||||
|
||||
# Mock database manager client
|
||||
mock_get_db_manager = mocker.patch(
|
||||
"backend.executor.utils.get_database_manager_async_client"
|
||||
)
|
||||
mock_db_manager = mocker.AsyncMock()
|
||||
mock_db_manager.get_graph_execution_meta = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_db_manager.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||
return_value=3 # 3 reviews cancelled
|
||||
)
|
||||
mock_db_manager.update_graph_execution_stats = mocker.AsyncMock()
|
||||
mock_get_db_manager.return_value = mock_db_manager
|
||||
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
mock_get_child_executions = mocker.patch(
|
||||
"backend.executor.utils._get_child_executions"
|
||||
)
|
||||
mock_get_child_executions.return_value = [] # No children
|
||||
|
||||
# Call stop_graph_execution with timeout
|
||||
await stop_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
wait_timeout=1.0,
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
# Verify database manager was used for cancel_pending_reviews
|
||||
mock_db_manager.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||
graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Verify execution status was updated via database manager
|
||||
mock_db_manager.update_graph_execution_stats.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_cascades_to_child_with_reviews(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Test that stopping parent execution cascades to children and cancels their reviews."""
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
from backend.executor.utils import stop_graph_execution
|
||||
|
||||
user_id = "test-user"
|
||||
parent_exec_id = "parent-exec"
|
||||
child_exec_id = "child-exec"
|
||||
|
||||
# Mock parent execution in RUNNING status
|
||||
mock_parent_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_parent_exec.id = parent_exec_id
|
||||
mock_parent_exec.status = ExecutionStatus.RUNNING
|
||||
|
||||
# Mock child execution in REVIEW status
|
||||
mock_child_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_child_exec.id = child_exec_id
|
||||
mock_child_exec.status = ExecutionStatus.REVIEW
|
||||
|
||||
# Mock dependencies
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_queue_client = mocker.AsyncMock()
|
||||
mock_get_queue.return_value = mock_queue_client
|
||||
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_prisma.is_connected.return_value = True
|
||||
|
||||
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||
return_value=1 # 1 child review cancelled
|
||||
)
|
||||
|
||||
# Mock execution_db to return different status based on which execution is queried
|
||||
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
|
||||
|
||||
# Track call count to simulate status transition
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def get_exec_meta_side_effect(execution_id, user_id):
|
||||
call_count["count"] += 1
|
||||
if execution_id == parent_exec_id:
|
||||
# After a few calls (child processing happens), transition parent to TERMINATED
|
||||
# This simulates the executor service processing the stop request
|
||||
if call_count["count"] > 3:
|
||||
mock_parent_exec.status = ExecutionStatus.TERMINATED
|
||||
return mock_parent_exec
|
||||
elif execution_id == child_exec_id:
|
||||
return mock_child_exec
|
||||
return None
|
||||
|
||||
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
|
||||
side_effect=get_exec_meta_side_effect
|
||||
)
|
||||
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
|
||||
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
# Mock _get_child_executions to return the child
|
||||
mock_get_child_executions = mocker.patch(
|
||||
"backend.executor.utils._get_child_executions"
|
||||
)
|
||||
|
||||
def get_children_side_effect(parent_id):
|
||||
if parent_id == parent_exec_id:
|
||||
return [mock_child_exec]
|
||||
return []
|
||||
|
||||
mock_get_child_executions.side_effect = get_children_side_effect
|
||||
|
||||
# Call stop_graph_execution on parent with cascade=True
|
||||
await stop_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=parent_exec_id,
|
||||
wait_timeout=1.0,
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
# Verify child reviews were cancelled
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||
child_exec_id, user_id
|
||||
)
|
||||
|
||||
# Verify both parent and child status updates
|
||||
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
||||
|
||||
@@ -1,935 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Sequence, cast
|
||||
|
||||
import prisma
|
||||
import prisma.models
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
from backend.util.models import Pagination
|
||||
|
||||
|
||||
def _json_dict(value: Any | None) -> dict[str, Any]:
|
||||
if not value:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return {}
|
||||
|
||||
|
||||
def _map_cost(record: prisma.models.LlmModelCost) -> llm_model.LlmModelCost:
|
||||
return llm_model.LlmModelCost(
|
||||
id=record.id,
|
||||
unit=record.unit,
|
||||
credit_cost=record.creditCost,
|
||||
credential_provider=record.credentialProvider,
|
||||
credential_id=record.credentialId,
|
||||
credential_type=record.credentialType,
|
||||
currency=record.currency,
|
||||
metadata=_json_dict(record.metadata),
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(
|
||||
record: prisma.models.LlmModelCreator,
|
||||
) -> llm_model.LlmModelCreator:
|
||||
return llm_model.LlmModelCreator(
|
||||
id=record.id,
|
||||
name=record.name,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
website_url=record.websiteUrl,
|
||||
logo_url=record.logoUrl,
|
||||
metadata=_json_dict(record.metadata),
|
||||
)
|
||||
|
||||
|
||||
def _map_model(record: prisma.models.LlmModel) -> llm_model.LlmModel:
|
||||
costs = []
|
||||
if record.Costs:
|
||||
costs = [_map_cost(cost) for cost in record.Costs]
|
||||
|
||||
creator = None
|
||||
if hasattr(record, "Creator") and record.Creator:
|
||||
creator = _map_creator(record.Creator)
|
||||
|
||||
return llm_model.LlmModel(
|
||||
id=record.id,
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
provider_id=record.providerId,
|
||||
creator_id=record.creatorId,
|
||||
creator=creator,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
capabilities=_json_dict(record.capabilities),
|
||||
metadata=_json_dict(record.metadata),
|
||||
costs=costs,
|
||||
)
|
||||
|
||||
|
||||
def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
|
||||
models: list[llm_model.LlmModel] = []
|
||||
if record.Models:
|
||||
models = [_map_model(model) for model in record.Models]
|
||||
|
||||
return llm_model.LlmProvider(
|
||||
id=record.id,
|
||||
name=record.name,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
default_credential_provider=record.defaultCredentialProvider,
|
||||
default_credential_id=record.defaultCredentialId,
|
||||
default_credential_type=record.defaultCredentialType,
|
||||
supports_tools=record.supportsTools,
|
||||
supports_json_output=record.supportsJsonOutput,
|
||||
supports_reasoning=record.supportsReasoning,
|
||||
supports_parallel_tool=record.supportsParallelTool,
|
||||
metadata=_json_dict(record.metadata),
|
||||
models=models,
|
||||
)
|
||||
|
||||
|
||||
async def list_providers(
|
||||
include_models: bool = True, enabled_only: bool = False
|
||||
) -> list[llm_model.LlmProvider]:
|
||||
"""
|
||||
List all LLM providers.
|
||||
|
||||
Args:
|
||||
include_models: Whether to include models for each provider
|
||||
enabled_only: If True, only include enabled models (for public routes)
|
||||
"""
|
||||
include: Any = None
|
||||
if include_models:
|
||||
model_where = {"isEnabled": True} if enabled_only else None
|
||||
include = {
|
||||
"Models": {
|
||||
"include": {"Costs": True, "Creator": True},
|
||||
"where": model_where,
|
||||
}
|
||||
}
|
||||
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
|
||||
return [_map_provider(record) for record in records]
|
||||
|
||||
|
||||
async def upsert_provider(
|
||||
request: llm_model.UpsertLlmProviderRequest,
|
||||
provider_id: str | None = None,
|
||||
) -> llm_model.LlmProvider:
|
||||
data: Any = {
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"defaultCredentialProvider": request.default_credential_provider,
|
||||
"defaultCredentialId": request.default_credential_id,
|
||||
"defaultCredentialType": request.default_credential_type,
|
||||
"supportsTools": request.supports_tools,
|
||||
"supportsJsonOutput": request.supports_json_output,
|
||||
"supportsReasoning": request.supports_reasoning,
|
||||
"supportsParallelTool": request.supports_parallel_tool,
|
||||
"metadata": prisma.Json(request.metadata or {}),
|
||||
}
|
||||
include: Any = {"Models": {"include": {"Costs": True, "Creator": True}}}
|
||||
if provider_id:
|
||||
record = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
data=data,
|
||||
include=include,
|
||||
)
|
||||
else:
|
||||
record = await prisma.models.LlmProvider.prisma().create(
|
||||
data=data,
|
||||
include=include,
|
||||
)
|
||||
if record is None:
|
||||
raise ValueError("Failed to create/update provider")
|
||||
return _map_provider(record)
|
||||
|
||||
|
||||
async def delete_provider(provider_id: str) -> bool:
|
||||
"""
|
||||
Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
Due to onDelete: Restrict on LlmModel.Provider, the database will
|
||||
block deletion if models exist.
|
||||
|
||||
Args:
|
||||
provider_id: UUID of the provider to delete
|
||||
|
||||
Returns:
|
||||
True if deleted successfully
|
||||
|
||||
Raises:
|
||||
ValueError: If provider not found or has associated models
|
||||
"""
|
||||
# Check if provider exists
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Check if provider has any models
|
||||
model_count = len(provider.Models) if provider.Models else 0
|
||||
if model_count > 0:
|
||||
raise ValueError(
|
||||
f"Cannot delete provider '{provider.displayName}' because it has "
|
||||
f"{model_count} model(s). Delete all models first."
|
||||
)
|
||||
|
||||
# Safe to delete
|
||||
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
|
||||
return True
|
||||
|
||||
|
||||
async def list_models(
|
||||
provider_id: str | None = None,
|
||||
enabled_only: bool = False,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> llm_model.LlmModelsResponse:
|
||||
"""
|
||||
List LLM models with pagination.
|
||||
|
||||
Args:
|
||||
provider_id: Optional filter by provider ID
|
||||
enabled_only: If True, only return enabled models (for public routes)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Number of models per page
|
||||
"""
|
||||
where: Any = {}
|
||||
if provider_id:
|
||||
where["providerId"] = provider_id
|
||||
if enabled_only:
|
||||
where["isEnabled"] = True
|
||||
|
||||
# Get total count for pagination
|
||||
total_items = await prisma.models.LlmModel.prisma().count(
|
||||
where=where if where else None
|
||||
)
|
||||
|
||||
# Calculate pagination
|
||||
skip = (page - 1) * page_size
|
||||
total_pages = (total_items + page_size - 1) // page_size if total_items > 0 else 0
|
||||
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
where=where if where else None,
|
||||
include={"Costs": True, "Creator": True},
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
)
|
||||
models = [_map_model(record) for record in records]
|
||||
|
||||
return llm_model.LlmModelsResponse(
|
||||
models=models,
|
||||
pagination=Pagination(
|
||||
total_items=total_items,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _cost_create_payload(
|
||||
costs: Sequence[llm_model.LlmModelCostInput],
|
||||
) -> dict[str, Iterable[dict[str, Any]]]:
|
||||
|
||||
create_items = []
|
||||
for cost in costs:
|
||||
item: dict[str, Any] = {
|
||||
"unit": cost.unit,
|
||||
"creditCost": cost.credit_cost,
|
||||
"credentialProvider": cost.credential_provider,
|
||||
}
|
||||
# Only include optional fields if they have values
|
||||
if cost.credential_id:
|
||||
item["credentialId"] = cost.credential_id
|
||||
if cost.credential_type:
|
||||
item["credentialType"] = cost.credential_type
|
||||
if cost.currency:
|
||||
item["currency"] = cost.currency
|
||||
# Handle metadata - use Prisma Json type
|
||||
if cost.metadata is not None and cost.metadata != {}:
|
||||
item["metadata"] = prisma.Json(cost.metadata)
|
||||
create_items.append(item)
|
||||
return {"create": create_items}
|
||||
|
||||
|
||||
async def create_model(
|
||||
request: llm_model.CreateLlmModelRequest,
|
||||
) -> llm_model.LlmModel:
|
||||
data: Any = {
|
||||
"slug": request.slug,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"Provider": {"connect": {"id": request.provider_id}},
|
||||
"contextWindow": request.context_window,
|
||||
"maxOutputTokens": request.max_output_tokens,
|
||||
"isEnabled": request.is_enabled,
|
||||
"capabilities": prisma.Json(request.capabilities or {}),
|
||||
"metadata": prisma.Json(request.metadata or {}),
|
||||
"Costs": _cost_create_payload(request.costs),
|
||||
}
|
||||
if request.creator_id:
|
||||
data["Creator"] = {"connect": {"id": request.creator_id}}
|
||||
|
||||
record = await prisma.models.LlmModel.prisma().create(
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
return _map_model(record)
|
||||
|
||||
|
||||
async def update_model(
|
||||
model_id: str,
|
||||
request: llm_model.UpdateLlmModelRequest,
|
||||
) -> llm_model.LlmModel:
|
||||
# Build scalar field updates (non-relation fields)
|
||||
scalar_data: Any = {}
|
||||
if request.display_name is not None:
|
||||
scalar_data["displayName"] = request.display_name
|
||||
if request.description is not None:
|
||||
scalar_data["description"] = request.description
|
||||
if request.context_window is not None:
|
||||
scalar_data["contextWindow"] = request.context_window
|
||||
if request.max_output_tokens is not None:
|
||||
scalar_data["maxOutputTokens"] = request.max_output_tokens
|
||||
if request.is_enabled is not None:
|
||||
scalar_data["isEnabled"] = request.is_enabled
|
||||
if request.capabilities is not None:
|
||||
scalar_data["capabilities"] = request.capabilities
|
||||
if request.metadata is not None:
|
||||
scalar_data["metadata"] = request.metadata
|
||||
# Foreign keys can be updated directly as scalar fields
|
||||
if request.provider_id is not None:
|
||||
scalar_data["providerId"] = request.provider_id
|
||||
if request.creator_id is not None:
|
||||
# Empty string means remove the creator
|
||||
scalar_data["creatorId"] = request.creator_id if request.creator_id else None
|
||||
|
||||
# If we have costs to update, we need to handle them separately
|
||||
# because nested writes have different constraints
|
||||
if request.costs is not None:
|
||||
# Wrap cost replacement in a transaction for atomicity
|
||||
async with transaction() as tx:
|
||||
# First update scalar fields
|
||||
if scalar_data:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data=scalar_data,
|
||||
)
|
||||
# Then handle costs: delete existing and create new
|
||||
await tx.llmmodelcost.delete_many(where={"llmModelId": model_id})
|
||||
if request.costs:
|
||||
cost_payload = _cost_create_payload(request.costs)
|
||||
for cost_item in cost_payload["create"]:
|
||||
cost_item["llmModelId"] = model_id
|
||||
await tx.llmmodelcost.create(data=cast(Any, cost_item))
|
||||
# Fetch the updated record (outside transaction)
|
||||
record = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
else:
|
||||
# No costs update - simple update
|
||||
record = await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data=scalar_data,
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
|
||||
if not record:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return _map_model(record)
|
||||
|
||||
|
||||
async def toggle_model(
|
||||
model_id: str,
|
||||
is_enabled: bool,
|
||||
migrate_to_slug: str | None = None,
|
||||
migration_reason: str | None = None,
|
||||
custom_credit_cost: int | None = None,
|
||||
) -> llm_model.ToggleLlmModelResponse:
|
||||
"""
|
||||
Toggle a model's enabled status, optionally migrating workflows when disabling.
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model to toggle
|
||||
is_enabled: New enabled status
|
||||
migrate_to_slug: If disabling and this is provided, migrate all workflows
|
||||
using this model to the specified replacement model
|
||||
migration_reason: Optional reason for the migration (e.g., "Provider outage")
|
||||
custom_credit_cost: Optional custom pricing override for migrated workflows.
|
||||
When set, the billing system should use this cost instead
|
||||
of the target model's cost for affected nodes.
|
||||
|
||||
Returns:
|
||||
ToggleLlmModelResponse with the updated model and optional migration stats
|
||||
"""
|
||||
import json
|
||||
|
||||
# Get the model being toggled
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
nodes_migrated = 0
|
||||
migration_id: str | None = None
|
||||
|
||||
# If disabling with migration, perform migration first
|
||||
if not is_enabled and migrate_to_slug:
|
||||
# Validate replacement model exists and is enabled
|
||||
replacement = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migrate_to_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(f"Replacement model '{migrate_to_slug}' not found")
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
# Perform all operations atomically within a single transaction
|
||||
# This ensures no nodes are missed between query and update
|
||||
async with transaction() as tx:
|
||||
# Get the IDs of nodes that will be migrated (inside transaction for consistency)
|
||||
node_ids_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT id
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
FOR UPDATE
|
||||
""",
|
||||
model.slug,
|
||||
)
|
||||
migrated_node_ids = (
|
||||
[row["id"] for row in node_ids_result] if node_ids_result else []
|
||||
)
|
||||
nodes_migrated = len(migrated_node_ids)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
# Update by IDs to ensure we only update the exact nodes we queried
|
||||
# Use JSON array and jsonb_array_elements_text for safe parameterization
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
""",
|
||||
migrate_to_slug,
|
||||
node_ids_json,
|
||||
)
|
||||
|
||||
record = await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
include={"Costs": True},
|
||||
)
|
||||
|
||||
# Create migration record for revert capability
|
||||
if nodes_migrated > 0:
|
||||
migration_data: Any = {
|
||||
"sourceModelSlug": model.slug,
|
||||
"targetModelSlug": migrate_to_slug,
|
||||
"reason": migration_reason,
|
||||
"migratedNodeIds": json.dumps(migrated_node_ids),
|
||||
"nodeCount": nodes_migrated,
|
||||
"customCreditCost": custom_credit_cost,
|
||||
}
|
||||
migration_record = await tx.llmmodelmigration.create(
|
||||
data=migration_data
|
||||
)
|
||||
migration_id = migration_record.id
|
||||
else:
|
||||
# Simple toggle without migration
|
||||
record = await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
include={"Costs": True},
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return llm_model.ToggleLlmModelResponse(
|
||||
model=_map_model(record),
|
||||
nodes_migrated=nodes_migrated,
|
||||
migrated_to_slug=migrate_to_slug if nodes_migrated > 0 else None,
|
||||
migration_id=migration_id,
|
||||
)
|
||||
|
||||
|
||||
async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
|
||||
"""Get usage count for a model."""
|
||||
import prisma as prisma_module
|
||||
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
model.slug,
|
||||
)
|
||||
node_count = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
return llm_model.LlmModelUsageResponse(model_slug=model.slug, node_count=node_count)
|
||||
|
||||
|
||||
async def delete_model(
|
||||
model_id: str, replacement_model_slug: str | None = None
|
||||
) -> llm_model.DeleteLlmModelResponse:
|
||||
"""
|
||||
Delete a model and optionally migrate all AgentNodes using it to a replacement model.
|
||||
|
||||
This performs an atomic operation within a database transaction:
|
||||
1. Validates the model exists
|
||||
2. Counts affected nodes
|
||||
3. If nodes exist, validates replacement model and migrates them
|
||||
4. Deletes the LlmModel record (CASCADE deletes costs)
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model to delete
|
||||
replacement_model_slug: Slug of the model to migrate to (required only if nodes use this model)
|
||||
|
||||
Returns:
|
||||
DeleteLlmModelResponse with migration stats
|
||||
|
||||
Raises:
|
||||
ValueError: If model not found, nodes exist but no replacement provided,
|
||||
replacement not found, or replacement is disabled
|
||||
"""
|
||||
# 1. Get the model being deleted (validation - outside transaction)
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
deleted_slug = model.slug
|
||||
deleted_display_name = model.displayName
|
||||
|
||||
# 2. Count affected nodes first to determine if replacement is needed
|
||||
import prisma as prisma_module
|
||||
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
deleted_slug,
|
||||
)
|
||||
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
# 3. Validate replacement model only if there are nodes to migrate
|
||||
if nodes_to_migrate > 0:
|
||||
if not replacement_model_slug:
|
||||
raise ValueError(
|
||||
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
|
||||
f"are using it. Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
replacement = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": replacement_model_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(f"Replacement model '{replacement_model_slug}' not found")
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
# 4. Perform migration (if needed) and deletion atomically within a transaction
|
||||
async with transaction() as tx:
|
||||
# Migrate all AgentNode.constantInput->model to replacement
|
||||
if nodes_to_migrate > 0 and replacement_model_slug:
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug,
|
||||
)
|
||||
|
||||
# Delete the model (CASCADE will delete costs automatically)
|
||||
await tx.llmmodel.delete(where={"id": model_id})
|
||||
|
||||
# Build appropriate message based on whether migration happened
|
||||
if nodes_to_migrate > 0:
|
||||
message = (
|
||||
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}) "
|
||||
f"and migrated {nodes_to_migrate} workflow node(s) to '{replacement_model_slug}'."
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}). "
|
||||
f"No workflows were using this model."
|
||||
)
|
||||
|
||||
return llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug=deleted_slug,
|
||||
deleted_model_display_name=deleted_display_name,
|
||||
replacement_model_slug=replacement_model_slug,
|
||||
nodes_migrated=nodes_to_migrate,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
def _map_migration(
|
||||
record: prisma.models.LlmModelMigration,
|
||||
) -> llm_model.LlmModelMigration:
|
||||
return llm_model.LlmModelMigration(
|
||||
id=record.id,
|
||||
source_model_slug=record.sourceModelSlug,
|
||||
target_model_slug=record.targetModelSlug,
|
||||
reason=record.reason,
|
||||
node_count=record.nodeCount,
|
||||
custom_credit_cost=record.customCreditCost,
|
||||
is_reverted=record.isReverted,
|
||||
created_at=record.createdAt.isoformat(),
|
||||
reverted_at=record.revertedAt.isoformat() if record.revertedAt else None,
|
||||
)
|
||||
|
||||
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> list[llm_model.LlmModelMigration]:
|
||||
"""
|
||||
List model migrations, optionally including reverted ones.
|
||||
|
||||
Args:
|
||||
include_reverted: If True, include reverted migrations. Default is False.
|
||||
|
||||
Returns:
|
||||
List of LlmModelMigration records
|
||||
"""
|
||||
where: Any = None if include_reverted else {"isReverted": False}
|
||||
records = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||
where=where,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [_map_migration(record) for record in records]
|
||||
|
||||
|
||||
async def get_migration(migration_id: str) -> llm_model.LlmModelMigration | None:
|
||||
"""Get a specific migration by ID."""
|
||||
record = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
return _map_migration(record) if record else None
|
||||
|
||||
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> llm_model.RevertMigrationResponse:
|
||||
"""
|
||||
Revert a model migration, restoring affected nodes to their original model.
|
||||
|
||||
This only reverts the specific nodes that were migrated, not all nodes
|
||||
currently using the target model.
|
||||
|
||||
Args:
|
||||
migration_id: UUID of the migration to revert
|
||||
re_enable_source_model: Whether to re-enable the source model if it's disabled
|
||||
|
||||
Returns:
|
||||
RevertMigrationResponse with revert stats
|
||||
|
||||
Raises:
|
||||
ValueError: If migration not found, already reverted, or source model not available
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Get the migration record
|
||||
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
if not migration:
|
||||
raise ValueError(f"Migration with id '{migration_id}' not found")
|
||||
|
||||
if migration.isReverted:
|
||||
raise ValueError(
|
||||
f"Migration '{migration_id}' has already been reverted "
|
||||
f"on {migration.revertedAt.isoformat() if migration.revertedAt else 'unknown date'}"
|
||||
)
|
||||
|
||||
# Check if source model exists
|
||||
source_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migration.sourceModelSlug}
|
||||
)
|
||||
if not source_model:
|
||||
raise ValueError(
|
||||
f"Source model '{migration.sourceModelSlug}' no longer exists. "
|
||||
f"Cannot revert migration."
|
||||
)
|
||||
|
||||
# Get the migrated node IDs (Prisma auto-parses JSONB to list)
|
||||
migrated_node_ids: list[str] = (
|
||||
migration.migratedNodeIds
|
||||
if isinstance(migration.migratedNodeIds, list)
|
||||
else json.loads(migration.migratedNodeIds) # type: ignore
|
||||
)
|
||||
if not migrated_node_ids:
|
||||
raise ValueError("No nodes to revert in this migration")
|
||||
|
||||
# Track if we need to re-enable the source model
|
||||
source_model_was_disabled = not source_model.isEnabled
|
||||
should_re_enable = source_model_was_disabled and re_enable_source_model
|
||||
source_model_re_enabled = False
|
||||
|
||||
# Perform revert atomically
|
||||
async with transaction() as tx:
|
||||
# Re-enable the source model if requested and it was disabled
|
||||
if should_re_enable:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": source_model.id},
|
||||
data={"isEnabled": True},
|
||||
)
|
||||
source_model_re_enabled = True
|
||||
|
||||
# Update only the specific nodes that were migrated
|
||||
# We need to check that they still have the target model (haven't been changed since)
|
||||
# Use a single batch update for efficiency
|
||||
# Use JSON array and jsonb_array_elements_text for safe parameterization
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
result = await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
AND "constantInput"::jsonb->>'model' = $3
|
||||
""",
|
||||
migration.sourceModelSlug,
|
||||
node_ids_json,
|
||||
migration.targetModelSlug,
|
||||
)
|
||||
nodes_reverted = result if result else 0
|
||||
|
||||
# Mark migration as reverted
|
||||
await tx.llmmodelmigration.update(
|
||||
where={"id": migration_id},
|
||||
data={
|
||||
"isReverted": True,
|
||||
"revertedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
# Calculate nodes that were already changed since migration
|
||||
nodes_already_changed = len(migrated_node_ids) - nodes_reverted
|
||||
|
||||
# Build appropriate message
|
||||
message_parts = [
|
||||
f"Successfully reverted migration: {nodes_reverted} node(s) restored "
|
||||
f"from '{migration.targetModelSlug}' to '{migration.sourceModelSlug}'."
|
||||
]
|
||||
if nodes_already_changed > 0:
|
||||
message_parts.append(
|
||||
f" {nodes_already_changed} node(s) were already changed and not reverted."
|
||||
)
|
||||
if source_model_re_enabled:
|
||||
message_parts.append(
|
||||
f" Model '{migration.sourceModelSlug}' has been re-enabled."
|
||||
)
|
||||
|
||||
return llm_model.RevertMigrationResponse(
|
||||
migration_id=migration_id,
|
||||
source_model_slug=migration.sourceModelSlug,
|
||||
target_model_slug=migration.targetModelSlug,
|
||||
nodes_reverted=nodes_reverted,
|
||||
nodes_already_changed=nodes_already_changed,
|
||||
source_model_re_enabled=source_model_re_enabled,
|
||||
message="".join(message_parts),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Creator CRUD operations
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def list_creators() -> list[llm_model.LlmModelCreator]:
|
||||
"""List all LLM model creators."""
|
||||
records = await prisma.models.LlmModelCreator.prisma().find_many(
|
||||
order={"displayName": "asc"}
|
||||
)
|
||||
return [_map_creator(record) for record in records]
|
||||
|
||||
|
||||
async def get_creator(creator_id: str) -> llm_model.LlmModelCreator | None:
|
||||
"""Get a specific creator by ID."""
|
||||
record = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"id": creator_id}
|
||||
)
|
||||
return _map_creator(record) if record else None
|
||||
|
||||
|
||||
async def upsert_creator(
|
||||
request: llm_model.UpsertLlmCreatorRequest,
|
||||
creator_id: str | None = None,
|
||||
) -> llm_model.LlmModelCreator:
|
||||
"""Create or update a model creator."""
|
||||
data: Any = {
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"websiteUrl": request.website_url,
|
||||
"logoUrl": request.logo_url,
|
||||
"metadata": prisma.Json(request.metadata or {}),
|
||||
}
|
||||
if creator_id:
|
||||
record = await prisma.models.LlmModelCreator.prisma().update(
|
||||
where={"id": creator_id},
|
||||
data=data,
|
||||
)
|
||||
else:
|
||||
record = await prisma.models.LlmModelCreator.prisma().create(data=data)
|
||||
if record is None:
|
||||
raise ValueError("Failed to create/update creator")
|
||||
return _map_creator(record)
|
||||
|
||||
|
||||
async def delete_creator(creator_id: str) -> bool:
|
||||
"""
|
||||
Delete a model creator.
|
||||
|
||||
This will set creatorId to NULL on all associated models (due to onDelete: SetNull).
|
||||
|
||||
Args:
|
||||
creator_id: UUID of the creator to delete
|
||||
|
||||
Returns:
|
||||
True if deleted successfully
|
||||
|
||||
Raises:
|
||||
ValueError: If creator not found
|
||||
"""
|
||||
creator = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"id": creator_id}
|
||||
)
|
||||
if not creator:
|
||||
raise ValueError(f"Creator with id '{creator_id}' not found")
|
||||
|
||||
await prisma.models.LlmModelCreator.prisma().delete(where={"id": creator_id})
|
||||
return True
|
||||
|
||||
|
||||
async def get_recommended_model() -> llm_model.LlmModel | None:
|
||||
"""
|
||||
Get the currently recommended LLM model.
|
||||
|
||||
Returns:
|
||||
The recommended model, or None if no model is marked as recommended.
|
||||
"""
|
||||
record = await prisma.models.LlmModel.prisma().find_first(
|
||||
where={"isRecommended": True, "isEnabled": True},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return _map_model(record) if record else None
|
||||
|
||||
|
||||
async def set_recommended_model(
|
||||
model_id: str,
|
||||
) -> tuple[llm_model.LlmModel, str | None]:
|
||||
"""
|
||||
Set a model as the recommended model.
|
||||
|
||||
This will clear the isRecommended flag from any other model and set it
|
||||
on the specified model. The model must be enabled.
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model to set as recommended
|
||||
|
||||
Returns:
|
||||
Tuple of (the updated model, previous recommended model slug or None)
|
||||
|
||||
Raises:
|
||||
ValueError: If model not found or not enabled
|
||||
"""
|
||||
# First, verify the model exists and is enabled
|
||||
target_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}
|
||||
)
|
||||
if not target_model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
if not target_model.isEnabled:
|
||||
raise ValueError(
|
||||
f"Cannot set disabled model '{target_model.slug}' as recommended"
|
||||
)
|
||||
|
||||
# Get the current recommended model (if any)
|
||||
current_recommended = await prisma.models.LlmModel.prisma().find_first(
|
||||
where={"isRecommended": True}
|
||||
)
|
||||
previous_slug = current_recommended.slug if current_recommended else None
|
||||
|
||||
# Use a transaction to ensure atomicity
|
||||
async with transaction() as tx:
|
||||
# Clear isRecommended from all models
|
||||
await tx.llmmodel.update_many(
|
||||
where={"isRecommended": True},
|
||||
data={"isRecommended": False},
|
||||
)
|
||||
# Set the new recommended model
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isRecommended": True},
|
||||
)
|
||||
|
||||
# Fetch and return the updated model
|
||||
updated_record = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
if not updated_record:
|
||||
raise ValueError("Failed to fetch updated model")
|
||||
|
||||
return _map_model(updated_record), previous_slug
|
||||
|
||||
|
||||
async def get_recommended_model_slug() -> str | None:
|
||||
"""
|
||||
Get the slug of the currently recommended LLM model.
|
||||
|
||||
Returns:
|
||||
The slug of the recommended model, or None if no model is marked as recommended.
|
||||
"""
|
||||
record = await prisma.models.LlmModel.prisma().find_first(
|
||||
where={"isRecommended": True, "isEnabled": True},
|
||||
)
|
||||
return record.slug if record else None
|
||||
@@ -1,235 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
from backend.util.models import Pagination
|
||||
|
||||
# Pattern for valid model slugs: alphanumeric start, then alphanumeric, dots, underscores, slashes, hyphens
|
||||
SLUG_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._/-]*$")
|
||||
|
||||
|
||||
class LlmModelCost(pydantic.BaseModel):
|
||||
id: str
|
||||
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: Optional[str] = None
|
||||
credential_type: Optional[str] = None
|
||||
currency: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCreator(pydantic.BaseModel):
|
||||
"""Represents the organization that created/trained the model (e.g., OpenAI, Meta)."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
website_url: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModel(pydantic.BaseModel):
|
||||
id: str
|
||||
slug: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
provider_id: str
|
||||
creator_id: Optional[str] = None
|
||||
creator: Optional[LlmModelCreator] = None
|
||||
context_window: int
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: bool = True
|
||||
is_recommended: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvider(pydantic.BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
default_credential_provider: Optional[str] = None
|
||||
default_credential_id: Optional[str] = None
|
||||
default_credential_type: Optional[str] = None
|
||||
supports_tools: bool = True
|
||||
supports_json_output: bool = True
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool: bool = False
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
models: list[LlmModel] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvidersResponse(pydantic.BaseModel):
|
||||
providers: list[LlmProvider]
|
||||
|
||||
|
||||
class LlmModelsResponse(pydantic.BaseModel):
|
||||
models: list[LlmModel]
|
||||
pagination: Optional[Pagination] = None
|
||||
|
||||
|
||||
class LlmCreatorsResponse(pydantic.BaseModel):
|
||||
creators: list[LlmModelCreator]
|
||||
|
||||
|
||||
class UpsertLlmProviderRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
default_credential_provider: Optional[str] = None
|
||||
default_credential_id: Optional[str] = None
|
||||
default_credential_type: Optional[str] = "api_key"
|
||||
supports_tools: bool = True
|
||||
supports_json_output: bool = True
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool: bool = False
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class UpsertLlmCreatorRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
website_url: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCostInput(pydantic.BaseModel):
|
||||
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: Optional[str] = None
|
||||
credential_type: Optional[str] = "api_key"
|
||||
currency: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class CreateLlmModelRequest(pydantic.BaseModel):
|
||||
slug: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
provider_id: str
|
||||
creator_id: Optional[str] = None
|
||||
context_window: int
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: bool = True
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCostInput]
|
||||
|
||||
@pydantic.field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str) -> str:
|
||||
if not v or len(v) > 100:
|
||||
raise ValueError("Slug must be 1-100 characters")
|
||||
if not SLUG_PATTERN.match(v):
|
||||
raise ValueError(
|
||||
"Slug must start with alphanumeric and contain only "
|
||||
"alphanumeric characters, dots, underscores, slashes, or hyphens"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class UpdateLlmModelRequest(pydantic.BaseModel):
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
context_window: Optional[int] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: Optional[bool] = None
|
||||
capabilities: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
provider_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
costs: Optional[list[LlmModelCostInput]] = None
|
||||
|
||||
|
||||
class ToggleLlmModelRequest(pydantic.BaseModel):
|
||||
is_enabled: bool
|
||||
migrate_to_slug: Optional[str] = None
|
||||
migration_reason: Optional[str] = None # e.g., "Provider outage"
|
||||
# Custom pricing override for migrated workflows. When set, billing should use
|
||||
# this cost instead of the target model's cost for affected nodes.
|
||||
# See LlmModelMigration in schema.prisma for full documentation.
|
||||
custom_credit_cost: Optional[int] = None
|
||||
|
||||
|
||||
class ToggleLlmModelResponse(pydantic.BaseModel):
|
||||
model: LlmModel
|
||||
nodes_migrated: int = 0
|
||||
migrated_to_slug: Optional[str] = None
|
||||
migration_id: Optional[str] = None # ID of the migration record for revert
|
||||
|
||||
|
||||
class DeleteLlmModelResponse(pydantic.BaseModel):
|
||||
deleted_model_slug: str
|
||||
deleted_model_display_name: str
|
||||
replacement_model_slug: Optional[str] = None
|
||||
nodes_migrated: int
|
||||
message: str
|
||||
|
||||
|
||||
class LlmModelUsageResponse(pydantic.BaseModel):
|
||||
model_slug: str
|
||||
node_count: int
|
||||
|
||||
|
||||
# Migration tracking models
|
||||
class LlmModelMigration(pydantic.BaseModel):
|
||||
id: str
|
||||
source_model_slug: str
|
||||
target_model_slug: str
|
||||
reason: Optional[str] = None
|
||||
node_count: int
|
||||
# Custom pricing override - billing should use this instead of target model's cost
|
||||
custom_credit_cost: Optional[int] = None
|
||||
is_reverted: bool = False
|
||||
created_at: datetime
|
||||
reverted_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class LlmMigrationsResponse(pydantic.BaseModel):
|
||||
migrations: list[LlmModelMigration]
|
||||
|
||||
|
||||
class RevertMigrationRequest(pydantic.BaseModel):
|
||||
re_enable_source_model: bool = (
|
||||
True # Whether to re-enable the source model if disabled
|
||||
)
|
||||
|
||||
|
||||
class RevertMigrationResponse(pydantic.BaseModel):
|
||||
migration_id: str
|
||||
source_model_slug: str
|
||||
target_model_slug: str
|
||||
nodes_reverted: int
|
||||
nodes_already_changed: int = (
|
||||
0 # Nodes that were modified since migration (not reverted)
|
||||
)
|
||||
source_model_re_enabled: bool = False # Whether the source model was re-enabled
|
||||
message: str
|
||||
|
||||
|
||||
class SetRecommendedModelRequest(pydantic.BaseModel):
|
||||
model_id: str
|
||||
|
||||
|
||||
class SetRecommendedModelResponse(pydantic.BaseModel):
|
||||
model: LlmModel
|
||||
previous_recommended_slug: Optional[str] = None
|
||||
message: str
|
||||
|
||||
|
||||
class RecommendedModelResponse(pydantic.BaseModel):
|
||||
model: Optional[LlmModel] = None
|
||||
slug: Optional[str] = None
|
||||
@@ -1,29 +0,0 @@
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/llm",
|
||||
tags=["llm"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models(
|
||||
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = fastapi.Query(
|
||||
default=50, ge=1, le=100, description="Number of models per page"
|
||||
),
|
||||
):
|
||||
"""List all enabled LLM models available to users."""
|
||||
return await llm_db.list_models(enabled_only=True, page=page, page_size=page_size)
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
"""List all LLM providers with their enabled models."""
|
||||
providers = await llm_db.list_providers(include_models=True, enabled_only=True)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
@@ -350,6 +350,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Whether to mark failed scans as clean or not",
|
||||
)
|
||||
|
||||
agentgenerator_host: str = Field(
|
||||
default="",
|
||||
description="The host for the Agent Generator service (empty to use built-in)",
|
||||
)
|
||||
agentgenerator_port: int = Field(
|
||||
default=8000,
|
||||
description="The port for the Agent Generator service",
|
||||
)
|
||||
agentgenerator_timeout: int = Field(
|
||||
default=120,
|
||||
description="The timeout in seconds for Agent Generator service requests",
|
||||
)
|
||||
|
||||
enable_example_blocks: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable example blocks in production",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
@@ -58,6 +59,11 @@ class SpinTestServer:
|
||||
self.db_api.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.notif_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
# Give services time to fully shut down
|
||||
# This prevents event loop issues where services haven't fully cleaned up
|
||||
# before the next test starts
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
def setup_dependency_overrides(self):
|
||||
# Override get_user_id for testing
|
||||
self.agent_server.set_test_dependency_overrides(
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmProvider" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"defaultCredentialProvider" TEXT,
|
||||
"defaultCredentialId" TEXT,
|
||||
"defaultCredentialType" TEXT,
|
||||
"supportsTools" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"supportsReasoning" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
"supportsParallelTool" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
|
||||
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id"),
|
||||
CONSTRAINT "LlmProvider_name_key" UNIQUE ("name")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModel" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"slug" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"providerId" TEXT NOT NULL,
|
||||
"contextWindow" INTEGER NOT NULL,
|
||||
"maxOutputTokens" INTEGER,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
|
||||
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id"),
|
||||
CONSTRAINT "LlmModel_slug_key" UNIQUE ("slug")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCost" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
|
||||
"creditCost" INTEGER NOT NULL,
|
||||
"credentialProvider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"credentialType" TEXT,
|
||||
"currency" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
"llmModelId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_slug_idx" ON "LlmModel"("slug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelCost_llmModelId_idx" ON "LlmModelCost"("llmModelId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelCost_credentialProvider_idx" ON "LlmModelCost"("credentialProvider");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModelCost_llmModelId_credentialProvider_unit_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
-- Seed LLM Registry from existing hard-coded data
|
||||
-- This migration populates the LlmProvider, LlmModel, and LlmModelCost tables
|
||||
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
|
||||
|
||||
-- Insert Providers
|
||||
INSERT INTO "LlmProvider" ("id", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "supportsTools", "supportsJsonOutput", "supportsReasoning", "supportsParallelTool", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', true, true, true, true, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', true, true, true, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', true, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', true, true, false, false, '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Models (using CTEs to reference provider IDs)
|
||||
WITH provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
model_slug,
|
||||
model_display_name,
|
||||
NULL,
|
||||
p."id",
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM (VALUES
|
||||
-- OpenAI models
|
||||
('o3', 'O3', 'openai', 200000, 100000),
|
||||
('o3-mini', 'O3 Mini', 'openai', 200000, 100000),
|
||||
('o1', 'O1', 'openai', 200000, 100000),
|
||||
('o1-mini', 'O1 Mini', 'openai', 128000, 65536),
|
||||
('gpt-5-2025-08-07', 'GPT 5', 'openai', 400000, 128000),
|
||||
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 400000, 128000),
|
||||
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 400000, 128000),
|
||||
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 400000, 128000),
|
||||
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 400000, 16384),
|
||||
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 1000000, 32768),
|
||||
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 1047576, 32768),
|
||||
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 128000, 16384),
|
||||
('gpt-4o', 'GPT 4o', 'openai', 128000, 16384),
|
||||
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 128000, 4096),
|
||||
('gpt-3.5-turbo', 'GPT 3.5 Turbo', 'openai', 16385, 4096),
|
||||
-- Anthropic models
|
||||
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 200000, 32000),
|
||||
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 200000, 32000),
|
||||
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 200000, 64000),
|
||||
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 200000, 64000),
|
||||
('claude-3-7-sonnet-20250219', 'Claude 3.7 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 200000, 4096),
|
||||
-- AI/ML API models
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 32000, 8000),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 128000, 40000),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 128000, NULL),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 131000, 2000),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 128000, NULL),
|
||||
-- Groq models
|
||||
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 128000, 32768),
|
||||
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 128000, 8192),
|
||||
-- Ollama models
|
||||
('llama3.3', 'Llama 3.3', 'ollama', 8192, NULL),
|
||||
('llama3.2', 'Llama 3.2', 'ollama', 8192, NULL),
|
||||
('llama3', 'Llama 3', 'ollama', 8192, NULL),
|
||||
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 8192, NULL),
|
||||
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 32768, NULL),
|
||||
-- OpenRouter models
|
||||
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 1050000, 8192),
|
||||
('google/gemini-3-pro-preview', 'Gemini 3 Pro Preview', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 1048576, 8192),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 1048576, 8192),
|
||||
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 128000, 4096),
|
||||
('cohere/command-r-08-2024', 'Command R', 'open_router', 128000, 4096),
|
||||
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 128000, 4096),
|
||||
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 64000, 2048),
|
||||
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 163840, 163840),
|
||||
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 127000, 8000),
|
||||
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 200000, 8000),
|
||||
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 128000, 16000),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 131000, 4096),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 12288, 12288),
|
||||
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 131072, 131072),
|
||||
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 131072, 32768),
|
||||
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 300000, 5120),
|
||||
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 128000, 5120),
|
||||
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 300000, 5120),
|
||||
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 65536, 4096),
|
||||
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 4096, 4096),
|
||||
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 131072, 131072),
|
||||
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 1048576, 1000000),
|
||||
('x-ai/grok-4', 'Grok 4', 'open_router', 256000, 256000),
|
||||
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 2000000, 30000),
|
||||
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 2000000, 30000),
|
||||
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 256000, 10000),
|
||||
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 131000, 131000),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 262144, 262144),
|
||||
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 262144, 262144),
|
||||
-- Llama API models
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 128000, 4028),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 128000, 4028),
|
||||
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 128000, 4028),
|
||||
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 128000, 4028),
|
||||
-- v0 models
|
||||
('v0-1.5-md', 'v0 1.5 MD', 'v0', 128000, 64000),
|
||||
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 512000, 64000),
|
||||
('v0-1.0-md', 'v0 1.0 MD', 'v0', 128000, 64000)
|
||||
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
|
||||
JOIN provider_ids p ON p."name" = models.provider_name
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Insert Costs (using CTEs to reference model IDs)
|
||||
WITH model_ids AS (
|
||||
SELECT "id", "slug", "providerId" FROM "LlmModel"
|
||||
),
|
||||
provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
'RUN'::"LlmCostUnit",
|
||||
cost,
|
||||
p."name",
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM (VALUES
|
||||
-- OpenAI costs
|
||||
('o3', 4),
|
||||
('o3-mini', 2),
|
||||
('o1', 16),
|
||||
('o1-mini', 4),
|
||||
('gpt-5-2025-08-07', 2),
|
||||
('gpt-5.1-2025-11-13', 5),
|
||||
('gpt-5-mini-2025-08-07', 1),
|
||||
('gpt-5-nano-2025-08-07', 1),
|
||||
('gpt-5-chat-latest', 5),
|
||||
('gpt-4.1-2025-04-14', 2),
|
||||
('gpt-4.1-mini-2025-04-14', 1),
|
||||
('gpt-4o-mini', 1),
|
||||
('gpt-4o', 3),
|
||||
('gpt-4-turbo', 10),
|
||||
('gpt-3.5-turbo', 1),
|
||||
-- Anthropic costs
|
||||
('claude-opus-4-1-20250805', 21),
|
||||
('claude-opus-4-20250514', 21),
|
||||
('claude-sonnet-4-20250514', 5),
|
||||
('claude-haiku-4-5-20251001', 4),
|
||||
('claude-opus-4-5-20251101', 14),
|
||||
('claude-sonnet-4-5-20250929', 9),
|
||||
('claude-3-7-sonnet-20250219', 5),
|
||||
('claude-3-haiku-20240307', 1),
|
||||
-- AI/ML API costs
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
|
||||
-- Groq costs
|
||||
('llama-3.3-70b-versatile', 1),
|
||||
('llama-3.1-8b-instant', 1),
|
||||
-- Ollama costs
|
||||
('llama3.3', 1),
|
||||
('llama3.2', 1),
|
||||
('llama3', 1),
|
||||
('llama3.1:405b', 1),
|
||||
('dolphin-mistral:latest', 1),
|
||||
-- OpenRouter costs
|
||||
('google/gemini-2.5-pro-preview-03-25', 4),
|
||||
('google/gemini-3-pro-preview', 5),
|
||||
('mistralai/mistral-nemo', 1),
|
||||
('cohere/command-r-08-2024', 1),
|
||||
('cohere/command-r-plus-08-2024', 3),
|
||||
('deepseek/deepseek-chat', 2),
|
||||
('perplexity/sonar', 1),
|
||||
('perplexity/sonar-pro', 5),
|
||||
('perplexity/sonar-deep-research', 10),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 1),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 1),
|
||||
('amazon/nova-lite-v1', 1),
|
||||
('amazon/nova-micro-v1', 1),
|
||||
('amazon/nova-pro-v1', 1),
|
||||
('microsoft/wizardlm-2-8x22b', 1),
|
||||
('gryphe/mythomax-l2-13b', 1),
|
||||
('meta-llama/llama-4-scout', 1),
|
||||
('meta-llama/llama-4-maverick', 1),
|
||||
('x-ai/grok-4', 9),
|
||||
('x-ai/grok-4-fast', 1),
|
||||
('x-ai/grok-4.1-fast', 1),
|
||||
('x-ai/grok-code-fast-1', 1),
|
||||
('moonshotai/kimi-k2', 1),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 1),
|
||||
('qwen/qwen3-coder', 9),
|
||||
('google/gemini-2.5-flash', 1),
|
||||
('google/gemini-2.0-flash-001', 1),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 1),
|
||||
('google/gemini-2.0-flash-lite-001', 1),
|
||||
('deepseek/deepseek-r1-0528', 1),
|
||||
('openai/gpt-oss-120b', 1),
|
||||
('openai/gpt-oss-20b', 1),
|
||||
-- Llama API costs
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
|
||||
('Llama-3.3-8B-Instruct', 1),
|
||||
('Llama-3.3-70B-Instruct', 1),
|
||||
-- v0 costs
|
||||
('v0-1.5-md', 1),
|
||||
('v0-1.5-lg', 2),
|
||||
('v0-1.0-md', 1)
|
||||
) AS costs(model_slug, cost)
|
||||
JOIN model_ids m ON m."slug" = costs.model_slug
|
||||
JOIN provider_ids p ON p."id" = m."providerId"
|
||||
ON CONFLICT ("llmModelId", "credentialProvider", "unit") DO NOTHING;
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelMigration" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"sourceModelSlug" TEXT NOT NULL,
|
||||
"targetModelSlug" TEXT NOT NULL,
|
||||
"reason" TEXT,
|
||||
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
|
||||
"nodeCount" INTEGER NOT NULL,
|
||||
"customCreditCost" INTEGER,
|
||||
"isReverted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"revertedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_sourceModelSlug_idx" ON "LlmModelMigration"("sourceModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_isReverted_idx" ON "LlmModelMigration"("isReverted");
|
||||
@@ -1,127 +0,0 @@
|
||||
-- Add LlmModelCreator table
|
||||
-- Creator represents who made/trained the model (e.g., OpenAI, Meta)
|
||||
-- This is distinct from Provider who hosts/serves the model (e.g., OpenRouter)
|
||||
|
||||
-- Create the LlmModelCreator table
|
||||
CREATE TABLE "LlmModelCreator" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"websiteUrl" TEXT,
|
||||
"logoUrl" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- Create unique index on name
|
||||
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
|
||||
|
||||
-- Add creatorId column to LlmModel
|
||||
ALTER TABLE "LlmModel" ADD COLUMN "creatorId" TEXT;
|
||||
|
||||
-- Add foreign key constraint
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
|
||||
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- Create index on creatorId
|
||||
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- Seed creators based on known model creators
|
||||
INSERT INTO "LlmModelCreator" ("id", "updatedAt", "name", "displayName", "description", "websiteUrl", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT models', 'https://openai.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude models', 'https://anthropic.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama models', 'https://ai.meta.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini models', 'https://deepmind.google', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'mistral', 'Mistral AI', 'Creator of Mistral models', 'https://mistral.ai', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command models', 'https://cohere.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek models', 'https://deepseek.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar models', 'https://perplexity.ai', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'qwen', 'Qwen (Alibaba)', 'Creator of Qwen models', 'https://qwenlm.github.io', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova models', 'https://aws.amazon.com/bedrock', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of WizardLM models', 'https://microsoft.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'moonshot', 'Moonshot AI', 'Creator of Kimi models', 'https://moonshot.cn', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nous_research', 'Nous Research', 'Creator of Hermes models', 'https://nousresearch.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 models', 'https://vercel.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cognitive_computations', 'Cognitive Computations', 'Creator of Dolphin models', 'https://erichartford.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', '{}')
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Update existing models with their creators
|
||||
-- OpenAI models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'openai')
|
||||
WHERE "slug" LIKE 'gpt-%' OR "slug" LIKE 'o1%' OR "slug" LIKE 'o3%' OR "slug" LIKE 'openai/%';
|
||||
|
||||
-- Anthropic models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'anthropic')
|
||||
WHERE "slug" LIKE 'claude-%';
|
||||
|
||||
-- Meta/Llama models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'meta')
|
||||
WHERE "slug" LIKE 'llama%' OR "slug" LIKE 'Llama%' OR "slug" LIKE 'meta-llama/%' OR "slug" LIKE '%/llama-%';
|
||||
|
||||
-- Google models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'google')
|
||||
WHERE "slug" LIKE 'google/%' OR "slug" LIKE 'gemini%';
|
||||
|
||||
-- Mistral models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'mistral')
|
||||
WHERE "slug" LIKE 'mistral%' OR "slug" LIKE 'mistralai/%';
|
||||
|
||||
-- Cohere models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cohere')
|
||||
WHERE "slug" LIKE 'cohere/%' OR "slug" LIKE 'command-%';
|
||||
|
||||
-- DeepSeek models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'deepseek')
|
||||
WHERE "slug" LIKE 'deepseek/%' OR "slug" LIKE 'deepseek-%';
|
||||
|
||||
-- Perplexity models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'perplexity')
|
||||
WHERE "slug" LIKE 'perplexity/%' OR "slug" LIKE 'sonar%';
|
||||
|
||||
-- Qwen models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'qwen')
|
||||
WHERE "slug" LIKE 'Qwen/%' OR "slug" LIKE 'qwen/%';
|
||||
|
||||
-- xAI/Grok models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'xai')
|
||||
WHERE "slug" LIKE 'x-ai/%' OR "slug" LIKE 'grok%';
|
||||
|
||||
-- Amazon models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'amazon')
|
||||
WHERE "slug" LIKE 'amazon/%' OR "slug" LIKE 'nova-%';
|
||||
|
||||
-- Microsoft models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'microsoft')
|
||||
WHERE "slug" LIKE 'microsoft/%' OR "slug" LIKE 'wizardlm%';
|
||||
|
||||
-- Moonshot models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'moonshot')
|
||||
WHERE "slug" LIKE 'moonshotai/%' OR "slug" LIKE 'kimi%';
|
||||
|
||||
-- NVIDIA models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nvidia')
|
||||
WHERE "slug" LIKE 'nvidia/%' OR "slug" LIKE '%nemotron%';
|
||||
|
||||
-- Nous Research models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nous_research')
|
||||
WHERE "slug" LIKE 'nousresearch/%' OR "slug" LIKE 'hermes%';
|
||||
|
||||
-- Vercel/v0 models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'vercel')
|
||||
WHERE "slug" LIKE 'v0-%';
|
||||
|
||||
-- Dolphin models (Cognitive Computations / Eric Hartford)
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cognitive_computations')
|
||||
WHERE "slug" LIKE 'dolphin-%';
|
||||
|
||||
-- Gryphe models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'gryphe')
|
||||
WHERE "slug" LIKE 'gryphe/%' OR "slug" LIKE 'mythomax%';
|
||||
@@ -1,4 +0,0 @@
|
||||
-- CreateIndex
|
||||
-- Index for efficient LLM model lookups on AgentNode.constantInput->>'model'
|
||||
-- This improves performance of model migration queries in the LLM registry
|
||||
CREATE INDEX "AgentNode_constantInput_model_idx" ON "AgentNode" ((("constantInput"->>'model')));
|
||||
@@ -1,52 +0,0 @@
|
||||
-- Add GPT-5.2 model and update O3 slug
|
||||
-- This migration adds the new GPT-5.2 model added in dev branch
|
||||
|
||||
-- Update O3 slug to match dev branch format
|
||||
UPDATE "LlmModel"
|
||||
SET "slug" = 'o3-2025-04-16'
|
||||
WHERE "slug" = 'o3';
|
||||
|
||||
-- Update cost reference for O3 if needed
|
||||
-- (costs are linked by model ID, so no update needed)
|
||||
|
||||
-- Add GPT-5.2 model
|
||||
WITH provider_id AS (
|
||||
SELECT "id" FROM "LlmProvider" WHERE "name" = 'openai'
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
'gpt-5.2-2025-12-11',
|
||||
'GPT 5.2',
|
||||
'OpenAI GPT-5.2 model',
|
||||
p."id",
|
||||
400000,
|
||||
128000,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM provider_id p
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Add cost for GPT-5.2
|
||||
WITH model_id AS (
|
||||
SELECT m."id", p."name" as provider_name
|
||||
FROM "LlmModel" m
|
||||
JOIN "LlmProvider" p ON p."id" = m."providerId"
|
||||
WHERE m."slug" = 'gpt-5.2-2025-12-11'
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
'RUN'::"LlmCostUnit",
|
||||
3, -- Same cost tier as GPT-5.1
|
||||
m.provider_name,
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM model_id m
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM "LlmModelCost" c WHERE c."llmModelId" = m."id"
|
||||
);
|
||||
@@ -1,11 +0,0 @@
|
||||
-- Add isRecommended field to LlmModel table
|
||||
-- This allows admins to mark a model as the recommended default
|
||||
|
||||
ALTER TABLE "LlmModel" ADD COLUMN "isRecommended" BOOLEAN NOT NULL DEFAULT false;
|
||||
|
||||
-- Set gpt-4o-mini as the default recommended model (if it exists)
|
||||
UPDATE "LlmModel" SET "isRecommended" = true WHERE "slug" = 'gpt-4o-mini' AND "isEnabled" = true;
|
||||
|
||||
-- Create unique partial index to enforce only one recommended model at the database level
|
||||
-- This prevents multiple rows from having isRecommended = true
|
||||
CREATE UNIQUE INDEX "LlmModel_single_recommended_idx" ON "LlmModel" ("isRecommended") WHERE "isRecommended" = true;
|
||||
@@ -0,0 +1,7 @@
|
||||
-- Remove NodeExecution foreign key from PendingHumanReview
|
||||
-- The nodeExecId column remains as the primary key, but we remove the FK constraint
|
||||
-- to AgentNodeExecution since PendingHumanReview records can persist after node
|
||||
-- execution records are deleted.
|
||||
|
||||
-- Drop foreign key constraint that linked PendingHumanReview.nodeExecId to AgentNodeExecution.id
|
||||
ALTER TABLE "PendingHumanReview" DROP CONSTRAINT IF EXISTS "PendingHumanReview_nodeExecId_fkey";
|
||||
@@ -1,61 +0,0 @@
|
||||
-- Add new columns to LlmModel table for extended model metadata
|
||||
-- These columns support the LLM Picker UI enhancements
|
||||
|
||||
-- Add priceTier column: 1=cheapest, 2=medium, 3=expensive
|
||||
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "priceTier" INTEGER NOT NULL DEFAULT 1;
|
||||
|
||||
-- Add creatorId column for model creator relationship (if not exists)
|
||||
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "creatorId" TEXT;
|
||||
|
||||
-- Add isRecommended column (if not exists)
|
||||
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "isRecommended" BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Add index on creatorId if not exists
|
||||
CREATE INDEX IF NOT EXISTS "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- Add foreign key for creatorId if not exists
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'LlmModel_creatorId_fkey') THEN
|
||||
-- Only add FK if LlmModelCreator table exists
|
||||
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'LlmModelCreator') THEN
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
|
||||
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
END IF;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
-- Update priceTier values for existing models based on original MODEL_METADATA
|
||||
-- Tier 1 = cheapest, Tier 2 = medium, Tier 3 = expensive
|
||||
|
||||
-- OpenAI models
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o3';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'o3-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'o1';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o1-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-5.2';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5.1';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-nano';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5-chat-latest';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" LIKE 'gpt-4.1%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-4o-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-4o';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-4-turbo';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-3.5-turbo';
|
||||
|
||||
-- Anthropic models
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude-opus%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude-sonnet%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude%-4-5-sonnet%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude%-haiku%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'claude-3-haiku-20240307';
|
||||
|
||||
-- OpenRouter models - Pro/expensive tiers
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'google/gemini%-pro%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%command-r-plus%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%sonar-pro%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%sonar-deep-research%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'x-ai/grok-4';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%qwen3-coder%';
|
||||
@@ -517,8 +517,6 @@ model AgentNodeExecution {
|
||||
|
||||
stats Json?
|
||||
|
||||
PendingHumanReview PendingHumanReview?
|
||||
|
||||
@@index([agentGraphExecutionId, agentNodeId, executionStatus])
|
||||
@@index([agentNodeId, executionStatus])
|
||||
@@index([addedTime, queuedTime])
|
||||
@@ -567,6 +565,7 @@ enum ReviewStatus {
|
||||
}
|
||||
|
||||
// Pending human reviews for Human-in-the-loop blocks
|
||||
// Also stores auto-approval records with special nodeExecId patterns (e.g., "auto_approve_{graph_exec_id}_{node_id}")
|
||||
model PendingHumanReview {
|
||||
nodeExecId String @id
|
||||
userId String
|
||||
@@ -585,7 +584,6 @@ model PendingHumanReview {
|
||||
reviewedAt DateTime?
|
||||
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
NodeExecution AgentNodeExecution @relation(fields: [nodeExecId], references: [id], onDelete: Cascade)
|
||||
GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([nodeExecId]) // One pending review per node execution
|
||||
@@ -1096,153 +1094,6 @@ enum APIKeyStatus {
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
///////////// LLM REGISTRY AND BILLING DATA /////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// LlmCostUnit: Defines how LLM MODEL costs are calculated (per run or per token).
|
||||
// This is distinct from BlockCostType (in backend/data/block.py) which defines
|
||||
// how BLOCK EXECUTION costs are calculated (per run, per byte, or per second).
|
||||
// LlmCostUnit is for pricing individual LLM model API calls in the registry,
|
||||
// while BlockCostType is for billing platform block executions.
|
||||
enum LlmCostUnit {
|
||||
RUN
|
||||
TOKENS
|
||||
}
|
||||
|
||||
model LlmModelCreator {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique // e.g., "openai", "anthropic", "meta"
|
||||
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
|
||||
description String?
|
||||
websiteUrl String? // Link to creator's website
|
||||
logoUrl String? // URL to creator's logo
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
}
|
||||
|
||||
model LlmProvider {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
defaultCredentialProvider String?
|
||||
defaultCredentialId String?
|
||||
defaultCredentialType String?
|
||||
|
||||
supportsTools Boolean @default(true)
|
||||
supportsJsonOutput Boolean @default(true)
|
||||
supportsReasoning Boolean @default(false)
|
||||
supportsParallelTool Boolean @default(false)
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
}
|
||||
|
||||
model LlmModel {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
slug String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
providerId String
|
||||
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
|
||||
|
||||
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
|
||||
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
|
||||
creatorId String?
|
||||
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
|
||||
|
||||
contextWindow Int
|
||||
maxOutputTokens Int?
|
||||
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive
|
||||
isEnabled Boolean @default(true)
|
||||
isRecommended Boolean @default(false)
|
||||
|
||||
capabilities Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
|
||||
Costs LlmModelCost[]
|
||||
|
||||
@@index([providerId, isEnabled])
|
||||
@@index([creatorId])
|
||||
@@index([slug])
|
||||
}
|
||||
|
||||
model LlmModelCost {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
unit LlmCostUnit @default(RUN)
|
||||
|
||||
creditCost Int
|
||||
|
||||
credentialProvider String
|
||||
credentialId String?
|
||||
credentialType String?
|
||||
currency String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
llmModelId String
|
||||
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([llmModelId, credentialProvider, unit])
|
||||
@@index([llmModelId])
|
||||
@@index([credentialProvider])
|
||||
}
|
||||
|
||||
// Tracks model migrations for revert capability
|
||||
// When a model is disabled with migration, we record which nodes were affected
|
||||
// so they can be reverted when the original model is back online
|
||||
model LlmModelMigration {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
sourceModelSlug String // The original model that was disabled
|
||||
targetModelSlug String // The model workflows were migrated to
|
||||
reason String? // Why the migration happened (e.g., "Provider outage")
|
||||
|
||||
// Track affected nodes as JSON array of node IDs
|
||||
// Format: ["node-uuid-1", "node-uuid-2", ...]
|
||||
migratedNodeIds Json @default("[]")
|
||||
nodeCount Int // Number of nodes migrated
|
||||
|
||||
// Custom pricing override for migrated workflows during the migration period.
|
||||
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
|
||||
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
|
||||
// to avoid billing surprises, or offer a discount during the transition.
|
||||
//
|
||||
// IMPORTANT: This field is intended for integration with the billing system.
|
||||
// When billing calculates costs for nodes affected by this migration, it should
|
||||
// check if customCreditCost is set and use it instead of the target model's cost.
|
||||
// If null, the target model's normal cost applies.
|
||||
//
|
||||
// TODO: Integrate with billing system to apply this override during cost calculation.
|
||||
customCreditCost Int?
|
||||
|
||||
// Revert tracking
|
||||
isReverted Boolean @default(false)
|
||||
revertedAt DateTime?
|
||||
|
||||
@@index([sourceModelSlug])
|
||||
@@index([targetModelSlug])
|
||||
@@index([isReverted])
|
||||
}
|
||||
////////////// OAUTH PROVIDER TABLES //////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -34,7 +34,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Default output directory relative to repo root
|
||||
DEFAULT_OUTPUT_DIR = (
|
||||
Path(__file__).parent.parent.parent.parent / "docs" / "integrations"
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "docs"
|
||||
/ "integrations"
|
||||
/ "block-integrations"
|
||||
)
|
||||
|
||||
|
||||
@@ -421,6 +424,14 @@ def generate_block_markdown(
|
||||
lines.append("<!-- END MANUAL -->")
|
||||
lines.append("")
|
||||
|
||||
# Optional per-block extras (only include if has content)
|
||||
extras = manual_content.get("extras", "")
|
||||
if extras:
|
||||
lines.append("<!-- MANUAL: extras -->")
|
||||
lines.append(extras)
|
||||
lines.append("<!-- END MANUAL -->")
|
||||
lines.append("")
|
||||
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
|
||||
@@ -456,25 +467,52 @@ def get_block_file_mapping(blocks: list[BlockDoc]) -> dict[str, list[BlockDoc]]:
|
||||
return dict(file_mapping)
|
||||
|
||||
|
||||
def generate_overview_table(blocks: list[BlockDoc]) -> str:
|
||||
"""Generate the overview table markdown (blocks.md)."""
|
||||
def generate_overview_table(blocks: list[BlockDoc], block_dir_prefix: str = "") -> str:
|
||||
"""Generate the overview table markdown (blocks.md).
|
||||
|
||||
Args:
|
||||
blocks: List of block documentation objects
|
||||
block_dir_prefix: Prefix for block file links (e.g., "block-integrations/")
|
||||
"""
|
||||
lines = []
|
||||
|
||||
# GitBook YAML frontmatter
|
||||
lines.append("---")
|
||||
lines.append("layout:")
|
||||
lines.append(" width: default")
|
||||
lines.append(" title:")
|
||||
lines.append(" visible: true")
|
||||
lines.append(" description:")
|
||||
lines.append(" visible: true")
|
||||
lines.append(" tableOfContents:")
|
||||
lines.append(" visible: false")
|
||||
lines.append(" outline:")
|
||||
lines.append(" visible: true")
|
||||
lines.append(" pagination:")
|
||||
lines.append(" visible: true")
|
||||
lines.append(" metadata:")
|
||||
lines.append(" visible: true")
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
|
||||
lines.append("# AutoGPT Blocks Overview")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
'AutoGPT uses a modular approach with various "blocks" to handle different tasks. These blocks are the building blocks of AutoGPT workflows, allowing users to create complex automations by combining simple, specialized components.'
|
||||
)
|
||||
lines.append("")
|
||||
lines.append('!!! info "Creating Your Own Blocks"')
|
||||
lines.append(" Want to create your own custom blocks? Check out our guides:")
|
||||
lines.append(" ")
|
||||
lines.append('{% hint style="info" %}')
|
||||
lines.append("**Creating Your Own Blocks**")
|
||||
lines.append("")
|
||||
lines.append("Want to create your own custom blocks? Check out our guides:")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
" - [Build your own Blocks](https://docs.agpt.co/platform/new_blocks/) - Step-by-step tutorial with examples"
|
||||
"* [Build your own Blocks](https://docs.agpt.co/platform/new_blocks/) - Step-by-step tutorial with examples"
|
||||
)
|
||||
lines.append(
|
||||
" - [Block SDK Guide](https://docs.agpt.co/platform/block-sdk-guide/) - Advanced SDK patterns with OAuth, webhooks, and provider configuration"
|
||||
"* [Block SDK Guide](https://docs.agpt.co/platform/block-sdk-guide/) - Advanced SDK patterns with OAuth, webhooks, and provider configuration"
|
||||
)
|
||||
lines.append("{% endhint %}")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"Below is a comprehensive list of all available blocks, categorized by their primary function. Click on any block name to view its detailed documentation."
|
||||
@@ -537,7 +575,8 @@ def generate_overview_table(blocks: list[BlockDoc]) -> str:
|
||||
else "No description"
|
||||
)
|
||||
short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
|
||||
lines.append(f"| [{block.name}]({file_path}#{anchor}) | {short_desc} |")
|
||||
link_path = f"{block_dir_prefix}{file_path}"
|
||||
lines.append(f"| [{block.name}]({link_path}#{anchor}) | {short_desc} |")
|
||||
lines.append("")
|
||||
continue
|
||||
|
||||
@@ -563,13 +602,55 @@ def generate_overview_table(blocks: list[BlockDoc]) -> str:
|
||||
)
|
||||
short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
|
||||
|
||||
lines.append(f"| [{block.name}]({file_path}#{anchor}) | {short_desc} |")
|
||||
link_path = f"{block_dir_prefix}{file_path}"
|
||||
lines.append(f"| [{block.name}]({link_path}#{anchor}) | {short_desc} |")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_summary_md(
|
||||
blocks: list[BlockDoc], root_dir: Path, block_dir_prefix: str = ""
|
||||
) -> str:
|
||||
"""Generate SUMMARY.md for GitBook navigation.
|
||||
|
||||
Args:
|
||||
blocks: List of block documentation objects
|
||||
root_dir: The root docs directory (e.g., docs/integrations/)
|
||||
block_dir_prefix: Prefix for block file links (e.g., "block-integrations/")
|
||||
"""
|
||||
lines = []
|
||||
lines.append("# Table of contents")
|
||||
lines.append("")
|
||||
lines.append("* [AutoGPT Blocks Overview](README.md)")
|
||||
lines.append("")
|
||||
|
||||
# Check for guides/ directory at the root level (docs/integrations/guides/)
|
||||
guides_dir = root_dir / "guides"
|
||||
if guides_dir.exists():
|
||||
lines.append("## Guides")
|
||||
lines.append("")
|
||||
for guide_file in sorted(guides_dir.glob("*.md")):
|
||||
# Use just the file name for title (replace hyphens/underscores with spaces)
|
||||
title = file_path_to_title(guide_file.stem.replace("-", "_") + ".md")
|
||||
lines.append(f"* [{title}](guides/{guide_file.name})")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Block Integrations")
|
||||
lines.append("")
|
||||
|
||||
file_mapping = get_block_file_mapping(blocks)
|
||||
for file_path in sorted(file_mapping.keys()):
|
||||
title = file_path_to_title(file_path)
|
||||
link_path = f"{block_dir_prefix}{file_path}"
|
||||
lines.append(f"* [{title}]({link_path})")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def load_all_blocks_for_docs() -> list[BlockDoc]:
|
||||
"""Load all blocks and extract documentation."""
|
||||
from backend.blocks import load_all_blocks
|
||||
@@ -653,6 +734,16 @@ def write_block_docs(
|
||||
)
|
||||
)
|
||||
|
||||
# Add file-level additional_content section if present
|
||||
file_additional = extract_manual_content(existing_content).get(
|
||||
"additional_content", ""
|
||||
)
|
||||
if file_additional:
|
||||
content_parts.append("<!-- MANUAL: additional_content -->")
|
||||
content_parts.append(file_additional)
|
||||
content_parts.append("<!-- END MANUAL -->")
|
||||
content_parts.append("")
|
||||
|
||||
full_content = file_header + "\n" + "\n".join(content_parts)
|
||||
generated_files[str(file_path)] = full_content
|
||||
|
||||
@@ -661,14 +752,28 @@ def write_block_docs(
|
||||
|
||||
full_path.write_text(full_content)
|
||||
|
||||
# Generate overview file
|
||||
overview_content = generate_overview_table(blocks)
|
||||
overview_path = output_dir / "README.md"
|
||||
# Generate overview file at the parent directory (docs/integrations/)
|
||||
# with links prefixed to point into block-integrations/
|
||||
root_dir = output_dir.parent
|
||||
block_dir_name = output_dir.name # "block-integrations"
|
||||
block_dir_prefix = f"{block_dir_name}/"
|
||||
|
||||
overview_content = generate_overview_table(blocks, block_dir_prefix)
|
||||
overview_path = root_dir / "README.md"
|
||||
generated_files["README.md"] = overview_content
|
||||
overview_path.write_text(overview_content)
|
||||
|
||||
if verbose:
|
||||
print(" Writing README.md (overview)")
|
||||
print(" Writing README.md (overview) to parent directory")
|
||||
|
||||
# Generate SUMMARY.md for GitBook navigation at the parent directory
|
||||
summary_content = generate_summary_md(blocks, root_dir, block_dir_prefix)
|
||||
summary_path = root_dir / "SUMMARY.md"
|
||||
generated_files["SUMMARY.md"] = summary_content
|
||||
summary_path.write_text(summary_content)
|
||||
|
||||
if verbose:
|
||||
print(" Writing SUMMARY.md (navigation) to parent directory")
|
||||
|
||||
return generated_files
|
||||
|
||||
@@ -748,6 +853,16 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
|
||||
elif block_match.group(1).strip() != expected_block_content.strip():
|
||||
mismatched_blocks.append(block.name)
|
||||
|
||||
# Add file-level additional_content to expected content (matches write_block_docs)
|
||||
file_additional = extract_manual_content(existing_content).get(
|
||||
"additional_content", ""
|
||||
)
|
||||
if file_additional:
|
||||
content_parts.append("<!-- MANUAL: additional_content -->")
|
||||
content_parts.append(file_additional)
|
||||
content_parts.append("<!-- END MANUAL -->")
|
||||
content_parts.append("")
|
||||
|
||||
expected_content = file_header + "\n" + "\n".join(content_parts)
|
||||
|
||||
if existing_content.strip() != expected_content.strip():
|
||||
@@ -757,11 +872,15 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
|
||||
out_of_sync_details.append((file_path, mismatched_blocks))
|
||||
all_match = False
|
||||
|
||||
# Check overview
|
||||
overview_path = output_dir / "README.md"
|
||||
# Check overview at the parent directory (docs/integrations/)
|
||||
root_dir = output_dir.parent
|
||||
block_dir_name = output_dir.name # "block-integrations"
|
||||
block_dir_prefix = f"{block_dir_name}/"
|
||||
|
||||
overview_path = root_dir / "README.md"
|
||||
if overview_path.exists():
|
||||
existing_overview = overview_path.read_text()
|
||||
expected_overview = generate_overview_table(blocks)
|
||||
expected_overview = generate_overview_table(blocks, block_dir_prefix)
|
||||
if existing_overview.strip() != expected_overview.strip():
|
||||
print("OUT OF SYNC: README.md (overview)")
|
||||
print(" The blocks overview table needs regeneration")
|
||||
@@ -772,6 +891,21 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
|
||||
out_of_sync_details.append(("README.md", ["overview table"]))
|
||||
all_match = False
|
||||
|
||||
# Check SUMMARY.md at the parent directory
|
||||
summary_path = root_dir / "SUMMARY.md"
|
||||
if summary_path.exists():
|
||||
existing_summary = summary_path.read_text()
|
||||
expected_summary = generate_summary_md(blocks, root_dir, block_dir_prefix)
|
||||
if existing_summary.strip() != expected_summary.strip():
|
||||
print("OUT OF SYNC: SUMMARY.md (navigation)")
|
||||
print(" The GitBook navigation needs regeneration")
|
||||
out_of_sync_details.append(("SUMMARY.md", ["navigation"]))
|
||||
all_match = False
|
||||
else:
|
||||
print("MISSING: SUMMARY.md (navigation)")
|
||||
out_of_sync_details.append(("SUMMARY.md", ["navigation"]))
|
||||
all_match = False
|
||||
|
||||
# Check for unfilled manual sections
|
||||
unfilled_patterns = [
|
||||
"_Add a description of this category of blocks._",
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for agent generator module."""
|
||||
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Tests for the Agent Generator core module.
|
||||
|
||||
This test suite verifies that the core functions correctly delegate to
|
||||
the external Agent Generator service.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
from backend.api.features.chat.tools.agent_generator.core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
)
|
||||
|
||||
|
||||
class TestServiceNotConfigured:
|
||||
"""Test that functions raise AgentGeneratorNotConfiguredError when service is not configured."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_raises_when_not_configured(self):
|
||||
"""Test that decompose_goal raises error when service not configured."""
|
||||
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||
await core.decompose_goal("Build a chatbot")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_raises_when_not_configured(self):
|
||||
"""Test that generate_agent raises error when service not configured."""
|
||||
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||
await core.generate_agent({"steps": []})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_patch_raises_when_not_configured(self):
|
||||
"""Test that generate_agent_patch raises error when service not configured."""
|
||||
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||
await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||
|
||||
|
||||
class TestDecomposeGoal:
|
||||
"""Test decompose_goal function service delegation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_external_service(self):
|
||||
"""Test that decompose_goal calls the external service."""
|
||||
expected_result = {"type": "instructions", "steps": ["Step 1"]}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "decompose_goal_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result
|
||||
|
||||
result = await core.decompose_goal("Build a chatbot")
|
||||
|
||||
mock_external.assert_called_once_with("Build a chatbot", "")
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_context_to_external_service(self):
|
||||
"""Test that decompose_goal passes context to external service."""
|
||||
expected_result = {"type": "instructions", "steps": ["Step 1"]}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "decompose_goal_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result
|
||||
|
||||
await core.decompose_goal("Build a chatbot", "Use Python")
|
||||
|
||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_service_failure(self):
|
||||
"""Test that decompose_goal returns None when external service fails."""
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "decompose_goal_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = None
|
||||
|
||||
result = await core.decompose_goal("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateAgent:
|
||||
"""Test generate_agent function service delegation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_external_service(self):
|
||||
"""Test that generate_agent calls the external service."""
|
||||
expected_result = {"name": "Test Agent", "nodes": [], "links": []}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result
|
||||
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
result = await core.generate_agent(instructions)
|
||||
|
||||
mock_external.assert_called_once_with(instructions)
|
||||
# Result should have id, version, is_active added if not present
|
||||
assert result is not None
|
||||
assert result["name"] == "Test Agent"
|
||||
assert "id" in result
|
||||
assert result["version"] == 1
|
||||
assert result["is_active"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_existing_id_and_version(self):
|
||||
"""Test that external service result preserves existing id and version."""
|
||||
expected_result = {
|
||||
"id": "existing-id",
|
||||
"version": 3,
|
||||
"is_active": False,
|
||||
"name": "Test Agent",
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result.copy()
|
||||
|
||||
result = await core.generate_agent({"steps": []})
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == "existing-id"
|
||||
assert result["version"] == 3
|
||||
assert result["is_active"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_external_service_fails(self):
|
||||
"""Test that generate_agent returns None when external service fails."""
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = None
|
||||
|
||||
result = await core.generate_agent({"steps": []})
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateAgentPatch:
|
||||
"""Test generate_agent_patch function service delegation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_external_service(self):
|
||||
"""Test that generate_agent_patch calls the external service."""
|
||||
expected_result = {"name": "Updated Agent", "nodes": [], "links": []}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result
|
||||
|
||||
current_agent = {"nodes": [], "links": []}
|
||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||
|
||||
mock_external.assert_called_once_with("Add a node", current_agent)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_clarifying_questions(self):
|
||||
"""Test that generate_agent_patch returns clarifying questions."""
|
||||
expected_result = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": [{"question": "What type of node?"}],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result
|
||||
|
||||
result = await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_external_service_fails(self):
|
||||
"""Test that generate_agent_patch returns None when service fails."""
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = None
|
||||
|
||||
result = await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestJsonToGraph:
|
||||
"""Test json_to_graph function."""
|
||||
|
||||
def test_converts_agent_json_to_graph(self):
|
||||
"""Test conversion of agent JSON to Graph model."""
|
||||
agent_json = {
|
||||
"id": "test-id",
|
||||
"version": 2,
|
||||
"is_active": True,
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"block_id": "block1",
|
||||
"input_default": {"key": "value"},
|
||||
"metadata": {"x": 100},
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "link1",
|
||||
"source_id": "node1",
|
||||
"sink_id": "output",
|
||||
"source_name": "result",
|
||||
"sink_name": "input",
|
||||
"is_static": False,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
graph = core.json_to_graph(agent_json)
|
||||
|
||||
assert graph.id == "test-id"
|
||||
assert graph.version == 2
|
||||
assert graph.is_active is True
|
||||
assert graph.name == "Test Agent"
|
||||
assert graph.description == "A test agent"
|
||||
assert len(graph.nodes) == 1
|
||||
assert graph.nodes[0].id == "node1"
|
||||
assert graph.nodes[0].block_id == "block1"
|
||||
assert len(graph.links) == 1
|
||||
assert graph.links[0].source_id == "node1"
|
||||
|
||||
def test_generates_ids_if_missing(self):
|
||||
"""Test that missing IDs are generated."""
|
||||
agent_json = {
|
||||
"name": "Test Agent",
|
||||
"nodes": [{"block_id": "block1"}],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
graph = core.json_to_graph(agent_json)
|
||||
|
||||
assert graph.id is not None
|
||||
assert graph.nodes[0].id is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
422
autogpt_platform/backend/test/agent_generator/test_service.py
Normal file
422
autogpt_platform/backend/test/agent_generator/test_service.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""
|
||||
Tests for the Agent Generator external service client.
|
||||
|
||||
This test suite verifies the external Agent Generator service integration,
|
||||
including service detection, API calls, and error handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import service
|
||||
|
||||
|
||||
class TestServiceConfiguration:
|
||||
"""Test service configuration detection."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset settings singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
def test_external_service_not_configured_when_host_empty(self):
|
||||
"""Test that external service is not configured when host is empty."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.config.agentgenerator_host = ""
|
||||
|
||||
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||
assert service.is_external_service_configured() is False
|
||||
|
||||
def test_external_service_configured_when_host_set(self):
|
||||
"""Test that external service is configured when host is set."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.config.agentgenerator_host = "agent-generator.local"
|
||||
|
||||
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||
assert service.is_external_service_configured() is True
|
||||
|
||||
def test_get_base_url(self):
|
||||
"""Test base URL construction."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.config.agentgenerator_host = "agent-generator.local"
|
||||
mock_settings.config.agentgenerator_port = 8000
|
||||
|
||||
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||
url = service._get_base_url()
|
||||
assert url == "http://agent-generator.local:8000"
|
||||
|
||||
|
||||
class TestDecomposeGoalExternal:
|
||||
"""Test decompose_goal_external function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_instructions(self):
|
||||
"""Test successful decomposition returning instructions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1", "Step 2"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result == {"type": "instructions", "steps": ["Step 1", "Step 2"]}
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/decompose-description", json={"description": "Build a chatbot"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_clarifying_questions(self):
|
||||
"""Test decomposition returning clarifying questions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What platform?", "What language?"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build something")
|
||||
|
||||
assert result == {
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What platform?", "What language?"],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_with_context(self):
|
||||
"""Test decomposition with additional context."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.decompose_goal_external(
|
||||
"Build a chatbot", context="Use Python"
|
||||
)
|
||||
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/decompose-description",
|
||||
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_unachievable_goal(self):
|
||||
"""Test decomposition returning unachievable goal response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "unachievable_goal",
|
||||
"reason": "Cannot do X",
|
||||
"suggested_goal": "Try Y instead",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Do something impossible")
|
||||
|
||||
assert result == {
|
||||
"type": "unachievable_goal",
|
||||
"reason": "Cannot do X",
|
||||
"suggested_goal": "Try Y instead",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_http_error(self):
|
||||
"""Test decomposition handles HTTP errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||
"Server error", request=MagicMock(), response=MagicMock()
|
||||
)
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_request_error(self):
|
||||
"""Test decomposition handles request errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_service_error(self):
|
||||
"""Test decomposition handles service returning error."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": False,
|
||||
"error": "Internal error",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateAgentExternal:
|
||||
"""Test generate_agent_external function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_success(self):
|
||||
"""Test successful agent generation."""
|
||||
agent_json = {
|
||||
"name": "Test Agent",
|
||||
"nodes": [],
|
||||
"links": [],
|
||||
}
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": agent_json,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.generate_agent_external(instructions)
|
||||
|
||||
assert result == agent_json
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/generate-agent", json={"instructions": instructions}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_handles_error(self):
|
||||
"""Test agent generation handles errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.generate_agent_external({"steps": []})
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateAgentPatchExternal:
|
||||
"""Test generate_agent_patch_external function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_patch_returns_updated_agent(self):
|
||||
"""Test successful patch generation returning updated agent."""
|
||||
updated_agent = {
|
||||
"name": "Updated Agent",
|
||||
"nodes": [{"id": "1", "block_id": "test"}],
|
||||
"links": [],
|
||||
}
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": updated_agent,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.generate_agent_patch_external(
|
||||
"Add a new node", current_agent
|
||||
)
|
||||
|
||||
assert result == updated_agent
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/update-agent",
|
||||
json={
|
||||
"update_request": "Add a new node",
|
||||
"current_agent_json": current_agent,
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_patch_returns_clarifying_questions(self):
|
||||
"""Test patch generation returning clarifying questions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What type of node?"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.generate_agent_patch_external(
|
||||
"Add something", {"nodes": []}
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What type of node?"],
|
||||
}
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Test health_check function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset singletons before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_false_when_not_configured(self):
|
||||
"""Test health check returns False when service not configured."""
|
||||
with patch.object(
|
||||
service, "is_external_service_configured", return_value=False
|
||||
):
|
||||
result = await service.health_check()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_true_when_healthy(self):
|
||||
"""Test health check returns True when service is healthy."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"status": "healthy",
|
||||
"blocks_loaded": True,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is True
|
||||
mock_client.get.assert_called_once_with("/health")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_false_when_not_healthy(self):
|
||||
"""Test health check returns False when service is not healthy."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"status": "unhealthy",
|
||||
"blocks_loaded": False,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_false_on_error(self):
|
||||
"""Test health check returns False on connection error."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetBlocksExternal:
|
||||
"""Test get_blocks_external function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_blocks_success(self):
|
||||
"""Test successful blocks retrieval."""
|
||||
blocks = [
|
||||
{"id": "block1", "name": "Block 1"},
|
||||
{"id": "block2", "name": "Block 2"},
|
||||
]
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"blocks": blocks,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.get_blocks_external()
|
||||
|
||||
assert result == blocks
|
||||
mock_client.get.assert_called_once_with("/api/blocks")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_blocks_handles_error(self):
|
||||
"""Test blocks retrieval handles errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.get_blocks_external()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,8 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
import { Cpu } from "@phosphor-icons/react";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
@@ -29,11 +26,6 @@ const sidebarLinkGroups = [
|
||||
href: "/admin/execution-analytics",
|
||||
icon: <FileText className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "LLM Registry",
|
||||
href: "/admin/llms",
|
||||
icon: <Cpu size={24} />,
|
||||
},
|
||||
{
|
||||
text: "Admin User Management",
|
||||
href: "/admin/settings",
|
||||
|
||||
@@ -1,493 +0,0 @@
|
||||
"use server";
|
||||
|
||||
import { revalidatePath } from "next/cache";
|
||||
|
||||
// Generated API functions
|
||||
import {
|
||||
getV2ListLlmProviders,
|
||||
postV2CreateLlmProvider,
|
||||
patchV2UpdateLlmProvider,
|
||||
deleteV2DeleteLlmProvider,
|
||||
getV2ListLlmModels,
|
||||
postV2CreateLlmModel,
|
||||
patchV2UpdateLlmModel,
|
||||
patchV2ToggleLlmModelAvailability,
|
||||
deleteV2DeleteLlmModelAndMigrateWorkflows,
|
||||
getV2GetModelUsageCount,
|
||||
getV2ListModelMigrations,
|
||||
postV2RevertAModelMigration,
|
||||
getV2ListModelCreators,
|
||||
postV2CreateModelCreator,
|
||||
patchV2UpdateModelCreator,
|
||||
deleteV2DeleteModelCreator,
|
||||
postV2SetRecommendedModel,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
|
||||
// Generated types
|
||||
import type { LlmProvidersResponse } from "@/app/api/__generated__/models/llmProvidersResponse";
|
||||
import type { LlmModelsResponse } from "@/app/api/__generated__/models/llmModelsResponse";
|
||||
import type { UpsertLlmProviderRequest } from "@/app/api/__generated__/models/upsertLlmProviderRequest";
|
||||
import type { CreateLlmModelRequest } from "@/app/api/__generated__/models/createLlmModelRequest";
|
||||
import type { UpdateLlmModelRequest } from "@/app/api/__generated__/models/updateLlmModelRequest";
|
||||
import type { ToggleLlmModelRequest } from "@/app/api/__generated__/models/toggleLlmModelRequest";
|
||||
import type { LlmMigrationsResponse } from "@/app/api/__generated__/models/llmMigrationsResponse";
|
||||
import type { LlmCreatorsResponse } from "@/app/api/__generated__/models/llmCreatorsResponse";
|
||||
import type { UpsertLlmCreatorRequest } from "@/app/api/__generated__/models/upsertLlmCreatorRequest";
|
||||
import type { LlmModelUsageResponse } from "@/app/api/__generated__/models/llmModelUsageResponse";
|
||||
import { LlmCostUnit } from "@/app/api/__generated__/models/llmCostUnit";
|
||||
|
||||
const ADMIN_LLM_PATH = "/admin/llms";
|
||||
|
||||
// =============================================================================
|
||||
// Utilities
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Extracts and validates a required string field from FormData.
|
||||
* Throws an error if the field is missing or empty.
|
||||
*/
|
||||
function getRequiredFormField(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): string {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = raw ? String(raw).trim() : "";
|
||||
if (!value) {
|
||||
throw new Error(`${displayName || fieldName} is required`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts and validates a required positive number field from FormData.
|
||||
* Throws an error if the field is missing, empty, or not a positive number.
|
||||
*/
|
||||
function getRequiredPositiveNumber(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): number {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = Number(raw);
|
||||
if (raw === null || raw === "" || !Number.isFinite(value) || value <= 0) {
|
||||
throw new Error(`${displayName || fieldName} must be a positive number`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts and validates a required number field from FormData.
|
||||
* Throws an error if the field is missing, empty, or not a finite number.
|
||||
*/
|
||||
function getRequiredNumber(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): number {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = Number(raw);
|
||||
if (raw === null || raw === "" || !Number.isFinite(value)) {
|
||||
throw new Error(`${displayName || fieldName} is required`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Provider Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmProviders(): Promise<LlmProvidersResponse> {
|
||||
const response = await getV2ListLlmProviders({ include_models: true });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch LLM providers");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createLlmProviderAction(formData: FormData) {
|
||||
const payload: UpsertLlmProviderRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
default_credential_provider: formData.get("default_credential_provider")
|
||||
? String(formData.get("default_credential_provider")).trim()
|
||||
: undefined,
|
||||
default_credential_id: formData.get("default_credential_id")
|
||||
? String(formData.get("default_credential_id")).trim()
|
||||
: undefined,
|
||||
default_credential_type: formData.get("default_credential_type")
|
||||
? String(formData.get("default_credential_type")).trim()
|
||||
: "api_key",
|
||||
supports_tools: formData.getAll("supports_tools").includes("on"),
|
||||
supports_json_output: formData
|
||||
.getAll("supports_json_output")
|
||||
.includes("on"),
|
||||
supports_reasoning: formData.getAll("supports_reasoning").includes("on"),
|
||||
supports_parallel_tool: formData
|
||||
.getAll("supports_parallel_tool")
|
||||
.includes("on"),
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await postV2CreateLlmProvider(payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create LLM provider");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmProviderAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const providerId = getRequiredFormField(
|
||||
formData,
|
||||
"provider_id",
|
||||
"Provider id",
|
||||
);
|
||||
|
||||
const response = await deleteV2DeleteLlmProvider(providerId);
|
||||
if (response.status !== 200) {
|
||||
const errorData = response.data as { detail?: string };
|
||||
throw new Error(errorData?.detail || "Failed to delete provider");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmProviderAction(formData: FormData) {
|
||||
const providerId = getRequiredFormField(
|
||||
formData,
|
||||
"provider_id",
|
||||
"Provider id",
|
||||
);
|
||||
|
||||
const payload: UpsertLlmProviderRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
default_credential_provider: formData.get("default_credential_provider")
|
||||
? String(formData.get("default_credential_provider")).trim()
|
||||
: undefined,
|
||||
default_credential_id: formData.get("default_credential_id")
|
||||
? String(formData.get("default_credential_id")).trim()
|
||||
: undefined,
|
||||
default_credential_type: formData.get("default_credential_type")
|
||||
? String(formData.get("default_credential_type")).trim()
|
||||
: "api_key",
|
||||
supports_tools: formData.getAll("supports_tools").includes("on"),
|
||||
supports_json_output: formData
|
||||
.getAll("supports_json_output")
|
||||
.includes("on"),
|
||||
supports_reasoning: formData.getAll("supports_reasoning").includes("on"),
|
||||
supports_parallel_tool: formData
|
||||
.getAll("supports_parallel_tool")
|
||||
.includes("on"),
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await patchV2UpdateLlmProvider(providerId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to update LLM provider");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Model Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmModels(): Promise<LlmModelsResponse> {
|
||||
const response = await getV2ListLlmModels();
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch LLM models");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createLlmModelAction(formData: FormData) {
|
||||
const providerId = getRequiredFormField(formData, "provider_id", "Provider");
|
||||
const creatorId = formData.get("creator_id");
|
||||
const contextWindow = getRequiredPositiveNumber(
|
||||
formData,
|
||||
"context_window",
|
||||
"Context window",
|
||||
);
|
||||
const creditCost = getRequiredNumber(formData, "credit_cost", "Credit cost");
|
||||
|
||||
// Fetch provider to get default credentials
|
||||
const providersResponse = await getV2ListLlmProviders({
|
||||
include_models: false,
|
||||
});
|
||||
if (providersResponse.status !== 200) {
|
||||
throw new Error("Failed to fetch providers");
|
||||
}
|
||||
const provider = providersResponse.data.providers.find(
|
||||
(p) => p.id === providerId,
|
||||
);
|
||||
|
||||
if (!provider) {
|
||||
throw new Error("Provider not found");
|
||||
}
|
||||
|
||||
const payload: CreateLlmModelRequest = {
|
||||
slug: String(formData.get("slug") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
provider_id: providerId,
|
||||
creator_id: creatorId ? String(creatorId) : undefined,
|
||||
context_window: contextWindow,
|
||||
max_output_tokens: formData.get("max_output_tokens")
|
||||
? Number(formData.get("max_output_tokens"))
|
||||
: undefined,
|
||||
is_enabled: formData.getAll("is_enabled").includes("on"),
|
||||
capabilities: {},
|
||||
metadata: {},
|
||||
costs: [
|
||||
{
|
||||
unit: (formData.get("unit") as LlmCostUnit) || LlmCostUnit.RUN,
|
||||
credit_cost: creditCost,
|
||||
credential_provider:
|
||||
provider.default_credential_provider || provider.name,
|
||||
credential_id: provider.default_credential_id || undefined,
|
||||
credential_type: provider.default_credential_type || "api_key",
|
||||
metadata: {},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await postV2CreateLlmModel(payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create LLM model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmModelAction(formData: FormData) {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
const creatorId = formData.get("creator_id");
|
||||
|
||||
const payload: UpdateLlmModelRequest = {
|
||||
display_name: formData.get("display_name")
|
||||
? String(formData.get("display_name"))
|
||||
: undefined,
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
provider_id: formData.get("provider_id")
|
||||
? String(formData.get("provider_id"))
|
||||
: undefined,
|
||||
creator_id: creatorId ? String(creatorId) : undefined,
|
||||
context_window: formData.get("context_window")
|
||||
? Number(formData.get("context_window"))
|
||||
: undefined,
|
||||
max_output_tokens: formData.get("max_output_tokens")
|
||||
? Number(formData.get("max_output_tokens"))
|
||||
: undefined,
|
||||
is_enabled: formData.has("is_enabled")
|
||||
? formData.getAll("is_enabled").includes("on")
|
||||
: undefined,
|
||||
costs: formData.get("credit_cost")
|
||||
? [
|
||||
{
|
||||
unit: (formData.get("unit") as LlmCostUnit) || LlmCostUnit.RUN,
|
||||
credit_cost: Number(formData.get("credit_cost")),
|
||||
credential_provider: String(
|
||||
formData.get("credential_provider") || "",
|
||||
).trim(),
|
||||
credential_id: formData.get("credential_id")
|
||||
? String(formData.get("credential_id"))
|
||||
: undefined,
|
||||
credential_type: formData.get("credential_type")
|
||||
? String(formData.get("credential_type"))
|
||||
: undefined,
|
||||
metadata: {},
|
||||
},
|
||||
]
|
||||
: undefined,
|
||||
};
|
||||
|
||||
const response = await patchV2UpdateLlmModel(modelId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to update LLM model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function toggleLlmModelAction(formData: FormData): Promise<void> {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
const shouldEnable = formData.get("is_enabled") === "true";
|
||||
const migrateToSlug = formData.get("migrate_to_slug");
|
||||
const migrationReason = formData.get("migration_reason");
|
||||
const customCreditCost = formData.get("custom_credit_cost");
|
||||
|
||||
const payload: ToggleLlmModelRequest = {
|
||||
is_enabled: shouldEnable,
|
||||
migrate_to_slug: migrateToSlug ? String(migrateToSlug) : undefined,
|
||||
migration_reason: migrationReason ? String(migrationReason) : undefined,
|
||||
custom_credit_cost: customCreditCost ? Number(customCreditCost) : undefined,
|
||||
};
|
||||
|
||||
const response = await patchV2ToggleLlmModelAvailability(modelId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to toggle LLM model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmModelAction(formData: FormData): Promise<void> {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
const rawReplacement = formData.get("replacement_model_slug");
|
||||
const replacementModelSlug =
|
||||
rawReplacement && String(rawReplacement).trim()
|
||||
? String(rawReplacement).trim()
|
||||
: undefined;
|
||||
|
||||
const response = await deleteV2DeleteLlmModelAndMigrateWorkflows(modelId, {
|
||||
replacement_model_slug: replacementModelSlug,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to delete model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function fetchLlmModelUsage(
|
||||
modelId: string,
|
||||
): Promise<LlmModelUsageResponse> {
|
||||
const response = await getV2GetModelUsageCount(modelId);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch model usage");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Migration Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmMigrations(
|
||||
includeReverted: boolean = false,
|
||||
): Promise<LlmMigrationsResponse> {
|
||||
const response = await getV2ListModelMigrations({
|
||||
include_reverted: includeReverted,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch migrations");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function revertLlmMigrationAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const migrationId = getRequiredFormField(
|
||||
formData,
|
||||
"migration_id",
|
||||
"Migration id",
|
||||
);
|
||||
|
||||
const response = await postV2RevertAModelMigration(migrationId, null);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to revert migration");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Creator Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmCreators(): Promise<LlmCreatorsResponse> {
|
||||
const response = await getV2ListModelCreators();
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch creators");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const payload: UpsertLlmCreatorRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
website_url: formData.get("website_url")
|
||||
? String(formData.get("website_url")).trim()
|
||||
: undefined,
|
||||
logo_url: formData.get("logo_url")
|
||||
? String(formData.get("logo_url")).trim()
|
||||
: undefined,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await postV2CreateModelCreator(payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create creator");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const creatorId = getRequiredFormField(formData, "creator_id", "Creator id");
|
||||
|
||||
const payload: UpsertLlmCreatorRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
website_url: formData.get("website_url")
|
||||
? String(formData.get("website_url")).trim()
|
||||
: undefined,
|
||||
logo_url: formData.get("logo_url")
|
||||
? String(formData.get("logo_url")).trim()
|
||||
: undefined,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await patchV2UpdateModelCreator(creatorId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to update creator");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const creatorId = getRequiredFormField(formData, "creator_id", "Creator id");
|
||||
|
||||
const response = await deleteV2DeleteModelCreator(creatorId);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to delete creator");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Recommended Model Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function setRecommendedModelAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
|
||||
const response = await postV2SetRecommendedModel({ model_id: modelId });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to set recommended model");
|
||||
}
|
||||
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { createLlmCreatorAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function AddCreatorModal() {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to create creator");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "512px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Creator
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Add a new model creator (the organization that made/trained the
|
||||
model).
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Name (slug) <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Lowercase identifier (e.g., openai, meta, anthropic)
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={2}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Creator of GPT models..."
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="website_url"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Website URL
|
||||
</label>
|
||||
<input
|
||||
id="website_url"
|
||||
name="website_url"
|
||||
type="url"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="https://openai.com"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Add Creator"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,314 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import { createLlmModelAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
interface Props {
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}
|
||||
|
||||
export function AddModelModal({ providers, creators }: Props) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [selectedCreatorId, setSelectedCreatorId] = useState("");
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to create model");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
// When provider changes, auto-select matching creator if one exists
|
||||
function handleProviderChange(providerId: string) {
|
||||
const provider = providers.find((p) => p.id === providerId);
|
||||
if (provider) {
|
||||
// Find creator with same name as provider (e.g., "openai" -> "openai")
|
||||
const matchingCreator = creators.find((c) => c.name === provider.name);
|
||||
if (matchingCreator) {
|
||||
setSelectedCreatorId(matchingCreator.id);
|
||||
} else {
|
||||
// No matching creator (e.g., OpenRouter hosts other creators' models)
|
||||
setSelectedCreatorId("");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Model"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Model
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Register a new model slug, metadata, and pricing.
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core model details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="slug"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Model Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="slug"
|
||||
required
|
||||
name="slug"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="gpt-4.1-mini-2025-04-14"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="GPT 4.1 Mini"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Model Configuration */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Model Configuration
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Model capabilities and limits
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="provider_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<select
|
||||
id="provider_id"
|
||||
required
|
||||
name="provider_id"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
defaultValue=""
|
||||
onChange={(e) => handleProviderChange(e.target.value)}
|
||||
>
|
||||
<option value="" disabled>
|
||||
Select provider
|
||||
</option>
|
||||
{providers.map((provider) => (
|
||||
<option key={provider.id} value={provider.id}>
|
||||
{provider.display_name} ({provider.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Who hosts/serves the model
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="creator_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Creator
|
||||
</label>
|
||||
<select
|
||||
id="creator_id"
|
||||
name="creator_id"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
value={selectedCreatorId}
|
||||
onChange={(e) => setSelectedCreatorId(e.target.value)}
|
||||
>
|
||||
<option value="">No creator selected</option>
|
||||
{creators.map((creator) => (
|
||||
<option key={creator.id} value={creator.id}>
|
||||
{creator.display_name} ({creator.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Who made/trained the model (e.g., OpenAI, Meta)
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="context_window"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Context Window <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="context_window"
|
||||
required
|
||||
type="number"
|
||||
name="context_window"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="128000"
|
||||
min={1}
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="max_output_tokens"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Max Output Tokens
|
||||
</label>
|
||||
<input
|
||||
id="max_output_tokens"
|
||||
type="number"
|
||||
name="max_output_tokens"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="16384"
|
||||
min={1}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Pricing */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">Pricing</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credit cost per run (credentials are managed via the provider)
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-1">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="credit_cost"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credit Cost <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="credit_cost"
|
||||
required
|
||||
type="number"
|
||||
name="credit_cost"
|
||||
step="1"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="5"
|
||||
min={0}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credit cost is always in platform credits. Credentials are
|
||||
inherited from the selected provider.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Enabled Toggle */}
|
||||
<div className="flex items-center gap-3 border-t border-border pt-6">
|
||||
<input type="hidden" name="is_enabled" value="off" />
|
||||
<input
|
||||
id="is_enabled"
|
||||
type="checkbox"
|
||||
name="is_enabled"
|
||||
defaultChecked
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor="is_enabled"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Enabled by default
|
||||
</label>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Save Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { createLlmProviderAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function AddProviderModal() {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to create provider",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Provider
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Define a new upstream provider and default credential information.
|
||||
</div>
|
||||
|
||||
{/* Setup Instructions */}
|
||||
<div className="mb-6 rounded-lg border border-primary/30 bg-primary/5 p-4">
|
||||
<div className="space-y-2">
|
||||
<h4 className="text-sm font-semibold text-foreground">
|
||||
Before Adding a Provider
|
||||
</h4>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
To use a new provider, you must first configure its credentials in
|
||||
the backend:
|
||||
</p>
|
||||
<ol className="list-inside list-decimal space-y-1 text-xs text-muted-foreground">
|
||||
<li>
|
||||
Add the credential to{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
backend/integrations/credentials_store.py
|
||||
</code>{" "}
|
||||
with a UUID, provider name, and settings secret reference
|
||||
</li>
|
||||
<li>
|
||||
Add it to the{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>{" "}
|
||||
dictionary in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
backend/data/block_cost_config.py
|
||||
</code>
|
||||
</li>
|
||||
<li>
|
||||
Use the <strong>same provider name</strong> in the
|
||||
"Credential Provider" field below that matches the key
|
||||
in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</li>
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core provider details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="e.g. openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Default Credentials */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Default Credentials
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credential provider name that matches the key in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_provider"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Provider <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_provider"
|
||||
name="default_credential_provider"
|
||||
required
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
<strong>Important:</strong> This must exactly match the key in
|
||||
the{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>{" "}
|
||||
dictionary in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
block_cost_config.py
|
||||
</code>
|
||||
. Common values: "openai", "anthropic",
|
||||
"groq", "open_router", etc.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Capabilities */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Capabilities
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Provider feature flags
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
{[
|
||||
{ name: "supports_tools", label: "Supports tools" },
|
||||
{ name: "supports_json_output", label: "Supports JSON output" },
|
||||
{ name: "supports_reasoning", label: "Supports reasoning" },
|
||||
{
|
||||
name: "supports_parallel_tool",
|
||||
label: "Supports parallel tool calls",
|
||||
},
|
||||
].map(({ name, label }) => (
|
||||
<div
|
||||
key={name}
|
||||
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
|
||||
>
|
||||
<input type="hidden" name={name} value="off" />
|
||||
<input
|
||||
id={name}
|
||||
type="checkbox"
|
||||
name={name}
|
||||
defaultChecked={
|
||||
name !== "supports_reasoning" &&
|
||||
name !== "supports_parallel_tool"
|
||||
}
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor={name}
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Save Provider"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { updateLlmCreatorAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { DeleteCreatorModal } from "./DeleteCreatorModal";
|
||||
|
||||
export function CreatorsTable({ creators }: { creators: LlmModelCreator[] }) {
|
||||
if (!creators.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No creators registered yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Creator</TableHead>
|
||||
<TableHead>Description</TableHead>
|
||||
<TableHead>Website</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{creators.map((creator) => (
|
||||
<TableRow key={creator.id}>
|
||||
<TableCell>
|
||||
<div className="font-medium">{creator.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{creator.name}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span className="text-sm text-muted-foreground">
|
||||
{creator.description || "—"}
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{creator.website_url ? (
|
||||
<a
|
||||
href={creator.website_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-sm text-primary hover:underline"
|
||||
>
|
||||
{(() => {
|
||||
try {
|
||||
return new URL(creator.website_url).hostname;
|
||||
} catch {
|
||||
return creator.website_url;
|
||||
}
|
||||
})()}
|
||||
</a>
|
||||
) : (
|
||||
<span className="text-muted-foreground">—</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
<EditCreatorModal creator={creator} />
|
||||
<DeleteCreatorModal creator={creator} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EditCreatorModal({ creator }: { creator: LlmModelCreator }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to update creator");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "512px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small" className="min-w-0">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<input type="hidden" name="creator_id" value={creator.id} />
|
||||
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Name (slug)</label>
|
||||
<input
|
||||
required
|
||||
name="name"
|
||||
defaultValue={creator.name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Display Name</label>
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={creator.display_name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Description</label>
|
||||
<textarea
|
||||
name="description"
|
||||
rows={2}
|
||||
defaultValue={creator.description ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Website URL</label>
|
||||
<input
|
||||
name="website_url"
|
||||
type="url"
|
||||
defaultValue={creator.website_url ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Updating..." : "Update"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import { deleteLlmCreatorAction } from "../actions";
|
||||
|
||||
export function DeleteCreatorModal({ creator }: { creator: LlmModelCreator }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to delete creator");
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "480px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{creator.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">
|
||||
({creator.name})
|
||||
</span>
|
||||
</p>
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Models using this creator will have their creator field
|
||||
cleared. This is safe and won't affect model
|
||||
functionality.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="creator_id" value={creator.id} />
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={isDeleting}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
{isDeleting ? "Deleting..." : "Delete Creator"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import { deleteLlmModelAction, fetchLlmModelUsage } from "../actions";
|
||||
|
||||
export function DeleteModelModal({
|
||||
model,
|
||||
availableModels,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
availableModels: LlmModel[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [selectedReplacement, setSelectedReplacement] = useState<string>("");
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [usageCount, setUsageCount] = useState<number | null>(null);
|
||||
const [usageLoading, setUsageLoading] = useState(false);
|
||||
const [usageError, setUsageError] = useState<string | null>(null);
|
||||
|
||||
// Filter out the current model and disabled models from replacement options
|
||||
const replacementOptions = availableModels.filter(
|
||||
(m) => m.id !== model.id && m.is_enabled,
|
||||
);
|
||||
|
||||
// Check if migration is required (has blocks using this model)
|
||||
const requiresMigration = usageCount !== null && usageCount > 0;
|
||||
|
||||
async function fetchUsage() {
|
||||
setUsageLoading(true);
|
||||
setUsageError(null);
|
||||
try {
|
||||
const usage = await fetchLlmModelUsage(model.id);
|
||||
setUsageCount(usage.node_count);
|
||||
} catch (err) {
|
||||
console.error("Failed to fetch model usage:", err);
|
||||
setUsageError("Failed to load usage count");
|
||||
setUsageCount(null);
|
||||
} finally {
|
||||
setUsageLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to delete model");
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if delete button should be enabled
|
||||
const canDelete =
|
||||
!isDeleting &&
|
||||
!usageLoading &&
|
||||
usageCount !== null &&
|
||||
(requiresMigration
|
||||
? selectedReplacement && replacementOptions.length > 0
|
||||
: true);
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Model"
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: async (isOpen) => {
|
||||
setOpen(isOpen);
|
||||
if (isOpen) {
|
||||
setUsageCount(null);
|
||||
setUsageError(null);
|
||||
setError(null);
|
||||
setSelectedReplacement("");
|
||||
await fetchUsage();
|
||||
}
|
||||
},
|
||||
}}
|
||||
styling={{ maxWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
{requiresMigration
|
||||
? "This action cannot be undone. All workflows using this model will be migrated to the replacement model you select."
|
||||
: "This action cannot be undone."}
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{model.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">({model.slug})</span>
|
||||
</p>
|
||||
{usageLoading && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Loading usage count...
|
||||
</p>
|
||||
)}
|
||||
{usageError && (
|
||||
<p className="mt-2 text-destructive">{usageError}</p>
|
||||
)}
|
||||
{!usageLoading && !usageError && usageCount !== null && (
|
||||
<p className="mt-2 font-semibold">
|
||||
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
|
||||
currently use this model
|
||||
</p>
|
||||
)}
|
||||
{requiresMigration && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
All workflows currently using this model will be
|
||||
automatically updated to use the replacement model you
|
||||
choose below.
|
||||
</p>
|
||||
)}
|
||||
{!usageLoading && usageCount === 0 && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
No workflows are using this model. It can be safely deleted.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.id} />
|
||||
<input
|
||||
type="hidden"
|
||||
name="replacement_model_slug"
|
||||
value={selectedReplacement}
|
||||
/>
|
||||
|
||||
{requiresMigration && (
|
||||
<label className="text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Select Replacement Model{" "}
|
||||
<span className="text-destructive">*</span>
|
||||
</span>
|
||||
<select
|
||||
required
|
||||
value={selectedReplacement}
|
||||
onChange={(e) => setSelectedReplacement(e.target.value)}
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="">-- Choose a replacement model --</option>
|
||||
{replacementOptions.map((m) => (
|
||||
<option key={m.id} value={m.slug}>
|
||||
{m.display_name} ({m.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
{replacementOptions.length === 0 && (
|
||||
<p className="mt-2 text-xs text-destructive">
|
||||
No replacement models available. You must have at least one
|
||||
other enabled model before deleting this one.
|
||||
</p>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setSelectedReplacement("");
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={!canDelete}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
{isDeleting
|
||||
? "Deleting..."
|
||||
: requiresMigration
|
||||
? "Delete and Migrate"
|
||||
: "Delete"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { deleteLlmProviderAction } from "../actions";
|
||||
|
||||
export function DeleteProviderModal({ provider }: { provider: LlmProvider }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
const modelCount = provider.models?.length ?? 0;
|
||||
const hasModels = modelCount > 0;
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to delete provider",
|
||||
);
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "480px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<div
|
||||
className={`rounded-lg border p-4 ${
|
||||
hasModels
|
||||
? "border-destructive/30 bg-destructive/10"
|
||||
: "border-amber-500/30 bg-amber-500/10 dark:border-amber-400/30 dark:bg-amber-400/10"
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-start gap-3">
|
||||
<div
|
||||
className={`flex-shrink-0 ${
|
||||
hasModels
|
||||
? "text-destructive"
|
||||
: "text-amber-600 dark:text-amber-400"
|
||||
}`}
|
||||
>
|
||||
{hasModels ? "🚫" : "⚠️"}
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{provider.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">
|
||||
({provider.name})
|
||||
</span>
|
||||
</p>
|
||||
{hasModels ? (
|
||||
<p className="mt-2 text-destructive">
|
||||
This provider has {modelCount} model(s). You must delete all
|
||||
models before you can delete this provider.
|
||||
</p>
|
||||
) : (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
This provider has no models and can be safely deleted.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="provider_id" value={provider.id} />
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={isDeleting || hasModels}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90 disabled:opacity-50"
|
||||
>
|
||||
{isDeleting ? "Deleting..." : "Delete Provider"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,288 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import { toggleLlmModelAction, fetchLlmModelUsage } from "../actions";
|
||||
|
||||
export function DisableModelModal({
|
||||
model,
|
||||
availableModels,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
availableModels: LlmModel[];
|
||||
}) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDisabling, setIsDisabling] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [usageCount, setUsageCount] = useState<number | null>(null);
|
||||
const [selectedMigration, setSelectedMigration] = useState<string>("");
|
||||
const [wantsMigration, setWantsMigration] = useState(false);
|
||||
const [migrationReason, setMigrationReason] = useState("");
|
||||
const [customCreditCost, setCustomCreditCost] = useState<string>("");
|
||||
|
||||
// Filter out the current model and disabled models from replacement options
|
||||
const migrationOptions = availableModels.filter(
|
||||
(m) => m.id !== model.id && m.is_enabled,
|
||||
);
|
||||
|
||||
async function fetchUsage() {
|
||||
try {
|
||||
const usage = await fetchLlmModelUsage(model.id);
|
||||
setUsageCount(usage.node_count);
|
||||
} catch {
|
||||
setUsageCount(null);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDisable(formData: FormData) {
|
||||
setIsDisabling(true);
|
||||
setError(null);
|
||||
try {
|
||||
await toggleLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to disable model");
|
||||
} finally {
|
||||
setIsDisabling(false);
|
||||
}
|
||||
}
|
||||
|
||||
function resetState() {
|
||||
setError(null);
|
||||
setSelectedMigration("");
|
||||
setWantsMigration(false);
|
||||
setMigrationReason("");
|
||||
setCustomCreditCost("");
|
||||
}
|
||||
|
||||
const hasUsage = usageCount !== null && usageCount > 0;
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Disable Model"
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: async (isOpen) => {
|
||||
setOpen(isOpen);
|
||||
if (isOpen) {
|
||||
setUsageCount(null);
|
||||
resetState();
|
||||
await fetchUsage();
|
||||
}
|
||||
},
|
||||
}}
|
||||
styling={{ maxWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0"
|
||||
>
|
||||
Disable
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Disabling a model will hide it from users when creating new workflows.
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to disable:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{model.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">({model.slug})</span>
|
||||
</p>
|
||||
{usageCount === null ? (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Loading usage data...
|
||||
</p>
|
||||
) : usageCount > 0 ? (
|
||||
<p className="mt-2 font-semibold">
|
||||
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
|
||||
currently use this model
|
||||
</p>
|
||||
) : (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
No workflows are currently using this model.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{hasUsage && (
|
||||
<div className="space-y-4 rounded-lg border border-border bg-muted/50 p-4">
|
||||
<label className="flex items-start gap-3">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={wantsMigration}
|
||||
onChange={(e) => {
|
||||
setWantsMigration(e.target.checked);
|
||||
if (!e.target.checked) {
|
||||
setSelectedMigration("");
|
||||
}
|
||||
}}
|
||||
className="mt-1"
|
||||
/>
|
||||
<div className="text-sm">
|
||||
<span className="font-medium">
|
||||
Migrate existing workflows to another model
|
||||
</span>
|
||||
<p className="mt-1 text-muted-foreground">
|
||||
Creates a revertible migration record. If unchecked,
|
||||
existing workflows will use automatic fallback to an enabled
|
||||
model from the same provider.
|
||||
</p>
|
||||
</div>
|
||||
</label>
|
||||
|
||||
{wantsMigration && (
|
||||
<div className="space-y-4 border-t border-border pt-4">
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Replacement Model{" "}
|
||||
<span className="text-destructive">*</span>
|
||||
</span>
|
||||
<select
|
||||
required
|
||||
value={selectedMigration}
|
||||
onChange={(e) => setSelectedMigration(e.target.value)}
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="">-- Choose a replacement model --</option>
|
||||
{migrationOptions.map((m) => (
|
||||
<option key={m.id} value={m.slug}>
|
||||
{m.display_name} ({m.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
{migrationOptions.length === 0 && (
|
||||
<p className="mt-2 text-xs text-destructive">
|
||||
No other enabled models available for migration.
|
||||
</p>
|
||||
)}
|
||||
</label>
|
||||
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Migration Reason{" "}
|
||||
<span className="font-normal text-muted-foreground">
|
||||
(optional)
|
||||
</span>
|
||||
</span>
|
||||
<input
|
||||
type="text"
|
||||
value={migrationReason}
|
||||
onChange={(e) => setMigrationReason(e.target.value)}
|
||||
placeholder="e.g., Provider outage, Cost reduction"
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
<p className="mt-1 text-xs text-muted-foreground">
|
||||
Helps track why the migration was made
|
||||
</p>
|
||||
</label>
|
||||
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Custom Credit Cost{" "}
|
||||
<span className="font-normal text-muted-foreground">
|
||||
(optional)
|
||||
</span>
|
||||
</span>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
value={customCreditCost}
|
||||
onChange={(e) => setCustomCreditCost(e.target.value)}
|
||||
placeholder="Leave blank to use target model's cost"
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
<p className="mt-1 text-xs text-muted-foreground">
|
||||
Override pricing for migrated workflows. When set, billing
|
||||
will use this cost instead of the target model's
|
||||
cost.
|
||||
</p>
|
||||
</label>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<form action={handleDisable} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.id} />
|
||||
<input type="hidden" name="is_enabled" value="false" />
|
||||
{wantsMigration && selectedMigration && (
|
||||
<>
|
||||
<input
|
||||
type="hidden"
|
||||
name="migrate_to_slug"
|
||||
value={selectedMigration}
|
||||
/>
|
||||
{migrationReason && (
|
||||
<input
|
||||
type="hidden"
|
||||
name="migration_reason"
|
||||
value={migrationReason}
|
||||
/>
|
||||
)}
|
||||
{customCreditCost && (
|
||||
<input
|
||||
type="hidden"
|
||||
name="custom_credit_cost"
|
||||
value={customCreditCost}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
resetState();
|
||||
}}
|
||||
disabled={isDisabling}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={
|
||||
isDisabling ||
|
||||
(wantsMigration && !selectedMigration) ||
|
||||
usageCount === null
|
||||
}
|
||||
>
|
||||
{isDisabling
|
||||
? "Disabling..."
|
||||
: wantsMigration && selectedMigration
|
||||
? "Disable & Migrate"
|
||||
: "Disable Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { updateLlmModelAction } from "../actions";
|
||||
|
||||
export function EditModelModal({
|
||||
model,
|
||||
providers,
|
||||
creators,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const cost = model.costs?.[0];
|
||||
const provider = providers.find((p) => p.id === model.provider_id);
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to update model");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Model"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small" className="min-w-0">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Update model metadata and pricing information.
|
||||
</div>
|
||||
{error && (
|
||||
<div className="mb-4 rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.id} />
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Display Name
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={model.display_name}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Provider
|
||||
<select
|
||||
required
|
||||
name="provider_id"
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
defaultValue={model.provider_id}
|
||||
>
|
||||
{providers.map((p) => (
|
||||
<option key={p.id} value={p.id}>
|
||||
{p.display_name} ({p.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Who hosts/serves the model
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Creator
|
||||
<select
|
||||
name="creator_id"
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
defaultValue={model.creator_id ?? ""}
|
||||
>
|
||||
<option value="">No creator selected</option>
|
||||
{creators.map((c) => (
|
||||
<option key={c.id} value={c.id}>
|
||||
{c.display_name} ({c.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Who made/trained the model (e.g., OpenAI, Meta)
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<label className="text-sm font-medium">
|
||||
Description
|
||||
<textarea
|
||||
name="description"
|
||||
rows={2}
|
||||
defaultValue={model.description ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</label>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Context Window
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="context_window"
|
||||
defaultValue={model.context_window}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Max Output Tokens
|
||||
<input
|
||||
type="number"
|
||||
name="max_output_tokens"
|
||||
defaultValue={model.max_output_tokens ?? undefined}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Credit Cost
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="credit_cost"
|
||||
defaultValue={cost?.credit_cost ?? 0}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={0}
|
||||
/>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Credits charged per run
|
||||
</span>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Credential Provider
|
||||
<select
|
||||
required
|
||||
name="credential_provider"
|
||||
defaultValue={cost?.credential_provider ?? provider?.name ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="" disabled>
|
||||
Select provider
|
||||
</option>
|
||||
{providers.map((p) => (
|
||||
<option key={p.id} value={p.name}>
|
||||
{p.display_name} ({p.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Must match a key in PROVIDER_CREDENTIALS
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
{/* Hidden defaults for credential_type and unit */}
|
||||
<input
|
||||
type="hidden"
|
||||
name="credential_type"
|
||||
value={
|
||||
cost?.credential_type ??
|
||||
provider?.default_credential_type ??
|
||||
"api_key"
|
||||
}
|
||||
/>
|
||||
<input type="hidden" name="unit" value={cost?.unit ?? "RUN"} />
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => setOpen(false)}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Updating..." : "Update Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { updateLlmProviderAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
|
||||
export function EditProviderModal({ provider }: { provider: LlmProvider }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to update provider",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Update provider configuration and capabilities.
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
<input type="hidden" name="provider_id" value={provider.id} />
|
||||
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core provider details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
defaultValue={provider.name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="e.g. openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={provider.display_name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
defaultValue={provider.description ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Default Credentials */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Default Credentials
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credential provider name that matches the key in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_provider"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Provider
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_provider"
|
||||
name="default_credential_provider"
|
||||
defaultValue={provider.default_credential_provider ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential ID
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_id"
|
||||
name="default_credential_id"
|
||||
defaultValue={provider.default_credential_id ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional credential ID"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_type"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Type
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_type"
|
||||
name="default_credential_type"
|
||||
defaultValue={provider.default_credential_type ?? "api_key"}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="api_key"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Capabilities */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Capabilities
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Provider feature flags
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
{[
|
||||
{
|
||||
name: "supports_tools",
|
||||
label: "Supports tools",
|
||||
checked: provider.supports_tools,
|
||||
},
|
||||
{
|
||||
name: "supports_json_output",
|
||||
label: "Supports JSON output",
|
||||
checked: provider.supports_json_output,
|
||||
},
|
||||
{
|
||||
name: "supports_reasoning",
|
||||
label: "Supports reasoning",
|
||||
checked: provider.supports_reasoning,
|
||||
},
|
||||
{
|
||||
name: "supports_parallel_tool",
|
||||
label: "Supports parallel tool calls",
|
||||
checked: provider.supports_parallel_tool,
|
||||
},
|
||||
].map(({ name, label, checked }) => (
|
||||
<div
|
||||
key={name}
|
||||
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
|
||||
>
|
||||
<input type="hidden" name={name} value="off" />
|
||||
<input
|
||||
id={name}
|
||||
type="checkbox"
|
||||
name={name}
|
||||
defaultChecked={checked}
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor={name}
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Saving..." : "Save Changes"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { ErrorBoundary } from "@/components/molecules/ErrorBoundary/ErrorBoundary";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { AddProviderModal } from "./AddProviderModal";
|
||||
import { AddModelModal } from "./AddModelModal";
|
||||
import { AddCreatorModal } from "./AddCreatorModal";
|
||||
import { ProviderList } from "./ProviderList";
|
||||
import { ModelsTable } from "./ModelsTable";
|
||||
import { MigrationsTable } from "./MigrationsTable";
|
||||
import { CreatorsTable } from "./CreatorsTable";
|
||||
import { RecommendedModelSelector } from "./RecommendedModelSelector";
|
||||
|
||||
interface Props {
|
||||
providers: LlmProvider[];
|
||||
models: LlmModel[];
|
||||
migrations: LlmModelMigration[];
|
||||
creators: LlmModelCreator[];
|
||||
}
|
||||
|
||||
function AdminErrorFallback() {
|
||||
return (
|
||||
<div className="mx-auto max-w-xl p-6">
|
||||
<ErrorCard
|
||||
responseError={{
|
||||
message:
|
||||
"An error occurred while loading the LLM Registry. Please refresh the page.",
|
||||
}}
|
||||
context="llm-registry"
|
||||
onRetry={() => window.location.reload()}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function LlmRegistryDashboard({
|
||||
providers,
|
||||
models,
|
||||
migrations,
|
||||
creators,
|
||||
}: Props) {
|
||||
return (
|
||||
<ErrorBoundary fallback={<AdminErrorFallback />} context="llm-registry">
|
||||
<div className="mx-auto p-6">
|
||||
<div className="flex flex-col gap-6">
|
||||
{/* Header */}
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">LLM Registry</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Manage providers, creators, models, and credit pricing
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Active Migrations Section - Only show if there are migrations */}
|
||||
{migrations.length > 0 && (
|
||||
<div className="rounded-lg border border-primary/30 bg-primary/5 p-6 shadow-sm">
|
||||
<div className="mb-4">
|
||||
<h2 className="text-xl font-semibold">Active Migrations</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
These migrations can be reverted to restore workflows to their
|
||||
original model
|
||||
</p>
|
||||
</div>
|
||||
<MigrationsTable migrations={migrations} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Providers & Creators Section - Side by Side */}
|
||||
<div className="grid gap-6 lg:grid-cols-2">
|
||||
{/* Providers */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Providers</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Who hosts/serves the models
|
||||
</p>
|
||||
</div>
|
||||
<AddProviderModal />
|
||||
</div>
|
||||
<ProviderList providers={providers} />
|
||||
</div>
|
||||
|
||||
{/* Creators */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Creators</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Who made/trained the models
|
||||
</p>
|
||||
</div>
|
||||
<AddCreatorModal />
|
||||
</div>
|
||||
<CreatorsTable creators={creators} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Models Section */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Models</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Toggle availability, adjust context windows, and update credit
|
||||
pricing
|
||||
</p>
|
||||
</div>
|
||||
<AddModelModal providers={providers} creators={creators} />
|
||||
</div>
|
||||
|
||||
{/* Recommended Model Selector */}
|
||||
<div className="mb-6">
|
||||
<RecommendedModelSelector models={models} />
|
||||
</div>
|
||||
|
||||
<ModelsTable
|
||||
models={models}
|
||||
providers={providers}
|
||||
creators={creators}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { revertLlmMigrationAction } from "../actions";
|
||||
|
||||
export function MigrationsTable({
|
||||
migrations,
|
||||
}: {
|
||||
migrations: LlmModelMigration[];
|
||||
}) {
|
||||
if (!migrations.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No active migrations. Migrations are created when you disable a model
|
||||
with the "Migrate existing workflows" option.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Migration</TableHead>
|
||||
<TableHead>Reason</TableHead>
|
||||
<TableHead>Nodes Affected</TableHead>
|
||||
<TableHead>Custom Cost</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead className="text-right">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{migrations.map((migration) => (
|
||||
<MigrationRow key={migration.id} migration={migration} />
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function MigrationRow({ migration }: { migration: LlmModelMigration }) {
|
||||
const [isReverting, setIsReverting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
async function handleRevert(formData: FormData) {
|
||||
setIsReverting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await revertLlmMigrationAction(formData);
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to revert migration",
|
||||
);
|
||||
} finally {
|
||||
setIsReverting(false);
|
||||
}
|
||||
}
|
||||
|
||||
const createdDate = new Date(migration.created_at);
|
||||
|
||||
return (
|
||||
<>
|
||||
<TableRow>
|
||||
<TableCell>
|
||||
<div className="text-sm">
|
||||
<span className="font-medium">{migration.source_model_slug}</span>
|
||||
<span className="mx-2 text-muted-foreground">→</span>
|
||||
<span className="font-medium">{migration.target_model_slug}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{migration.reason || "—"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm">{migration.node_count}</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm">
|
||||
{migration.custom_credit_cost !== null &&
|
||||
migration.custom_credit_cost !== undefined
|
||||
? `${migration.custom_credit_cost} credits`
|
||||
: "—"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{createdDate.toLocaleDateString()}{" "}
|
||||
{createdDate.toLocaleTimeString([], {
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
})}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<form action={handleRevert} className="inline">
|
||||
<input type="hidden" name="migration_id" value={migration.id} />
|
||||
<Button
|
||||
type="submit"
|
||||
variant="outline"
|
||||
size="small"
|
||||
disabled={isReverting}
|
||||
>
|
||||
{isReverting ? "Reverting..." : "Revert"}
|
||||
</Button>
|
||||
</form>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
{error && (
|
||||
<TableRow>
|
||||
<TableCell colSpan={6}>
|
||||
<div className="rounded border border-destructive/30 bg-destructive/10 p-2 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,262 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { toggleLlmModelAction } from "../actions";
|
||||
import { DeleteModelModal } from "./DeleteModelModal";
|
||||
import { DisableModelModal } from "./DisableModelModal";
|
||||
import { EditModelModal } from "./EditModelModal";
|
||||
import { Star, Spinner } from "@phosphor-icons/react";
|
||||
import { getV2ListLlmModels } from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
|
||||
const PAGE_SIZE = 50;
|
||||
|
||||
export function ModelsTable({
|
||||
models: initialModels,
|
||||
providers,
|
||||
creators,
|
||||
}: {
|
||||
models: LlmModel[];
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}) {
|
||||
const [models, setModels] = useState<LlmModel[]>(initialModels);
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
const [hasMore, setHasMore] = useState(initialModels.length === PAGE_SIZE);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const loadedPagesRef = useRef(1);
|
||||
|
||||
// Sync with parent when initialModels changes (e.g., after enable/disable)
|
||||
// Re-fetch all loaded pages to preserve expanded state
|
||||
useEffect(() => {
|
||||
async function refetchAllPages() {
|
||||
const pagesToLoad = loadedPagesRef.current;
|
||||
|
||||
if (pagesToLoad === 1) {
|
||||
// Only first page loaded, just use initialModels
|
||||
setModels(initialModels);
|
||||
setHasMore(initialModels.length === PAGE_SIZE);
|
||||
return;
|
||||
}
|
||||
|
||||
// Re-fetch all pages we had loaded
|
||||
const allModels: LlmModel[] = [...initialModels];
|
||||
let lastPageHadFullResults = initialModels.length === PAGE_SIZE;
|
||||
|
||||
for (let page = 2; page <= pagesToLoad; page++) {
|
||||
try {
|
||||
const response = await getV2ListLlmModels({
|
||||
page,
|
||||
page_size: PAGE_SIZE,
|
||||
});
|
||||
if (response.status === 200) {
|
||||
allModels.push(...response.data.models);
|
||||
lastPageHadFullResults = response.data.models.length === PAGE_SIZE;
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(`Error refetching page ${page}:`, err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
setModels(allModels);
|
||||
setHasMore(lastPageHadFullResults);
|
||||
}
|
||||
|
||||
refetchAllPages();
|
||||
}, [initialModels]);
|
||||
|
||||
async function loadMore() {
|
||||
if (isLoading) return;
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const nextPage = currentPage + 1;
|
||||
const response = await getV2ListLlmModels({
|
||||
page: nextPage,
|
||||
page_size: PAGE_SIZE,
|
||||
});
|
||||
|
||||
if (response.status === 200) {
|
||||
setModels((prev) => [...prev, ...response.data.models]);
|
||||
setCurrentPage(nextPage);
|
||||
loadedPagesRef.current = nextPage;
|
||||
setHasMore(response.data.models.length === PAGE_SIZE);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Error loading more models:", err);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
if (!models.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No models registered yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const providerLookup = new Map(
|
||||
providers.map((provider) => [provider.id, provider]),
|
||||
);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Model</TableHead>
|
||||
<TableHead>Provider</TableHead>
|
||||
<TableHead>Creator</TableHead>
|
||||
<TableHead>Context Window</TableHead>
|
||||
<TableHead>Max Output</TableHead>
|
||||
<TableHead>Cost</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{models.map((model) => {
|
||||
const cost = model.costs?.[0];
|
||||
const provider = providerLookup.get(model.provider_id);
|
||||
return (
|
||||
<TableRow
|
||||
key={model.id}
|
||||
className={model.is_enabled ? "" : "opacity-60"}
|
||||
>
|
||||
<TableCell>
|
||||
<div className="font-medium">{model.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.slug}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{provider ? (
|
||||
<>
|
||||
<div>{provider.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{provider.name}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
model.provider_id
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{model.creator ? (
|
||||
<>
|
||||
<div>{model.creator.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.creator.name}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<span className="text-muted-foreground">—</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>{model.context_window.toLocaleString()}</TableCell>
|
||||
<TableCell>
|
||||
{model.max_output_tokens
|
||||
? model.max_output_tokens.toLocaleString()
|
||||
: "—"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{cost ? (
|
||||
<>
|
||||
<div className="font-medium">
|
||||
{cost.credit_cost} credits
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{cost.credential_provider}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
"—"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex flex-col gap-1">
|
||||
<span
|
||||
className={`inline-flex rounded-full px-2.5 py-1 text-xs font-semibold ${
|
||||
model.is_enabled
|
||||
? "bg-primary/10 text-primary"
|
||||
: "bg-muted text-muted-foreground"
|
||||
}`}
|
||||
>
|
||||
{model.is_enabled ? "Enabled" : "Disabled"}
|
||||
</span>
|
||||
{model.is_recommended && (
|
||||
<span className="inline-flex items-center gap-1 rounded-full bg-amber-500/10 px-2.5 py-1 text-xs font-semibold text-amber-600 dark:text-amber-400">
|
||||
<Star size={12} weight="fill" />
|
||||
Recommended
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{model.is_enabled ? (
|
||||
<DisableModelModal
|
||||
model={model}
|
||||
availableModels={models}
|
||||
/>
|
||||
) : (
|
||||
<EnableModelButton modelId={model.id} />
|
||||
)}
|
||||
<EditModelModal
|
||||
model={model}
|
||||
providers={providers}
|
||||
creators={creators}
|
||||
/>
|
||||
<DeleteModelModal model={model} availableModels={models} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
{hasMore && (
|
||||
<div className="mt-4 flex justify-center">
|
||||
<Button onClick={loadMore} disabled={isLoading} variant="outline">
|
||||
{isLoading ? (
|
||||
<>
|
||||
<Spinner className="mr-2 h-4 w-4 animate-spin" />
|
||||
Loading...
|
||||
</>
|
||||
) : (
|
||||
"Load More"
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EnableModelButton({ modelId }: { modelId: string }) {
|
||||
return (
|
||||
<form action={toggleLlmModelAction} className="inline">
|
||||
<input type="hidden" name="model_id" value={modelId} />
|
||||
<input type="hidden" name="is_enabled" value="true" />
|
||||
<Button type="submit" variant="outline" size="small" className="min-w-0">
|
||||
Enable
|
||||
</Button>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { DeleteProviderModal } from "./DeleteProviderModal";
|
||||
import { EditProviderModal } from "./EditProviderModal";
|
||||
|
||||
export function ProviderList({ providers }: { providers: LlmProvider[] }) {
|
||||
if (!providers.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No providers configured yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Display Name</TableHead>
|
||||
<TableHead>Default Credential</TableHead>
|
||||
<TableHead>Capabilities</TableHead>
|
||||
<TableHead>Models</TableHead>
|
||||
<TableHead className="w-[100px]">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{providers.map((provider) => (
|
||||
<TableRow key={provider.id}>
|
||||
<TableCell className="font-medium">{provider.name}</TableCell>
|
||||
<TableCell>{provider.display_name}</TableCell>
|
||||
<TableCell>
|
||||
{provider.default_credential_provider
|
||||
? `${provider.default_credential_provider} (${provider.default_credential_id ?? "id?"})`
|
||||
: "—"}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-muted-foreground">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{provider.supports_tools && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Tools
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_json_output && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
JSON
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_reasoning && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Reasoning
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_parallel_tool && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Parallel Tools
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-sm">
|
||||
<span
|
||||
className={
|
||||
(provider.models?.length ?? 0) > 0
|
||||
? "text-foreground"
|
||||
: "text-muted-foreground"
|
||||
}
|
||||
>
|
||||
{provider.models?.length ?? 0}
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex gap-2">
|
||||
<EditProviderModal provider={provider} />
|
||||
<DeleteProviderModal provider={provider} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { setRecommendedModelAction } from "../actions";
|
||||
import { Star } from "@phosphor-icons/react";
|
||||
|
||||
export function RecommendedModelSelector({ models }: { models: LlmModel[] }) {
|
||||
const router = useRouter();
|
||||
const enabledModels = models.filter((m) => m.is_enabled);
|
||||
const currentRecommended = models.find((m) => m.is_recommended);
|
||||
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>(
|
||||
currentRecommended?.id || "",
|
||||
);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const hasChanges = selectedModelId !== (currentRecommended?.id || "");
|
||||
|
||||
async function handleSave() {
|
||||
if (!selectedModelId) return;
|
||||
|
||||
setIsSaving(true);
|
||||
setError(null);
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.set("model_id", selectedModelId);
|
||||
await setRecommendedModelAction(formData);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to save");
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border border-border bg-card p-4">
|
||||
<div className="mb-3 flex items-center gap-2">
|
||||
<Star size={20} weight="fill" className="text-amber-500" />
|
||||
<h3 className="text-sm font-semibold">Recommended Model</h3>
|
||||
</div>
|
||||
<p className="mb-3 text-xs text-muted-foreground">
|
||||
The recommended model is shown as the default suggestion in model
|
||||
selection dropdowns throughout the platform.
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-3">
|
||||
<select
|
||||
value={selectedModelId}
|
||||
onChange={(e) => setSelectedModelId(e.target.value)}
|
||||
className="flex-1 rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
disabled={isSaving}
|
||||
>
|
||||
<option value="">-- Select a model --</option>
|
||||
{enabledModels.map((model) => (
|
||||
<option key={model.id} value={model.id}>
|
||||
{model.display_name} ({model.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleSave}
|
||||
disabled={!hasChanges || !selectedModelId || isSaving}
|
||||
>
|
||||
{isSaving ? "Saving..." : "Save"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{error && <p className="mt-2 text-xs text-destructive">{error}</p>}
|
||||
|
||||
{currentRecommended && !hasChanges && (
|
||||
<p className="mt-2 text-xs text-muted-foreground">
|
||||
Currently set to:{" "}
|
||||
<span className="font-medium">{currentRecommended.display_name}</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
/**
|
||||
* Server-side data fetching for LLM Registry page.
|
||||
*/
|
||||
|
||||
import {
|
||||
fetchLlmCreators,
|
||||
fetchLlmMigrations,
|
||||
fetchLlmModels,
|
||||
fetchLlmProviders,
|
||||
} from "./actions";
|
||||
|
||||
export async function getLlmRegistryPageData() {
|
||||
// Fetch providers and models (required)
|
||||
const [providersResponse, modelsResponse] = await Promise.all([
|
||||
fetchLlmProviders(),
|
||||
fetchLlmModels(),
|
||||
]);
|
||||
|
||||
// Fetch migrations separately with fallback (table might not exist yet)
|
||||
let migrations: Awaited<ReturnType<typeof fetchLlmMigrations>>["migrations"] =
|
||||
[];
|
||||
try {
|
||||
const migrationsResponse = await fetchLlmMigrations(false);
|
||||
migrations = migrationsResponse.migrations;
|
||||
} catch {
|
||||
// Migrations table might not exist yet - that's ok, just show empty list
|
||||
console.warn("Could not fetch migrations - table may not exist yet");
|
||||
}
|
||||
|
||||
// Fetch creators separately with fallback (table might not exist yet)
|
||||
let creators: Awaited<ReturnType<typeof fetchLlmCreators>>["creators"] = [];
|
||||
try {
|
||||
const creatorsResponse = await fetchLlmCreators();
|
||||
creators = creatorsResponse.creators;
|
||||
} catch {
|
||||
// Creators table might not exist yet - that's ok, just show empty list
|
||||
console.warn("Could not fetch creators - table may not exist yet");
|
||||
}
|
||||
|
||||
return {
|
||||
providers: providersResponse.providers,
|
||||
models: modelsResponse.models,
|
||||
migrations,
|
||||
creators,
|
||||
};
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { getLlmRegistryPageData } from "./getLlmRegistryPage";
|
||||
import { LlmRegistryDashboard } from "./components/LlmRegistryDashboard";
|
||||
|
||||
async function LlmRegistryPage() {
|
||||
const data = await getLlmRegistryPageData();
|
||||
return <LlmRegistryDashboard {...data} />;
|
||||
}
|
||||
|
||||
export default async function AdminLlmRegistryPage() {
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedLlmRegistryPage = await withAdminAccess(LlmRegistryPage);
|
||||
return <ProtectedLlmRegistryPage />;
|
||||
}
|
||||
@@ -38,8 +38,12 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
|
||||
|
||||
return outputNodes
|
||||
.map((node) => {
|
||||
const executionResult = node.data.nodeExecutionResult;
|
||||
const outputData = executionResult?.output_data?.output;
|
||||
const executionResults = node.data.nodeExecutionResults || [];
|
||||
const latestResult =
|
||||
executionResults.length > 0
|
||||
? executionResults[executionResults.length - 1]
|
||||
: undefined;
|
||||
const outputData = latestResult?.output_data?.output;
|
||||
|
||||
const renderer = globalRegistry.getRenderer(outputData);
|
||||
|
||||
|
||||
@@ -153,6 +153,9 @@ export const useRunInputDialog = ({
|
||||
Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id),
|
||||
);
|
||||
|
||||
useNodeStore.getState().clearAllNodeExecutionResults();
|
||||
useNodeStore.getState().cleanNodesStatuses();
|
||||
|
||||
await executeGraph({
|
||||
graphId: flowID ?? "",
|
||||
graphVersion: flowVersion || null,
|
||||
|
||||
@@ -86,7 +86,6 @@ export function FloatingSafeModeToggle({
|
||||
const {
|
||||
currentHITLSafeMode,
|
||||
showHITLToggle,
|
||||
isHITLStateUndetermined,
|
||||
handleHITLToggle,
|
||||
currentSensitiveActionSafeMode,
|
||||
showSensitiveActionToggle,
|
||||
@@ -99,16 +98,9 @@ export function FloatingSafeModeToggle({
|
||||
return null;
|
||||
}
|
||||
|
||||
const showHITL = showHITLToggle && !isHITLStateUndetermined;
|
||||
const showSensitive = showSensitiveActionToggle;
|
||||
|
||||
if (!showHITL && !showSensitive) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("fixed z-50 flex flex-col gap-2", className)}>
|
||||
{showHITL && (
|
||||
{showHITLToggle && (
|
||||
<SafeModeButton
|
||||
isEnabled={currentHITLSafeMode}
|
||||
label="Human in the loop block approval"
|
||||
@@ -119,7 +111,7 @@ export function FloatingSafeModeToggle({
|
||||
fullWidth={fullWidth}
|
||||
/>
|
||||
)}
|
||||
{showSensitive && (
|
||||
{showSensitiveActionToggle && (
|
||||
<SafeModeButton
|
||||
isEnabled={currentSensitiveActionSafeMode}
|
||||
label="Sensitive actions blocks approval"
|
||||
|
||||
@@ -34,7 +34,7 @@ export type CustomNodeData = {
|
||||
uiType: BlockUIType;
|
||||
block_id: string;
|
||||
status?: AgentExecutionStatus;
|
||||
nodeExecutionResult?: NodeExecutionResult;
|
||||
nodeExecutionResults?: NodeExecutionResult[];
|
||||
staticOutput?: boolean;
|
||||
// TODO : We need better type safety for the following backend fields.
|
||||
costs: BlockCost[];
|
||||
@@ -75,7 +75,11 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
(value) => value !== null && value !== undefined && value !== "",
|
||||
);
|
||||
|
||||
const outputData = data.nodeExecutionResult?.output_data;
|
||||
const latestResult =
|
||||
data.nodeExecutionResults && data.nodeExecutionResults.length > 0
|
||||
? data.nodeExecutionResults[data.nodeExecutionResults.length - 1]
|
||||
: undefined;
|
||||
const outputData = latestResult?.output_data;
|
||||
const hasOutputError =
|
||||
typeof outputData === "object" &&
|
||||
outputData !== null &&
|
||||
|
||||
@@ -14,10 +14,15 @@ import { useNodeOutput } from "./useNodeOutput";
|
||||
import { ViewMoreData } from "./components/ViewMoreData";
|
||||
|
||||
export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
|
||||
const { outputData, copiedKey, handleCopy, executionResultId, inputData } =
|
||||
useNodeOutput(nodeId);
|
||||
const {
|
||||
latestOutputData,
|
||||
copiedKey,
|
||||
handleCopy,
|
||||
executionResultId,
|
||||
latestInputData,
|
||||
} = useNodeOutput(nodeId);
|
||||
|
||||
if (Object.keys(outputData).length === 0) {
|
||||
if (Object.keys(latestOutputData).length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -41,18 +46,19 @@ export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
|
||||
<div className="space-y-2">
|
||||
<Text variant="small-medium">Input</Text>
|
||||
|
||||
<ContentRenderer value={inputData} shortContent={false} />
|
||||
<ContentRenderer value={latestInputData} shortContent={false} />
|
||||
|
||||
<div className="mt-1 flex justify-end gap-1">
|
||||
<NodeDataViewer
|
||||
data={inputData}
|
||||
pinName="Input"
|
||||
nodeId={nodeId}
|
||||
execId={executionResultId}
|
||||
dataType="input"
|
||||
/>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => handleCopy("input", inputData)}
|
||||
onClick={() => handleCopy("input", latestInputData)}
|
||||
className={cn(
|
||||
"h-fit min-w-0 gap-1.5 border border-zinc-200 p-2 text-black hover:text-slate-900",
|
||||
copiedKey === "input" &&
|
||||
@@ -68,70 +74,72 @@ export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{Object.entries(outputData)
|
||||
{Object.entries(latestOutputData)
|
||||
.slice(0, 2)
|
||||
.map(([key, value]) => (
|
||||
<div key={key} className="flex flex-col gap-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<Text
|
||||
variant="small-medium"
|
||||
className="!font-semibold text-slate-600"
|
||||
>
|
||||
Pin:
|
||||
</Text>
|
||||
<Text variant="small" className="text-slate-700">
|
||||
{beautifyString(key)}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="w-full space-y-2">
|
||||
<Text
|
||||
variant="small"
|
||||
className="!font-semibold text-slate-600"
|
||||
>
|
||||
Data:
|
||||
</Text>
|
||||
<div className="relative space-y-2">
|
||||
{value.map((item, index) => (
|
||||
<div key={index}>
|
||||
<ContentRenderer value={item} shortContent={true} />
|
||||
</div>
|
||||
))}
|
||||
.map(([key, value]) => {
|
||||
return (
|
||||
<div key={key} className="flex flex-col gap-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<Text
|
||||
variant="small-medium"
|
||||
className="!font-semibold text-slate-600"
|
||||
>
|
||||
Pin:
|
||||
</Text>
|
||||
<Text variant="small" className="text-slate-700">
|
||||
{beautifyString(key)}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="w-full space-y-2">
|
||||
<Text
|
||||
variant="small"
|
||||
className="!font-semibold text-slate-600"
|
||||
>
|
||||
Data:
|
||||
</Text>
|
||||
<div className="relative space-y-2">
|
||||
{value.map((item, index) => (
|
||||
<div key={index}>
|
||||
<ContentRenderer
|
||||
value={item}
|
||||
shortContent={true}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<div className="mt-1 flex justify-end gap-1">
|
||||
<NodeDataViewer
|
||||
data={value}
|
||||
pinName={key}
|
||||
execId={executionResultId}
|
||||
/>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => handleCopy(key, value)}
|
||||
className={cn(
|
||||
"h-fit min-w-0 gap-1.5 border border-zinc-200 p-2 text-black hover:text-slate-900",
|
||||
copiedKey === key &&
|
||||
"border-green-400 bg-green-100 hover:border-green-400 hover:bg-green-200",
|
||||
)}
|
||||
>
|
||||
{copiedKey === key ? (
|
||||
<CheckIcon size={12} className="text-green-600" />
|
||||
) : (
|
||||
<CopyIcon size={12} />
|
||||
)}
|
||||
</Button>
|
||||
<div className="mt-1 flex justify-end gap-1">
|
||||
<NodeDataViewer
|
||||
pinName={key}
|
||||
nodeId={nodeId}
|
||||
execId={executionResultId}
|
||||
/>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => handleCopy(key, value)}
|
||||
className={cn(
|
||||
"h-fit min-w-0 gap-1.5 border border-zinc-200 p-2 text-black hover:text-slate-900",
|
||||
copiedKey === key &&
|
||||
"border-green-400 bg-green-100 hover:border-green-400 hover:bg-green-200",
|
||||
)}
|
||||
>
|
||||
{copiedKey === key ? (
|
||||
<CheckIcon
|
||||
size={12}
|
||||
className="text-green-600"
|
||||
/>
|
||||
) : (
|
||||
<CopyIcon size={12} />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{Object.keys(outputData).length > 2 && (
|
||||
<ViewMoreData
|
||||
outputData={outputData}
|
||||
execId={executionResultId}
|
||||
/>
|
||||
)}
|
||||
<ViewMoreData nodeId={nodeId} />
|
||||
</AccordionContent>
|
||||
</AccordionItem>
|
||||
</Accordion>
|
||||
|
||||
@@ -19,22 +19,51 @@ import {
|
||||
CopyIcon,
|
||||
DownloadIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { FC } from "react";
|
||||
import React, { FC } from "react";
|
||||
import { useNodeDataViewer } from "./useNodeDataViewer";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { NodeDataType } from "../../helpers";
|
||||
|
||||
interface NodeDataViewerProps {
|
||||
data: any;
|
||||
export interface NodeDataViewerProps {
|
||||
data?: any;
|
||||
pinName: string;
|
||||
nodeId?: string;
|
||||
execId?: string;
|
||||
isViewMoreData?: boolean;
|
||||
dataType?: NodeDataType;
|
||||
}
|
||||
|
||||
export const NodeDataViewer: FC<NodeDataViewerProps> = ({
|
||||
data,
|
||||
pinName,
|
||||
nodeId,
|
||||
execId = "N/A",
|
||||
isViewMoreData = false,
|
||||
dataType = "output",
|
||||
}) => {
|
||||
const executionResults = useNodeStore(
|
||||
useShallow((state) =>
|
||||
nodeId ? state.getNodeExecutionResults(nodeId) : [],
|
||||
),
|
||||
);
|
||||
const latestInputData = useNodeStore(
|
||||
useShallow((state) =>
|
||||
nodeId ? state.getLatestNodeInputData(nodeId) : undefined,
|
||||
),
|
||||
);
|
||||
const accumulatedOutputData = useNodeStore(
|
||||
useShallow((state) =>
|
||||
nodeId ? state.getAccumulatedNodeOutputData(nodeId) : {},
|
||||
),
|
||||
);
|
||||
|
||||
const resolvedData =
|
||||
data ??
|
||||
(dataType === "input"
|
||||
? (latestInputData ?? {})
|
||||
: (accumulatedOutputData[pinName] ?? []));
|
||||
|
||||
const {
|
||||
outputItems,
|
||||
copyExecutionId,
|
||||
@@ -42,7 +71,20 @@ export const NodeDataViewer: FC<NodeDataViewerProps> = ({
|
||||
handleDownloadItem,
|
||||
dataArray,
|
||||
copiedIndex,
|
||||
} = useNodeDataViewer(data, pinName, execId);
|
||||
groupedExecutions,
|
||||
totalGroupedItems,
|
||||
handleCopyGroupedItem,
|
||||
handleDownloadGroupedItem,
|
||||
copiedKey,
|
||||
} = useNodeDataViewer(
|
||||
resolvedData,
|
||||
pinName,
|
||||
execId,
|
||||
executionResults,
|
||||
dataType,
|
||||
);
|
||||
|
||||
const shouldGroupExecutions = groupedExecutions.length > 0;
|
||||
return (
|
||||
<Dialog styling={{ width: "600px" }}>
|
||||
<TooltipProvider>
|
||||
@@ -68,44 +110,141 @@ export const NodeDataViewer: FC<NodeDataViewerProps> = ({
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="flex items-center gap-2">
|
||||
<Text variant="large-medium" className="text-slate-900">
|
||||
Full Output Preview
|
||||
Full {dataType === "input" ? "Input" : "Output"} Preview
|
||||
</Text>
|
||||
</div>
|
||||
<div className="rounded-full border border-slate-300 bg-slate-100 px-3 py-1.5 text-xs font-medium text-black">
|
||||
{dataArray.length} item{dataArray.length !== 1 ? "s" : ""} total
|
||||
{shouldGroupExecutions ? totalGroupedItems : dataArray.length}{" "}
|
||||
item
|
||||
{shouldGroupExecutions
|
||||
? totalGroupedItems !== 1
|
||||
? "s"
|
||||
: ""
|
||||
: dataArray.length !== 1
|
||||
? "s"
|
||||
: ""}{" "}
|
||||
total
|
||||
</div>
|
||||
</div>
|
||||
<div className="text-sm text-gray-600">
|
||||
<div className="flex items-center gap-2">
|
||||
<Text variant="body" className="text-slate-600">
|
||||
Execution ID:
|
||||
</Text>
|
||||
<Text
|
||||
variant="body-medium"
|
||||
className="rounded-full border border-gray-300 bg-gray-50 px-2 py-1 font-mono text-xs"
|
||||
>
|
||||
{execId}
|
||||
</Text>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={copyExecutionId}
|
||||
className="h-6 w-6 min-w-0 p-0"
|
||||
>
|
||||
<CopyIcon size={14} />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
Pin:{" "}
|
||||
<span className="font-semibold">{beautifyString(pinName)}</span>
|
||||
</div>
|
||||
{shouldGroupExecutions ? (
|
||||
<div>
|
||||
Pin:{" "}
|
||||
<span className="font-semibold">{beautifyString(pinName)}</span>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="flex items-center gap-2">
|
||||
<Text variant="body" className="text-slate-600">
|
||||
Execution ID:
|
||||
</Text>
|
||||
<Text
|
||||
variant="body-medium"
|
||||
className="rounded-full border border-gray-300 bg-gray-50 px-2 py-1 font-mono text-xs"
|
||||
>
|
||||
{execId}
|
||||
</Text>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={copyExecutionId}
|
||||
className="h-6 w-6 min-w-0 p-0"
|
||||
>
|
||||
<CopyIcon size={14} />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
Pin:{" "}
|
||||
<span className="font-semibold">
|
||||
{beautifyString(pinName)}
|
||||
</span>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 overflow-hidden">
|
||||
<ScrollArea className="h-full">
|
||||
<div className="my-4">
|
||||
{dataArray.length > 0 ? (
|
||||
{shouldGroupExecutions ? (
|
||||
<div className="space-y-4">
|
||||
{groupedExecutions.map((execution) => (
|
||||
<div
|
||||
key={execution.execId}
|
||||
className="rounded-3xl border border-slate-200 bg-white p-4 shadow-sm"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Text variant="body" className="text-slate-600">
|
||||
Execution ID:
|
||||
</Text>
|
||||
<Text
|
||||
variant="body-medium"
|
||||
className="rounded-full border border-gray-300 bg-gray-50 px-2 py-1 font-mono text-xs"
|
||||
>
|
||||
{execution.execId}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="mt-2 space-y-4">
|
||||
{execution.outputItems.length > 0 ? (
|
||||
execution.outputItems.map((item, index) => (
|
||||
<div
|
||||
key={item.key}
|
||||
className="group flex items-start gap-4"
|
||||
>
|
||||
<div className="w-full flex-1">
|
||||
<OutputItem
|
||||
value={item.value}
|
||||
metadata={item.metadata}
|
||||
renderer={item.renderer}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex w-fit gap-3">
|
||||
<Button
|
||||
variant="secondary"
|
||||
className="min-w-0 p-1"
|
||||
size="icon"
|
||||
onClick={() =>
|
||||
handleCopyGroupedItem(
|
||||
execution.execId,
|
||||
index,
|
||||
item,
|
||||
)
|
||||
}
|
||||
aria-label="Copy item"
|
||||
>
|
||||
{copiedKey ===
|
||||
`${execution.execId}-${index}` ? (
|
||||
<CheckIcon className="size-4 text-green-600" />
|
||||
) : (
|
||||
<CopyIcon className="size-4 text-black" />
|
||||
)}
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="icon"
|
||||
className="min-w-0 p-1"
|
||||
onClick={() =>
|
||||
handleDownloadGroupedItem(item)
|
||||
}
|
||||
aria-label="Download item"
|
||||
>
|
||||
<DownloadIcon className="size-4 text-black" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<div className="py-4 text-center text-gray-500">
|
||||
No data available
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
) : dataArray.length > 0 ? (
|
||||
<div className="space-y-4">
|
||||
{outputItems.map((item, index) => (
|
||||
<div key={item.key} className="group relative">
|
||||
|
||||
@@ -1,82 +1,70 @@
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
import { downloadOutputs } from "@/components/contextual/OutputRenderers/utils/download";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import React, { useMemo, useState } from "react";
|
||||
import { useState } from "react";
|
||||
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
|
||||
import {
|
||||
NodeDataType,
|
||||
createOutputItems,
|
||||
getExecutionData,
|
||||
normalizeToArray,
|
||||
type OutputItem,
|
||||
} from "../../helpers";
|
||||
|
||||
export type GroupedExecution = {
|
||||
execId: string;
|
||||
outputItems: Array<OutputItem>;
|
||||
};
|
||||
|
||||
export const useNodeDataViewer = (
|
||||
data: any,
|
||||
pinName: string,
|
||||
execId: string,
|
||||
executionResults?: NodeExecutionResult[],
|
||||
dataType?: NodeDataType,
|
||||
) => {
|
||||
const { toast } = useToast();
|
||||
const [copiedIndex, setCopiedIndex] = useState<number | null>(null);
|
||||
const [copiedKey, setCopiedKey] = useState<string | null>(null);
|
||||
|
||||
// Normalize data to array format
|
||||
const dataArray = useMemo(() => {
|
||||
return Array.isArray(data) ? data : [data];
|
||||
}, [data]);
|
||||
const dataArray = Array.isArray(data) ? data : [data];
|
||||
|
||||
// Prepare items for the enhanced renderer system
|
||||
const outputItems = useMemo(() => {
|
||||
if (!dataArray) return [];
|
||||
|
||||
const items: Array<{
|
||||
key: string;
|
||||
label: string;
|
||||
value: unknown;
|
||||
metadata?: OutputMetadata;
|
||||
renderer: any;
|
||||
}> = [];
|
||||
|
||||
dataArray.forEach((value, index) => {
|
||||
const metadata: OutputMetadata = {};
|
||||
|
||||
// Extract metadata from the value if it's an object
|
||||
if (
|
||||
typeof value === "object" &&
|
||||
value !== null &&
|
||||
!React.isValidElement(value)
|
||||
) {
|
||||
const objValue = value as any;
|
||||
if (objValue.type) metadata.type = objValue.type;
|
||||
if (objValue.mimeType) metadata.mimeType = objValue.mimeType;
|
||||
if (objValue.filename) metadata.filename = objValue.filename;
|
||||
if (objValue.language) metadata.language = objValue.language;
|
||||
}
|
||||
|
||||
const renderer = globalRegistry.getRenderer(value, metadata);
|
||||
if (renderer) {
|
||||
items.push({
|
||||
key: `item-${index}`,
|
||||
const outputItems =
|
||||
!dataArray || dataArray.length === 0
|
||||
? []
|
||||
: createOutputItems(dataArray).map((item, index) => ({
|
||||
...item,
|
||||
label: index === 0 ? beautifyString(pinName) : "",
|
||||
value,
|
||||
metadata,
|
||||
renderer,
|
||||
});
|
||||
} else {
|
||||
// Fallback to text renderer
|
||||
const textRenderer = globalRegistry
|
||||
.getAllRenderers()
|
||||
.find((r) => r.name === "TextRenderer");
|
||||
if (textRenderer) {
|
||||
items.push({
|
||||
key: `item-${index}`,
|
||||
label: index === 0 ? beautifyString(pinName) : "",
|
||||
value:
|
||||
typeof value === "string"
|
||||
? value
|
||||
: JSON.stringify(value, null, 2),
|
||||
metadata,
|
||||
renderer: textRenderer,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
}));
|
||||
|
||||
return items;
|
||||
}, [dataArray, pinName]);
|
||||
const groupedExecutions =
|
||||
!executionResults || executionResults.length === 0
|
||||
? []
|
||||
: [...executionResults].reverse().map((result) => {
|
||||
const rawData = getExecutionData(
|
||||
result,
|
||||
dataType || "output",
|
||||
pinName,
|
||||
);
|
||||
let dataArray: unknown[];
|
||||
if (dataType === "input") {
|
||||
dataArray =
|
||||
rawData !== undefined && rawData !== null ? [rawData] : [];
|
||||
} else {
|
||||
dataArray = normalizeToArray(rawData);
|
||||
}
|
||||
|
||||
const outputItems = createOutputItems(dataArray);
|
||||
return {
|
||||
execId: result.node_exec_id,
|
||||
outputItems,
|
||||
};
|
||||
});
|
||||
|
||||
const totalGroupedItems = groupedExecutions.reduce(
|
||||
(total, execution) => total + execution.outputItems.length,
|
||||
0,
|
||||
);
|
||||
|
||||
const copyExecutionId = () => {
|
||||
navigator.clipboard.writeText(execId).then(() => {
|
||||
@@ -122,6 +110,45 @@ export const useNodeDataViewer = (
|
||||
]);
|
||||
};
|
||||
|
||||
const handleCopyGroupedItem = async (
|
||||
execId: string,
|
||||
index: number,
|
||||
item: OutputItem,
|
||||
) => {
|
||||
const copyContent = item.renderer.getCopyContent(item.value, item.metadata);
|
||||
|
||||
if (!copyContent) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
let text: string;
|
||||
if (typeof copyContent.data === "string") {
|
||||
text = copyContent.data;
|
||||
} else if (copyContent.fallbackText) {
|
||||
text = copyContent.fallbackText;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
await navigator.clipboard.writeText(text);
|
||||
setCopiedKey(`${execId}-${index}`);
|
||||
setTimeout(() => setCopiedKey(null), 2000);
|
||||
} catch (error) {
|
||||
console.error("Failed to copy:", error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDownloadGroupedItem = (item: OutputItem) => {
|
||||
downloadOutputs([
|
||||
{
|
||||
value: item.value,
|
||||
metadata: item.metadata,
|
||||
renderer: item.renderer,
|
||||
},
|
||||
]);
|
||||
};
|
||||
|
||||
return {
|
||||
outputItems,
|
||||
dataArray,
|
||||
@@ -129,5 +156,10 @@ export const useNodeDataViewer = (
|
||||
handleCopyItem,
|
||||
handleDownloadItem,
|
||||
copiedIndex,
|
||||
groupedExecutions,
|
||||
totalGroupedItems,
|
||||
handleCopyGroupedItem,
|
||||
handleDownloadGroupedItem,
|
||||
copiedKey,
|
||||
};
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user