mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-30 09:28:19 -05:00
Compare commits
17 Commits
add-llm-ma
...
abhi/check
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
26add35418 | ||
|
|
c6e5f83de8 | ||
|
|
c3a126e705 | ||
|
|
73d8323fe4 | ||
|
|
e0dfae5732 | ||
|
|
7df867d645 | ||
|
|
d855f79874 | ||
|
|
dac99694fe | ||
|
|
0953983944 | ||
|
|
0058cd3ba6 | ||
|
|
ea035224bc | ||
|
|
62813a1ea6 | ||
|
|
67405f7eb9 | ||
|
|
171ff6e776 | ||
|
|
349b1f9c79 | ||
|
|
277b0537e9 | ||
|
|
071b3bb5cd |
@@ -122,24 +122,6 @@ class ConnectionManager:
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
|
||||
"""Broadcast a message to all active websocket connections."""
|
||||
message = WSMessage(
|
||||
method=method,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
|
||||
connections = tuple(self.active_connections)
|
||||
if not connections:
|
||||
return 0
|
||||
|
||||
await asyncio.gather(
|
||||
*(connection.send_text(message) for connection in connections),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
|
||||
@@ -176,64 +176,30 @@ async def get_execution_analytics_config(
|
||||
# Return with provider prefix for clarity
|
||||
return f"{provider_name}: {model_name}"
|
||||
|
||||
# Get all models from the registry (dynamic, not hardcoded enum)
|
||||
from backend.data import llm_registry
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
|
||||
# Get the recommended model from the database (configurable via admin UI)
|
||||
recommended_model_slug = await llm_db.get_recommended_model_slug()
|
||||
|
||||
# Build the available models list
|
||||
first_enabled_slug = None
|
||||
for registry_model in llm_registry.iter_dynamic_models():
|
||||
# Only include enabled models in the list
|
||||
if not registry_model.is_enabled:
|
||||
continue
|
||||
|
||||
# Track first enabled model as fallback
|
||||
if first_enabled_slug is None:
|
||||
first_enabled_slug = registry_model.slug
|
||||
|
||||
model_enum = LlmModel(registry_model.slug) # Create enum instance from slug
|
||||
label = generate_model_label(model_enum)
|
||||
# Include all LlmModel values (no more filtering by hardcoded list)
|
||||
recommended_model = LlmModel.GPT4O_MINI.value
|
||||
for model in LlmModel:
|
||||
label = generate_model_label(model)
|
||||
# Add "(Recommended)" suffix to the recommended model
|
||||
if registry_model.slug == recommended_model_slug:
|
||||
if model.value == recommended_model:
|
||||
label += " (Recommended)"
|
||||
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
value=registry_model.slug,
|
||||
value=model.value,
|
||||
label=label,
|
||||
provider=registry_model.metadata.provider,
|
||||
provider=model.provider,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort models by provider and name for better UX
|
||||
available_models.sort(key=lambda x: (x.provider, x.label))
|
||||
|
||||
# Handle case where no models are available
|
||||
if not available_models:
|
||||
logger.warning(
|
||||
"No enabled LLM models found in registry. "
|
||||
"Ensure models are configured and enabled in the LLM Registry."
|
||||
)
|
||||
# Provide a placeholder entry so admins see meaningful feedback
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
value="",
|
||||
label="No models available - configure in LLM Registry",
|
||||
provider="none",
|
||||
)
|
||||
)
|
||||
|
||||
# Use the DB recommended model, or fallback to first enabled model
|
||||
final_recommended = recommended_model_slug or first_enabled_slug or ""
|
||||
|
||||
return ExecutionAnalyticsConfig(
|
||||
available_models=available_models,
|
||||
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||
default_user_prompt=DEFAULT_USER_PROMPT,
|
||||
recommended_model=final_recommended,
|
||||
recommended_model=recommended_model,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,595 +0,0 @@
|
||||
import logging
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
tags=["llm", "admin"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
async def _refresh_runtime_state() -> None:
|
||||
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
|
||||
logger.info("Refreshing LLM registry runtime state...")
|
||||
try:
|
||||
# Refresh registry from database
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated model options
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
logger.info("Cleared all block schema caches")
|
||||
|
||||
# Clear the /blocks endpoint cache so frontend gets updated schemas
|
||||
try:
|
||||
from backend.api.features.v1 import _get_cached_blocks
|
||||
|
||||
_get_cached_blocks.cache_clear()
|
||||
logger.info("Cleared /blocks endpoint cache")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to clear /blocks cache: %s", e)
|
||||
|
||||
# Clear the v2 builder caches (if they exist)
|
||||
try:
|
||||
from backend.api.features.builder import db as builder_db
|
||||
|
||||
if hasattr(builder_db, "_get_all_providers"):
|
||||
builder_db._get_all_providers.cache_clear()
|
||||
logger.info("Cleared v2 builder providers cache")
|
||||
if hasattr(builder_db, "_build_cached_search_results"):
|
||||
builder_db._build_cached_search_results.cache_clear()
|
||||
logger.info("Cleared v2 builder search results cache")
|
||||
except Exception as e:
|
||||
logger.debug("Could not clear v2 builder cache: %s", e)
|
||||
|
||||
# Notify all executor services to refresh their registry cache
|
||||
from backend.data.llm_registry import publish_registry_refresh_notification
|
||||
|
||||
await publish_registry_refresh_notification()
|
||||
logger.info("Published registry refresh notification")
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"LLM runtime state refresh failed; caches may be stale: %s", exc
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
summary="List LLM providers",
|
||||
response_model=llm_model.LlmProvidersResponse,
|
||||
)
|
||||
async def list_llm_providers(include_models: bool = True):
|
||||
providers = await llm_db.list_providers(include_models=include_models)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/providers",
|
||||
summary="Create LLM provider",
|
||||
response_model=llm_model.LlmProvider,
|
||||
)
|
||||
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
|
||||
provider = await llm_db.upsert_provider(request=request)
|
||||
await _refresh_runtime_state()
|
||||
return provider
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/providers/{provider_id}",
|
||||
summary="Update LLM provider",
|
||||
response_model=llm_model.LlmProvider,
|
||||
)
|
||||
async def update_llm_provider(
|
||||
provider_id: str,
|
||||
request: llm_model.UpsertLlmProviderRequest,
|
||||
):
|
||||
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
|
||||
await _refresh_runtime_state()
|
||||
return provider
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/providers/{provider_id}",
|
||||
summary="Delete LLM provider",
|
||||
response_model=dict,
|
||||
)
|
||||
async def delete_llm_provider(provider_id: str):
|
||||
"""
|
||||
Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
Delete all models from the provider first before deleting the provider.
|
||||
"""
|
||||
try:
|
||||
await llm_db.delete_provider(provider_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Deleted LLM provider '%s'", provider_id)
|
||||
return {"success": True, "message": "Provider deleted successfully"}
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to delete provider '%s': %s", provider_id, e)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to delete provider '%s': %s", provider_id, e)
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
summary="List LLM models",
|
||||
response_model=llm_model.LlmModelsResponse,
|
||||
)
|
||||
async def list_llm_models(
|
||||
provider_id: str | None = fastapi.Query(default=None),
|
||||
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = fastapi.Query(
|
||||
default=50, ge=1, le=100, description="Number of models per page"
|
||||
),
|
||||
):
|
||||
return await llm_db.list_models(
|
||||
provider_id=provider_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/models",
|
||||
summary="Create LLM model",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
|
||||
model = await llm_db.create_model(request=request)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/models/{model_id}",
|
||||
summary="Update LLM model",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def update_llm_model(
|
||||
model_id: str,
|
||||
request: llm_model.UpdateLlmModelRequest,
|
||||
):
|
||||
model = await llm_db.update_model(model_id=model_id, request=request)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/models/{model_id}/toggle",
|
||||
summary="Toggle LLM model availability",
|
||||
response_model=llm_model.ToggleLlmModelResponse,
|
||||
)
|
||||
async def toggle_llm_model(
|
||||
model_id: str,
|
||||
request: llm_model.ToggleLlmModelRequest,
|
||||
):
|
||||
"""
|
||||
Toggle a model's enabled status, optionally migrating workflows when disabling.
|
||||
|
||||
If disabling a model and `migrate_to_slug` is provided, all workflows using
|
||||
this model will be migrated to the specified replacement model before disabling.
|
||||
A migration record is created which can be reverted later using the revert endpoint.
|
||||
|
||||
Optional fields:
|
||||
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
|
||||
- `custom_credit_cost`: Custom pricing override for billing during migration
|
||||
"""
|
||||
try:
|
||||
result = await llm_db.toggle_model(
|
||||
model_id=model_id,
|
||||
is_enabled=request.is_enabled,
|
||||
migrate_to_slug=request.migrate_to_slug,
|
||||
migration_reason=request.migration_reason,
|
||||
custom_credit_cost=request.custom_credit_cost,
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
if result.nodes_migrated > 0:
|
||||
logger.info(
|
||||
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
|
||||
result.model.slug,
|
||||
"enabled" if request.is_enabled else "disabled",
|
||||
result.nodes_migrated,
|
||||
result.migrated_to_slug,
|
||||
result.migration_id,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
logger.warning("Model toggle validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to toggle model availability",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models/{model_id}/usage",
|
||||
summary="Get model usage count",
|
||||
response_model=llm_model.LlmModelUsageResponse,
|
||||
)
|
||||
async def get_llm_model_usage(model_id: str):
|
||||
"""Get the number of workflow nodes using this model."""
|
||||
try:
|
||||
return await llm_db.get_model_usage(model_id=model_id)
|
||||
except ValueError as exc:
|
||||
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get model usage %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get model usage",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/models/{model_id}",
|
||||
summary="Delete LLM model and migrate workflows",
|
||||
response_model=llm_model.DeleteLlmModelResponse,
|
||||
)
|
||||
async def delete_llm_model(
|
||||
model_id: str,
|
||||
replacement_model_slug: str | None = fastapi.Query(
|
||||
default=None,
|
||||
description="Slug of the model to migrate existing workflows to (required only if workflows use this model)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Delete a model and optionally migrate workflows using it to a replacement model.
|
||||
|
||||
If no workflows are using this model, it can be deleted without providing a
|
||||
replacement. If workflows exist, replacement_model_slug is required.
|
||||
|
||||
This endpoint:
|
||||
1. Counts how many workflow nodes use the model being deleted
|
||||
2. If nodes exist, validates the replacement model and migrates them
|
||||
3. Deletes the model record
|
||||
4. Refreshes all caches and notifies executors
|
||||
|
||||
Example: DELETE /admin/llm/models/{id}?replacement_model_slug=gpt-4o
|
||||
Example (no usage): DELETE /admin/llm/models/{id}
|
||||
"""
|
||||
try:
|
||||
result = await llm_db.delete_model(
|
||||
model_id=model_id, replacement_model_slug=replacement_model_slug
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Deleted model '%s' and migrated %d nodes to '%s'",
|
||||
result.deleted_model_slug,
|
||||
result.nodes_migrated,
|
||||
result.replacement_model_slug,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
# Validation errors (model not found, replacement invalid, etc.)
|
||||
logger.warning("Model deletion validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete model and migrate workflows",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Migration Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/migrations",
|
||||
summary="List model migrations",
|
||||
response_model=llm_model.LlmMigrationsResponse,
|
||||
)
|
||||
async def list_llm_migrations(
|
||||
include_reverted: bool = fastapi.Query(
|
||||
default=False, description="Include reverted migrations in the list"
|
||||
),
|
||||
):
|
||||
"""
|
||||
List all model migrations.
|
||||
|
||||
Migrations are created when disabling a model with the migrate_to_slug option.
|
||||
They can be reverted to restore the original model configuration.
|
||||
"""
|
||||
try:
|
||||
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
|
||||
return llm_model.LlmMigrationsResponse(migrations=migrations)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to list migrations: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to list migrations",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/migrations/{migration_id}",
|
||||
summary="Get migration details",
|
||||
response_model=llm_model.LlmModelMigration,
|
||||
)
|
||||
async def get_llm_migration(migration_id: str):
|
||||
"""Get details of a specific migration."""
|
||||
try:
|
||||
migration = await llm_db.get_migration(migration_id)
|
||||
if not migration:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Migration '{migration_id}' not found"
|
||||
)
|
||||
return migration
|
||||
except fastapi.HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get migration %s: %s", migration_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get migration",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/migrations/{migration_id}/revert",
|
||||
summary="Revert a model migration",
|
||||
response_model=llm_model.RevertMigrationResponse,
|
||||
)
|
||||
async def revert_llm_migration(
|
||||
migration_id: str,
|
||||
request: llm_model.RevertMigrationRequest | None = None,
|
||||
):
|
||||
"""
|
||||
Revert a model migration, restoring affected workflows to their original model.
|
||||
|
||||
This only reverts the specific nodes that were part of the migration.
|
||||
The source model must exist for the revert to succeed.
|
||||
|
||||
Options:
|
||||
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
|
||||
|
||||
Response includes:
|
||||
- `nodes_reverted`: Number of nodes successfully reverted
|
||||
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
|
||||
- `source_model_re_enabled`: Whether the source model was re-enabled
|
||||
|
||||
Requirements:
|
||||
- Migration must not already be reverted
|
||||
- Source model must exist
|
||||
"""
|
||||
try:
|
||||
re_enable = request.re_enable_source_model if request else True
|
||||
result = await llm_db.revert_migration(
|
||||
migration_id,
|
||||
re_enable_source_model=re_enable,
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
|
||||
"(%d already changed, source re-enabled=%s)",
|
||||
migration_id,
|
||||
result.nodes_reverted,
|
||||
result.target_model_slug,
|
||||
result.source_model_slug,
|
||||
result.nodes_already_changed,
|
||||
result.source_model_re_enabled,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
logger.warning("Migration revert validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to revert migration",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Creator Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creators",
|
||||
summary="List model creators",
|
||||
response_model=llm_model.LlmCreatorsResponse,
|
||||
)
|
||||
async def list_llm_creators():
|
||||
"""
|
||||
List all model creators.
|
||||
|
||||
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
|
||||
This is distinct from providers who host/serve the models (e.g., OpenRouter).
|
||||
"""
|
||||
try:
|
||||
creators = await llm_db.list_creators()
|
||||
return llm_model.LlmCreatorsResponse(creators=creators)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to list creators: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to list creators",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creators/{creator_id}",
|
||||
summary="Get creator details",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def get_llm_creator(creator_id: str):
|
||||
"""Get details of a specific model creator."""
|
||||
try:
|
||||
creator = await llm_db.get_creator(creator_id)
|
||||
if not creator:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Creator '{creator_id}' not found"
|
||||
)
|
||||
return creator
|
||||
except fastapi.HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/creators",
|
||||
summary="Create model creator",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
|
||||
"""
|
||||
Create a new model creator.
|
||||
|
||||
A creator represents an organization that creates/trains AI models,
|
||||
such as OpenAI, Anthropic, Meta, or Google.
|
||||
"""
|
||||
try:
|
||||
creator = await llm_db.upsert_creator(request=request)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
|
||||
return creator
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to create creator: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to create creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/creators/{creator_id}",
|
||||
summary="Update model creator",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def update_llm_creator(
|
||||
creator_id: str,
|
||||
request: llm_model.UpsertLlmCreatorRequest,
|
||||
):
|
||||
"""Update an existing model creator."""
|
||||
try:
|
||||
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
|
||||
return creator
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to update creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to update creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/creators/{creator_id}",
|
||||
summary="Delete model creator",
|
||||
response_model=dict,
|
||||
)
|
||||
async def delete_llm_creator(creator_id: str):
|
||||
"""
|
||||
Delete a model creator.
|
||||
|
||||
This will remove the creator association from all models that reference it
|
||||
(sets creatorId to NULL), but will not delete the models themselves.
|
||||
"""
|
||||
try:
|
||||
await llm_db.delete_creator(creator_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Deleted model creator '%s'", creator_id)
|
||||
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
|
||||
except ValueError as exc:
|
||||
logger.warning("Creator deletion validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete creator",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Recommended Model Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/recommended-model",
|
||||
summary="Get recommended model",
|
||||
response_model=llm_model.RecommendedModelResponse,
|
||||
)
|
||||
async def get_recommended_model():
|
||||
"""
|
||||
Get the currently recommended LLM model.
|
||||
|
||||
The recommended model is shown to users as the default/suggested option
|
||||
in model selection dropdowns.
|
||||
"""
|
||||
try:
|
||||
model = await llm_db.get_recommended_model()
|
||||
return llm_model.RecommendedModelResponse(
|
||||
model=model,
|
||||
slug=model.slug if model else None,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get recommended model: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get recommended model",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/recommended-model",
|
||||
summary="Set recommended model",
|
||||
response_model=llm_model.SetRecommendedModelResponse,
|
||||
)
|
||||
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
|
||||
"""
|
||||
Set a model as the recommended model.
|
||||
|
||||
This clears the recommended flag from any other model and sets it on
|
||||
the specified model. The model must be enabled to be set as recommended.
|
||||
|
||||
The recommended model is displayed to users as the default/suggested
|
||||
option in model selection dropdowns throughout the platform.
|
||||
"""
|
||||
try:
|
||||
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Set recommended model to '%s' (previous: %s)",
|
||||
model.slug,
|
||||
previous_slug or "none",
|
||||
)
|
||||
return llm_model.SetRecommendedModelResponse(
|
||||
model=model,
|
||||
previous_recommended_slug=previous_slug,
|
||||
message=f"Model '{model.display_name}' is now the recommended model",
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.warning("Set recommended model validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to set recommended model: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to set recommended model",
|
||||
) from exc
|
||||
@@ -1,491 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.api.features.admin.llm_routes as llm_routes
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
from backend.util.models import Pagination
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(llm_routes.router, prefix="/admin/llm")
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_list_llm_providers_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful listing of LLM providers"""
|
||||
# Mock the database function
|
||||
mock_providers = [
|
||||
{
|
||||
"id": "provider-1",
|
||||
"name": "openai",
|
||||
"display_name": "OpenAI",
|
||||
"description": "OpenAI LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": True,
|
||||
"metadata": {},
|
||||
"models": [],
|
||||
},
|
||||
{
|
||||
"id": "provider-2",
|
||||
"name": "anthropic",
|
||||
"display_name": "Anthropic",
|
||||
"description": "Anthropic LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": True,
|
||||
"metadata": {},
|
||||
"models": [],
|
||||
},
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.list_providers",
|
||||
new=AsyncMock(return_value=mock_providers),
|
||||
)
|
||||
|
||||
response = client.get("/admin/llm/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert len(response_data["providers"]) == 2
|
||||
assert response_data["providers"][0]["name"] == "openai"
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"list_llm_providers_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_list_llm_models_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful listing of LLM models with pagination"""
|
||||
# Mock the database function - now returns LlmModelsResponse
|
||||
mock_model = llm_model.LlmModel(
|
||||
id="model-1",
|
||||
slug="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description="GPT-4 Optimized",
|
||||
provider_id="provider-1",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
is_enabled=True,
|
||||
capabilities={},
|
||||
metadata={},
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
id="cost-1",
|
||||
credit_cost=10,
|
||||
credential_provider="openai",
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mock_response = llm_model.LlmModelsResponse(
|
||||
models=[mock_model],
|
||||
pagination=Pagination(
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=50,
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.list_models",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
response = client.get("/admin/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert len(response_data["models"]) == 1
|
||||
assert response_data["models"][0]["slug"] == "gpt-4o"
|
||||
assert response_data["pagination"]["total_items"] == 1
|
||||
assert response_data["pagination"]["page_size"] == 50
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"list_llm_models_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_create_llm_provider_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful creation of LLM provider"""
|
||||
mock_provider = {
|
||||
"id": "new-provider-id",
|
||||
"name": "groq",
|
||||
"display_name": "Groq",
|
||||
"description": "Groq LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": False,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
|
||||
new=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"name": "groq",
|
||||
"display_name": "Groq",
|
||||
"description": "Groq LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": False,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
response = client.post("/admin/llm/providers", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["name"] == "groq"
|
||||
assert response_data["display_name"] == "Groq"
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"create_llm_provider_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_create_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful creation of LLM model"""
|
||||
mock_model = {
|
||||
"id": "new-model-id",
|
||||
"slug": "gpt-4.1-mini",
|
||||
"display_name": "GPT-4.1 Mini",
|
||||
"description": "Latest GPT-4.1 Mini model",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"id": "cost-id",
|
||||
"credit_cost": 5,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.create_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"slug": "gpt-4.1-mini",
|
||||
"display_name": "GPT-4.1 Mini",
|
||||
"description": "Latest GPT-4.1 Mini model",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"credit_cost": 5,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/admin/llm/models", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["slug"] == "gpt-4.1-mini"
|
||||
assert response_data["is_enabled"] is True
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"create_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_update_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful update of LLM model"""
|
||||
mock_model = {
|
||||
"id": "model-1",
|
||||
"slug": "gpt-4o",
|
||||
"display_name": "GPT-4o Updated",
|
||||
"description": "Updated description",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 256000,
|
||||
"max_output_tokens": 32768,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"id": "cost-1",
|
||||
"credit_cost": 15,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.update_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"display_name": "GPT-4o Updated",
|
||||
"description": "Updated description",
|
||||
"context_window": 256000,
|
||||
"max_output_tokens": 32768,
|
||||
}
|
||||
|
||||
response = client.patch("/admin/llm/models/model-1", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["display_name"] == "GPT-4o Updated"
|
||||
assert response_data["context_window"] == 256000
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"update_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_toggle_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful toggling of LLM model enabled status"""
|
||||
# Create a proper mock model object
|
||||
mock_model = llm_model.LlmModel(
|
||||
id="model-1",
|
||||
slug="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description="GPT-4 Optimized",
|
||||
provider_id="provider-1",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
is_enabled=False,
|
||||
capabilities={},
|
||||
metadata={},
|
||||
costs=[],
|
||||
)
|
||||
|
||||
# Create a proper ToggleLlmModelResponse
|
||||
mock_response = llm_model.ToggleLlmModelResponse(
|
||||
model=mock_model,
|
||||
nodes_migrated=0,
|
||||
migrated_to_slug=None,
|
||||
migration_id=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {"is_enabled": False}
|
||||
|
||||
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["model"]["is_enabled"] is False
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"toggle_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful deletion of LLM model with migration"""
|
||||
# Create a proper DeleteLlmModelResponse
|
||||
mock_response = llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug="gpt-3.5-turbo",
|
||||
deleted_model_display_name="GPT-3.5 Turbo",
|
||||
replacement_model_slug="gpt-4o-mini",
|
||||
nodes_migrated=42,
|
||||
message="Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
|
||||
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.delete(
|
||||
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
|
||||
assert response_data["nodes_migrated"] == 42
|
||||
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"delete_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_llm_model_validation_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion fails with proper error when validation fails"""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Replacement model 'invalid' not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_llm_model_no_replacement_with_usage(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion fails when nodes exist but no replacement is provided"""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(
|
||||
side_effect=ValueError(
|
||||
"Cannot delete model 'test-model': 5 workflow node(s) are using it. "
|
||||
"Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "workflow node(s) are using it" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_llm_model_no_replacement_no_usage(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion succeeds when no nodes use the model and no replacement is provided"""
|
||||
mock_response = llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug="unused-model",
|
||||
deleted_model_display_name="Unused Model",
|
||||
replacement_model_slug=None,
|
||||
nodes_migrated=0,
|
||||
message="Successfully deleted model 'Unused Model' (unused-model). No workflows were using this model.",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["deleted_model_slug"] == "unused-model"
|
||||
assert response_data["nodes_migrated"] == 0
|
||||
assert response_data["replacement_model_slug"] is None
|
||||
mock_refresh.assert_called_once()
|
||||
@@ -15,7 +15,6 @@ from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.llm_registry import get_all_model_slugs_for_validation
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
@@ -32,14 +31,7 @@ from .model import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_llm_models() -> list[str]:
|
||||
"""Get LLM model names for search matching from the registry."""
|
||||
return [
|
||||
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
|
||||
]
|
||||
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
|
||||
MAX_LIBRARY_AGENT_RESULTS = 100
|
||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||
@@ -504,8 +496,8 @@ async def _get_static_counts():
|
||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
for field in schema_cls.model_fields.values():
|
||||
if field.annotation == LlmModel:
|
||||
# Check if query matches any value in llm_models from registry
|
||||
if any(query in name for name in _get_llm_models()):
|
||||
# Check if query matches any value in llm_models
|
||||
if any(query in name for name in llm_models):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -33,9 +33,15 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=3, description="Maximum number of agent schedules"
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Long-running operation configuration
|
||||
long_running_operation_ttl: int = Field(
|
||||
default=600,
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
|
||||
@@ -247,3 +247,45 @@ async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
|
||||
|
||||
async def update_tool_message_content(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
new_content: str,
|
||||
) -> bool:
|
||||
"""Update the content of a tool message in chat history.
|
||||
|
||||
Used by background tasks to update pending operation messages with final results.
|
||||
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
tool_call_id: The tool call ID to find the message.
|
||||
new_content: The new content to set.
|
||||
|
||||
Returns:
|
||||
True if a message was updated, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await PrismaChatMessage.prisma().update_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={
|
||||
"content": new_content,
|
||||
},
|
||||
)
|
||||
if result == 0:
|
||||
logger.warning(
|
||||
f"No message found to update for session {session_id}, "
|
||||
f"tool_call_id {tool_call_id}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update tool message for session {session_id}, "
|
||||
f"tool_call_id {tool_call_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -295,6 +295,21 @@ async def cache_chat_session(session: ChatSession) -> None:
|
||||
await _cache_session(session)
|
||||
|
||||
|
||||
async def invalidate_session_cache(session_id: str) -> None:
|
||||
"""Invalidate a chat session from Redis cache.
|
||||
|
||||
Used by background tasks to ensure fresh data is loaded on next access.
|
||||
This is best-effort - Redis failures are logged but don't fail the operation.
|
||||
"""
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
# Best-effort: log but don't fail - cache will expire naturally
|
||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
|
||||
@@ -113,7 +113,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
||||
# Keep these for internal backend use
|
||||
toolName: str | None = Field(
|
||||
default=None, description="Name of the tool that was executed"
|
||||
)
|
||||
@@ -121,6 +121,15 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-spec fields."""
|
||||
import json
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"toolCallId": self.toolCallId,
|
||||
"output": self.output,
|
||||
}
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
# ========== Other ==========
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from openai import (
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
get_business_understanding,
|
||||
@@ -24,6 +25,7 @@ from backend.data.understanding import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
@@ -31,6 +33,7 @@ from .model import (
|
||||
Usage,
|
||||
cache_chat_session,
|
||||
get_chat_session,
|
||||
invalidate_session_cache,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
@@ -48,8 +51,13 @@ from .response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .tools import execute_tool, tools
|
||||
from .tools.models import ErrorResponse
|
||||
from .tools import execute_tool, get_tool, tools
|
||||
from .tools.models import (
|
||||
ErrorResponse,
|
||||
OperationInProgressResponse,
|
||||
OperationPendingResponse,
|
||||
OperationStartedResponse,
|
||||
)
|
||||
from .tracking import track_user_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -61,11 +69,126 @@ client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
langfuse = get_client()
|
||||
|
||||
# Redis key prefix for tracking running long-running operations
|
||||
# Used for idempotency across Kubernetes pods - prevents duplicate executions on browser refresh
|
||||
RUNNING_OPERATION_PREFIX = "chat:running_operation:"
|
||||
|
||||
class LangfuseNotConfiguredError(Exception):
|
||||
"""Raised when Langfuse is required but not configured."""
|
||||
# Default system prompt used when Langfuse is not configured
|
||||
# This is a snapshot of the "CoPilot Prompt" from Langfuse (version 11)
|
||||
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||
|
||||
pass
|
||||
Here is everything you know about the current user from previous interactions:
|
||||
|
||||
<users_information>
|
||||
{users_information}
|
||||
</users_information>
|
||||
|
||||
## YOUR CORE MANDATE
|
||||
|
||||
You are action-oriented. Your success is measured by:
|
||||
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||
- **Time Saved**: Focus on tangible efficiency gains
|
||||
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||
|
||||
## YOUR WORKFLOW
|
||||
|
||||
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||
|
||||
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||
|
||||
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||
|
||||
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||
|
||||
4. **Discover or Create Agents**:
|
||||
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||
- Search the marketplace with `find_agent` for pre-built automations
|
||||
- Find reusable components with `find_block`
|
||||
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||
- Modify existing library agents with `edit_agent`
|
||||
|
||||
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||
|
||||
6. **Show Results**: Display outputs using `agent_output`.
|
||||
|
||||
## AVAILABLE TOOLS
|
||||
|
||||
**Understanding & Discovery:**
|
||||
- `add_understanding`: Create a memory about the user's business or use cases for future sessions
|
||||
- `search_docs`: Search platform documentation for specific technical information
|
||||
- `get_doc_page`: Retrieve full text of a specific documentation page
|
||||
|
||||
**Agent Discovery:**
|
||||
- `find_library_agent`: Search the user's existing agents (CHECK HERE FIRST—these may be customized)
|
||||
- `find_agent`: Search the marketplace for pre-built automations
|
||||
- `find_block`: Find pre-written code units that perform specific tasks (agents are built from blocks)
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
- `create_agent`: Create a new automation agent
|
||||
- `edit_agent`: Modify an agent in the user's library
|
||||
|
||||
**Execution & Output:**
|
||||
- `run_agent`: Run an agent now, schedule it, or set up a webhook trigger
|
||||
- `run_block`: Test or run a specific block independently
|
||||
- `agent_output`: View results from previous agent runs
|
||||
|
||||
## BEHAVIORAL GUIDELINES
|
||||
|
||||
**Be Concise:**
|
||||
- Target 2-5 short lines maximum
|
||||
- Make every word count—no repetition or filler
|
||||
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||
|
||||
**Be Proactive:**
|
||||
- Suggest next steps before being asked
|
||||
- Anticipate needs based on conversation context and user information
|
||||
- Look for opportunities to expand scope when relevant
|
||||
- Reveal capabilities through action, not explanation
|
||||
|
||||
**Use Tools Effectively:**
|
||||
- Select the right tool for each task
|
||||
- **Always check `find_library_agent` before searching the marketplace**
|
||||
- Use `add_understanding` to capture valuable business context
|
||||
- When tool calls fail, try alternative approaches
|
||||
|
||||
## CRITICAL REMINDER
|
||||
|
||||
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||
|
||||
# Module-level set to hold strong references to background tasks.
|
||||
# This prevents asyncio from garbage collecting tasks before they complete.
|
||||
# Tasks are automatically removed on completion via done_callback.
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
async def _mark_operation_started(tool_call_id: str) -> bool:
|
||||
"""Mark a long-running operation as started (Redis-based).
|
||||
|
||||
Returns True if successfully marked (operation was not already running),
|
||||
False if operation was already running (lost race condition).
|
||||
Raises exception if Redis is unavailable (fail-closed).
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
|
||||
# SETNX with TTL - atomic "set if not exists"
|
||||
result = await redis.set(key, "1", ex=config.long_running_operation_ttl, nx=True)
|
||||
return result is not None
|
||||
|
||||
|
||||
async def _mark_operation_completed(tool_call_id: str) -> None:
|
||||
"""Mark a long-running operation as completed (remove Redis key).
|
||||
|
||||
This is best-effort - if Redis fails, the TTL will eventually clean up.
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
|
||||
await redis.delete(key)
|
||||
except Exception as e:
|
||||
# Non-critical: TTL will clean up eventually
|
||||
logger.warning(f"Failed to delete running operation key {tool_call_id}: {e}")
|
||||
|
||||
|
||||
def _is_langfuse_configured() -> bool:
|
||||
@@ -75,6 +198,30 @@ def _is_langfuse_configured() -> bool:
|
||||
)
|
||||
|
||||
|
||||
async def _get_system_prompt_template(context: str) -> str:
|
||||
"""Get the system prompt, trying Langfuse first with fallback to default.
|
||||
|
||||
Args:
|
||||
context: The user context/information to compile into the prompt.
|
||||
|
||||
Returns:
|
||||
The compiled system prompt string.
|
||||
"""
|
||||
if _is_langfuse_configured():
|
||||
try:
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
prompt = await asyncio.to_thread(
|
||||
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
||||
)
|
||||
return prompt.compile(users_information=context)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch prompt from Langfuse, using default: {e}")
|
||||
|
||||
# Fallback to default prompt
|
||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||
|
||||
|
||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
"""Build the full system prompt including business understanding if available.
|
||||
|
||||
@@ -83,12 +230,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (compiled prompt string, Langfuse prompt object for tracing)
|
||||
Tuple of (compiled prompt string, business understanding object)
|
||||
"""
|
||||
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
|
||||
|
||||
# If user is authenticated, try to fetch their business understanding
|
||||
understanding = None
|
||||
if user_id:
|
||||
@@ -97,12 +240,13 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||
understanding = None
|
||||
|
||||
if understanding:
|
||||
context = format_understanding_for_prompt(understanding)
|
||||
else:
|
||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||
|
||||
compiled = prompt.compile(users_information=context)
|
||||
compiled = await _get_system_prompt_template(context)
|
||||
return compiled, understanding
|
||||
|
||||
|
||||
@@ -186,6 +330,7 @@ async def stream_chat_completion(
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None, # {url: str, content: str}
|
||||
_continuation_message_id: str | None = None, # Internal: reuse message ID for tool call continuations
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Main entry point for streaming chat completions with database handling.
|
||||
|
||||
@@ -210,16 +355,6 @@ async def stream_chat_completion(
|
||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||
)
|
||||
|
||||
# Check if Langfuse is configured - required for chat functionality
|
||||
if not _is_langfuse_configured():
|
||||
logger.error("Chat request failed: Langfuse is not configured")
|
||||
yield StreamError(
|
||||
errorText="Chat service is not available. Langfuse must be configured "
|
||||
"with LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables."
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Only fetch from Redis if session not provided (initial call)
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -315,6 +450,7 @@ async def stream_chat_completion(
|
||||
has_yielded_end = False
|
||||
has_yielded_error = False
|
||||
has_done_tool_call = False
|
||||
has_long_running_tool_call = False # Track if we had a long-running tool call
|
||||
has_received_text = False
|
||||
text_streaming_ended = False
|
||||
tool_response_messages: list[ChatMessage] = []
|
||||
@@ -323,11 +459,15 @@ async def stream_chat_completion(
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
import uuid as uuid_module
|
||||
|
||||
message_id = str(uuid_module.uuid4())
|
||||
# Reuse message ID for continuations (tool call follow-ups) to avoid duplicate messages
|
||||
is_continuation = _continuation_message_id is not None
|
||||
message_id = _continuation_message_id or str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Yield message start
|
||||
yield StreamStart(messageId=message_id)
|
||||
# Only yield message start for the initial call, not for continuations
|
||||
# This prevents the AI SDK from creating duplicate message objects
|
||||
if not is_continuation:
|
||||
yield StreamStart(messageId=message_id)
|
||||
|
||||
try:
|
||||
async for chunk in _stream_chat_chunks(
|
||||
@@ -336,7 +476,6 @@ async def stream_chat_completion(
|
||||
system_prompt=system_prompt,
|
||||
text_block_id=text_block_id,
|
||||
):
|
||||
|
||||
if isinstance(chunk, StreamTextStart):
|
||||
# Emit text-start before first text delta
|
||||
if not has_received_text:
|
||||
@@ -394,13 +533,34 @@ async def stream_chat_completion(
|
||||
if isinstance(chunk.output, str)
|
||||
else orjson.dumps(chunk.output).decode("utf-8")
|
||||
)
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=chunk.toolCallId,
|
||||
# Skip saving long-running operation responses - messages already saved in _yield_tool_call
|
||||
# Use JSON parsing instead of substring matching to avoid false positives
|
||||
is_long_running_response = False
|
||||
try:
|
||||
parsed = orjson.loads(result_content)
|
||||
if isinstance(parsed, dict) and parsed.get("type") in (
|
||||
"operation_started",
|
||||
"operation_in_progress",
|
||||
):
|
||||
is_long_running_response = True
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
pass # Not JSON or not a dict - treat as regular response
|
||||
if is_long_running_response:
|
||||
# Remove from accumulated_tool_calls since assistant message was already saved
|
||||
accumulated_tool_calls[:] = [
|
||||
tc
|
||||
for tc in accumulated_tool_calls
|
||||
if tc["id"] != chunk.toolCallId
|
||||
]
|
||||
has_long_running_tool_call = True
|
||||
else:
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=chunk.toolCallId,
|
||||
)
|
||||
)
|
||||
)
|
||||
has_done_tool_call = True
|
||||
# Track if any tool execution failed
|
||||
if not chunk.success:
|
||||
@@ -535,6 +695,7 @@ async def stream_chat_completion(
|
||||
retry_count=retry_count + 1,
|
||||
session=session,
|
||||
context=context,
|
||||
_continuation_message_id=message_id, # Reuse message ID since start was already sent
|
||||
):
|
||||
yield chunk
|
||||
return # Exit after retry to avoid double-saving in finally block
|
||||
@@ -576,7 +737,14 @@ async def stream_chat_completion(
|
||||
logger.info(
|
||||
f"Extended session messages, new message_count={len(session.messages)}"
|
||||
)
|
||||
if messages_to_save or has_appended_streaming_message:
|
||||
# Save if there are regular (non-long-running) tool responses or streaming message.
|
||||
# Long-running tools save their own state, but we still need to save regular tools
|
||||
# that may be in the same response.
|
||||
has_regular_tool_responses = len(tool_response_messages) > 0
|
||||
if has_regular_tool_responses or (
|
||||
not has_long_running_tool_call
|
||||
and (messages_to_save or has_appended_streaming_message)
|
||||
):
|
||||
await upsert_chat_session(session)
|
||||
else:
|
||||
logger.info(
|
||||
@@ -585,7 +753,9 @@ async def stream_chat_completion(
|
||||
)
|
||||
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
if has_done_tool_call:
|
||||
# Skip only if ALL tools were long-running (they handle their own completion)
|
||||
has_regular_tools = len(tool_response_messages) > 0
|
||||
if has_done_tool_call and (has_regular_tools or not has_long_running_tool_call):
|
||||
logger.info(
|
||||
"Tool call executed, streaming chat completion again to get assistant response"
|
||||
)
|
||||
@@ -595,6 +765,7 @@ async def stream_chat_completion(
|
||||
session=session, # Pass session object to avoid Redis refetch
|
||||
context=context,
|
||||
tool_call_response=str(tool_response_messages),
|
||||
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
|
||||
):
|
||||
yield chunk
|
||||
|
||||
@@ -725,6 +896,114 @@ async def _summarize_messages(
|
||||
return summary or "No summary available."
|
||||
|
||||
|
||||
def _ensure_tool_pairs_intact(
|
||||
recent_messages: list[dict],
|
||||
all_messages: list[dict],
|
||||
start_index: int,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||
|
||||
When slicing messages for context compaction, a naive slice can separate
|
||||
an assistant message containing tool_calls from its corresponding tool
|
||||
response messages. This causes API validation errors (e.g., Anthropic's
|
||||
"unexpected tool_use_id found in tool_result blocks").
|
||||
|
||||
This function checks for orphan tool responses in the slice and extends
|
||||
backwards to include their corresponding assistant messages.
|
||||
|
||||
Args:
|
||||
recent_messages: The sliced messages to validate
|
||||
all_messages: The complete message list (for looking up missing assistants)
|
||||
start_index: The index in all_messages where recent_messages begins
|
||||
|
||||
Returns:
|
||||
A potentially extended list of messages with tool pairs intact
|
||||
"""
|
||||
if not recent_messages:
|
||||
return recent_messages
|
||||
|
||||
# Collect all tool_call_ids from assistant messages in the slice
|
||||
available_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id:
|
||||
available_tool_call_ids.add(tc_id)
|
||||
|
||||
# Find orphan tool responses (tool messages whose tool_call_id is missing)
|
||||
orphan_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
if msg.get("role") == "tool":
|
||||
tc_id = msg.get("tool_call_id")
|
||||
if tc_id and tc_id not in available_tool_call_ids:
|
||||
orphan_tool_call_ids.add(tc_id)
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# No orphans, slice is valid
|
||||
return recent_messages
|
||||
|
||||
# Find the assistant messages that contain the orphan tool_call_ids
|
||||
# Search backwards from start_index in all_messages
|
||||
messages_to_prepend: list[dict] = []
|
||||
for i in range(start_index - 1, -1, -1):
|
||||
msg = all_messages[i]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
||||
if msg_tool_ids & orphan_tool_call_ids:
|
||||
# This assistant message has tool_calls we need
|
||||
# Also collect its contiguous tool responses that follow it
|
||||
assistant_and_responses: list[dict] = [msg]
|
||||
|
||||
# Scan forward from this assistant to collect tool responses
|
||||
for j in range(i + 1, start_index):
|
||||
following_msg = all_messages[j]
|
||||
if following_msg.get("role") == "tool":
|
||||
tool_id = following_msg.get("tool_call_id")
|
||||
if tool_id and tool_id in msg_tool_ids:
|
||||
assistant_and_responses.append(following_msg)
|
||||
else:
|
||||
# Stop at first non-tool message
|
||||
break
|
||||
|
||||
# Prepend the assistant and its tool responses (maintain order)
|
||||
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||
# Mark these as found
|
||||
orphan_tool_call_ids -= msg_tool_ids
|
||||
# Also add this assistant's tool_call_ids to available set
|
||||
available_tool_call_ids |= msg_tool_ids
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# Found all missing assistants
|
||||
break
|
||||
|
||||
if orphan_tool_call_ids:
|
||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||
# This shouldn't happen in normal operation but handles edge cases
|
||||
logger.warning(
|
||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||
"Removing orphan tool responses."
|
||||
)
|
||||
recent_messages = [
|
||||
msg
|
||||
for msg in recent_messages
|
||||
if not (
|
||||
msg.get("role") == "tool"
|
||||
and msg.get("tool_call_id") in orphan_tool_call_ids
|
||||
)
|
||||
]
|
||||
|
||||
if messages_to_prepend:
|
||||
logger.info(
|
||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||
f"tool_call/tool_response pairs"
|
||||
)
|
||||
return messages_to_prepend + recent_messages
|
||||
|
||||
return recent_messages
|
||||
|
||||
|
||||
async def _stream_chat_chunks(
|
||||
session: ChatSession,
|
||||
tools: list[ChatCompletionToolParam],
|
||||
@@ -816,7 +1095,15 @@ async def _stream_chat_chunks(
|
||||
# Always attempt mitigation when over limit, even with few messages
|
||||
if messages:
|
||||
# Split messages based on whether system prompt exists
|
||||
recent_messages = messages[-KEEP_RECENT:]
|
||||
# Calculate start index for the slice
|
||||
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
|
||||
recent_messages = messages_dict[-KEEP_RECENT:]
|
||||
|
||||
# Ensure tool_call/tool_response pairs stay together
|
||||
# This prevents API errors from orphan tool responses
|
||||
recent_messages = _ensure_tool_pairs_intact(
|
||||
recent_messages, messages_dict, slice_start
|
||||
)
|
||||
|
||||
if has_system_prompt:
|
||||
# Keep system prompt separate, summarize everything between system and recent
|
||||
@@ -903,6 +1190,13 @@ async def _stream_chat_chunks(
|
||||
if len(recent_messages) >= keep_count
|
||||
else recent_messages
|
||||
)
|
||||
# Ensure tool pairs stay intact in the reduced slice
|
||||
reduced_slice_start = max(
|
||||
0, len(recent_messages) - keep_count
|
||||
)
|
||||
reduced_recent = _ensure_tool_pairs_intact(
|
||||
reduced_recent, recent_messages, reduced_slice_start
|
||||
)
|
||||
if has_system_prompt:
|
||||
messages = [
|
||||
system_msg,
|
||||
@@ -961,7 +1255,10 @@ async def _stream_chat_chunks(
|
||||
|
||||
# Create a base list excluding system prompt to avoid duplication
|
||||
# This is the pool of messages we'll slice from in the loop
|
||||
base_msgs = messages[1:] if has_system_prompt else messages
|
||||
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
|
||||
base_msgs = (
|
||||
messages_dict[1:] if has_system_prompt else messages_dict
|
||||
)
|
||||
|
||||
# Try progressively smaller keep counts
|
||||
new_token_count = token_count # Initialize with current count
|
||||
@@ -984,6 +1281,12 @@ async def _stream_chat_chunks(
|
||||
# Slice from base_msgs to get recent messages (without system prompt)
|
||||
recent_messages = base_msgs[-keep_count:]
|
||||
|
||||
# Ensure tool pairs stay intact in the reduced slice
|
||||
reduced_slice_start = max(0, len(base_msgs) - keep_count)
|
||||
recent_messages = _ensure_tool_pairs_intact(
|
||||
recent_messages, base_msgs, reduced_slice_start
|
||||
)
|
||||
|
||||
if has_system_prompt:
|
||||
messages = [system_msg] + recent_messages
|
||||
else:
|
||||
@@ -1260,17 +1563,19 @@ async def _yield_tool_call(
|
||||
"""
|
||||
Yield a tool call and its execution result.
|
||||
|
||||
For long-running tools, yields heartbeat events every 15 seconds to keep
|
||||
the SSE connection alive through proxies and load balancers.
|
||||
For tools marked with `is_long_running=True` (like agent generation), spawns a
|
||||
background task so the operation survives SSE disconnections. For other tools,
|
||||
yields heartbeat events every 15 seconds to keep the SSE connection alive.
|
||||
|
||||
Raises:
|
||||
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
|
||||
KeyError: If expected tool call fields are missing
|
||||
TypeError: If tool call structure is invalid
|
||||
"""
|
||||
import uuid as uuid_module
|
||||
|
||||
tool_name = tool_calls[yield_idx]["function"]["name"]
|
||||
tool_call_id = tool_calls[yield_idx]["id"]
|
||||
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
|
||||
|
||||
# Parse tool call arguments - handle empty arguments gracefully
|
||||
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
|
||||
@@ -1285,7 +1590,151 @@ async def _yield_tool_call(
|
||||
input=arguments,
|
||||
)
|
||||
|
||||
# Run tool execution in background task with heartbeats to keep connection alive
|
||||
# Check if this tool is long-running (survives SSE disconnection)
|
||||
tool = get_tool(tool_name)
|
||||
if tool and tool.is_long_running:
|
||||
# Atomic check-and-set: returns False if operation already running (lost race)
|
||||
if not await _mark_operation_started(tool_call_id):
|
||||
logger.info(
|
||||
f"Tool call {tool_call_id} already in progress, returning status"
|
||||
)
|
||||
# Build dynamic message based on tool name
|
||||
if tool_name == "create_agent":
|
||||
in_progress_msg = "Agent creation already in progress. Please wait..."
|
||||
elif tool_name == "edit_agent":
|
||||
in_progress_msg = "Agent edit already in progress. Please wait..."
|
||||
else:
|
||||
in_progress_msg = f"{tool_name} already in progress. Please wait..."
|
||||
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=OperationInProgressResponse(
|
||||
message=in_progress_msg,
|
||||
tool_call_id=tool_call_id,
|
||||
).model_dump_json(),
|
||||
success=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Generate operation ID
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
|
||||
# Build a user-friendly message based on tool and arguments
|
||||
if tool_name == "create_agent":
|
||||
agent_desc = arguments.get("description", "")
|
||||
# Truncate long descriptions for the message
|
||||
desc_preview = (
|
||||
(agent_desc[:100] + "...") if len(agent_desc) > 100 else agent_desc
|
||||
)
|
||||
pending_msg = (
|
||||
f"Creating your agent: {desc_preview}"
|
||||
if desc_preview
|
||||
else "Creating agent... This may take a few minutes."
|
||||
)
|
||||
started_msg = (
|
||||
"Agent creation started. You can close this tab - "
|
||||
"check your library in a few minutes."
|
||||
)
|
||||
elif tool_name == "edit_agent":
|
||||
changes = arguments.get("changes", "")
|
||||
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
||||
pending_msg = (
|
||||
f"Editing agent: {changes_preview}"
|
||||
if changes_preview
|
||||
else "Editing agent... This may take a few minutes."
|
||||
)
|
||||
started_msg = (
|
||||
"Agent edit started. You can close this tab - "
|
||||
"check your library in a few minutes."
|
||||
)
|
||||
else:
|
||||
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
||||
started_msg = (
|
||||
f"{tool_name} started. You can close this tab - "
|
||||
"check back in a few minutes."
|
||||
)
|
||||
|
||||
# Track appended messages for rollback on failure
|
||||
assistant_message: ChatMessage | None = None
|
||||
pending_message: ChatMessage | None = None
|
||||
|
||||
# Wrap session save and task creation in try-except to release lock on failure
|
||||
try:
|
||||
# Save assistant message with tool_call FIRST (required by LLM)
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[tool_calls[yield_idx]],
|
||||
)
|
||||
session.messages.append(assistant_message)
|
||||
|
||||
# Then save pending tool result
|
||||
pending_message = ChatMessage(
|
||||
role="tool",
|
||||
content=OperationPendingResponse(
|
||||
message=pending_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
logger.info(
|
||||
f"Saved pending operation {operation_id} for tool {tool_name} "
|
||||
f"in session {session.session_id}"
|
||||
)
|
||||
|
||||
# Store task reference in module-level set to prevent GC before completion
|
||||
task = asyncio.create_task(
|
||||
_execute_long_running_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=arguments,
|
||||
tool_call_id=tool_call_id,
|
||||
operation_id=operation_id,
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except Exception as e:
|
||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||
if (
|
||||
pending_message
|
||||
and session.messages
|
||||
and session.messages[-1] == pending_message
|
||||
):
|
||||
session.messages.pop()
|
||||
if (
|
||||
assistant_message
|
||||
and session.messages
|
||||
and session.messages[-1] == assistant_message
|
||||
):
|
||||
session.messages.pop()
|
||||
|
||||
# Release the Redis lock since the background task won't be spawned
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
logger.error(
|
||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
# Return immediately - don't wait for completion
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=OperationStartedResponse(
|
||||
message=started_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
success=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Normal flow: Run tool execution in background task with heartbeats
|
||||
tool_task = asyncio.create_task(
|
||||
execute_tool(
|
||||
tool_name=tool_name,
|
||||
@@ -1335,3 +1784,190 @@ async def _yield_tool_call(
|
||||
)
|
||||
|
||||
yield tool_execution_response
|
||||
|
||||
|
||||
async def _execute_long_running_tool(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
operation_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> None:
|
||||
"""Execute a long-running tool in background and update chat history with result.
|
||||
|
||||
This function runs independently of the SSE connection, so the operation
|
||||
survives if the user closes their browser tab.
|
||||
"""
|
||||
try:
|
||||
# Load fresh session (not stale reference)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for background tool")
|
||||
return
|
||||
|
||||
# Execute the actual tool
|
||||
result = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=parameters,
|
||||
tool_call_id=tool_call_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Update the pending message with result
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=(
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else orjson.dumps(result.output).decode("utf-8")
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"Background tool {tool_name} completed for session {session_id}")
|
||||
|
||||
# Generate LLM continuation so user sees response when they poll/refresh
|
||||
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
||||
error_response = ErrorResponse(
|
||||
message=f"Tool {tool_name} failed: {str(e)}",
|
||||
)
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
finally:
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
|
||||
async def _update_pending_operation(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
result: str,
|
||||
) -> None:
|
||||
"""Update the pending tool message with final result.
|
||||
|
||||
This is called by background tasks when long-running operations complete.
|
||||
"""
|
||||
# Update the message in database
|
||||
updated = await chat_db.update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=result,
|
||||
)
|
||||
|
||||
if updated:
|
||||
# Invalidate Redis cache so next load gets fresh data
|
||||
# Wrap in try/except to prevent cache failures from triggering error handling
|
||||
# that would overwrite our successful DB update
|
||||
try:
|
||||
await invalidate_session_cache(session_id)
|
||||
except Exception as e:
|
||||
# Non-critical: cache will eventually be refreshed on next load
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
logger.info(
|
||||
f"Updated pending operation for tool_call_id {tool_call_id} "
|
||||
f"in session {session_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to update pending operation for tool_call_id {tool_call_id} "
|
||||
f"in session {session_id}"
|
||||
)
|
||||
|
||||
|
||||
async def _generate_llm_continuation(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> None:
|
||||
"""Generate an LLM response after a long-running tool completes.
|
||||
|
||||
This is called by background tasks to continue the conversation
|
||||
after a tool result is saved. The response is saved to the database
|
||||
so users see it when they refresh or poll.
|
||||
"""
|
||||
try:
|
||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||
await invalidate_session_cache(session_id)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||
return
|
||||
|
||||
# Build system prompt
|
||||
system_prompt, _ = await _build_system_prompt(user_id)
|
||||
|
||||
# Build messages in OpenAI format
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
)
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Build extra_body for tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {
|
||||
"environment": settings.config.app_env.value,
|
||||
},
|
||||
}
|
||||
if user_id:
|
||||
extra_body["user"] = user_id[:128]
|
||||
extra_body["posthogDistinctId"] = user_id
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
# Make non-streaming LLM call (no tools - just text response)
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
# No tools parameter = text-only response (no tool calls)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
if response.choices and response.choices[0].message.content:
|
||||
assistant_content = response.choices[0].message.content
|
||||
|
||||
# Reload session from DB to avoid race condition with user messages
|
||||
# that may have been sent while we were generating the LLM response
|
||||
fresh_session = await get_chat_session(session_id, user_id)
|
||||
if not fresh_session:
|
||||
logger.error(
|
||||
f"Session {session_id} disappeared during LLM continuation"
|
||||
)
|
||||
return
|
||||
|
||||
# Save assistant message to database
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
fresh_session.messages.append(assistant_message)
|
||||
|
||||
# Save to database (not cache) to persist the response
|
||||
await upsert_chat_session(fresh_session)
|
||||
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
logger.info(
|
||||
f"Generated LLM continuation for session {session_id}, "
|
||||
f"response length: {len(assistant_content)}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"LLM continuation returned empty response for {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||
|
||||
@@ -49,6 +49,11 @@ tools: list[ChatCompletionToolParam] = [
|
||||
]
|
||||
|
||||
|
||||
def get_tool(tool_name: str) -> BaseTool | None:
|
||||
"""Get a tool instance by name."""
|
||||
return TOOL_REGISTRY.get(tool_name)
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
@@ -57,7 +62,7 @@ async def execute_tool(
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolOutputAvailable":
|
||||
"""Execute a tool by name."""
|
||||
tool = TOOL_REGISTRY.get(tool_name)
|
||||
tool = get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
|
||||
@@ -36,6 +36,16 @@ class BaseTool:
|
||||
"""Whether this tool requires authentication."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
"""Whether this tool is long-running and should execute in background.
|
||||
|
||||
Long-running tools (like agent generation) are executed via background
|
||||
tasks to survive SSE disconnections. The result is persisted to chat
|
||||
history and visible when the user refreshes.
|
||||
"""
|
||||
return False
|
||||
|
||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||
"""Convert to OpenAI tool format."""
|
||||
return ChatCompletionToolParam(
|
||||
|
||||
@@ -42,6 +42,10 @@ class CreateAgentTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -42,6 +42,10 @@ class EditAgentTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -28,6 +28,10 @@ class ResponseType(str, Enum):
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
# Long-running operation types
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -334,3 +338,39 @@ class BlockOutputResponse(ToolResponseBase):
|
||||
block_name: str
|
||||
outputs: dict[str, list[Any]]
|
||||
success: bool = True
|
||||
|
||||
|
||||
# Long-running operation models
|
||||
class OperationStartedResponse(ToolResponseBase):
|
||||
"""Response when a long-running operation has been started in the background.
|
||||
|
||||
This is returned immediately to the client while the operation continues
|
||||
to execute. The user can close the tab and check back later.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
|
||||
|
||||
class OperationPendingResponse(ToolResponseBase):
|
||||
"""Response stored in chat history while a long-running operation is executing.
|
||||
|
||||
This is persisted to the database so users see a pending state when they
|
||||
refresh before the operation completes.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_PENDING
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
|
||||
|
||||
class OperationInProgressResponse(ToolResponseBase):
|
||||
"""Response when an operation is already in progress.
|
||||
|
||||
Returned for idempotency when the same tool_call_id is requested again
|
||||
while the background task is still running.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.settings import Config
|
||||
@@ -64,11 +64,11 @@ async def list_library_agents(
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||
raise DatabaseError("Invalid pagination input")
|
||||
raise InvalidInputError("Invalid pagination input")
|
||||
|
||||
if search_term and len(search_term.strip()) > 100:
|
||||
logger.warning(f"Search term too long: {repr(search_term)}")
|
||||
raise DatabaseError("Search term is too long")
|
||||
raise InvalidInputError("Search term is too long")
|
||||
|
||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
@@ -175,7 +175,7 @@ async def list_favorite_library_agents(
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||
raise DatabaseError("Invalid pagination input")
|
||||
raise InvalidInputError("Invalid pagination input")
|
||||
|
||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
@@ -6,15 +5,11 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
from fastapi.responses import Response
|
||||
from prisma.enums import OnboardingStep
|
||||
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
from backend.data.onboarding import complete_onboarding_step
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .. import db as library_db
|
||||
from .. import model as library_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/agents",
|
||||
tags=["library", "private"],
|
||||
@@ -26,10 +21,6 @@ router = APIRouter(
|
||||
"",
|
||||
summary="List Library Agents",
|
||||
response_model=library_model.LibraryAgentResponse,
|
||||
responses={
|
||||
200: {"description": "List of library agents"},
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
)
|
||||
async def list_library_agents(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
@@ -53,43 +44,19 @@ async def list_library_agents(
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Get all agents in the user's library (both created and saved).
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user.
|
||||
search_term: Optional search term to filter agents by name/description.
|
||||
filter_by: List of filters to apply (favorites, created by user).
|
||||
sort_by: List of sorting criteria (created date, updated date).
|
||||
page: Page number to retrieve.
|
||||
page_size: Number of agents per page.
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing agents and pagination metadata.
|
||||
|
||||
Raises:
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_term,
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not list library agents for user #{user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
) from e
|
||||
return await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_term,
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/favorites",
|
||||
summary="List Favorite Library Agents",
|
||||
responses={
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
)
|
||||
async def list_favorite_library_agents(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
@@ -106,30 +73,12 @@ async def list_favorite_library_agents(
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Get all favorite agents in the user's library.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user.
|
||||
page: Page number to retrieve.
|
||||
page_size: Number of agents per page.
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing favorite agents and pagination metadata.
|
||||
|
||||
Raises:
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.list_favorite_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not list favorite library agents for user #{user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
) from e
|
||||
return await library_db.list_favorite_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
||||
@@ -162,10 +111,6 @@ async def get_library_agent_by_graph_id(
|
||||
summary="Get Agent By Store ID",
|
||||
tags=["store", "library"],
|
||||
response_model=library_model.LibraryAgent | None,
|
||||
responses={
|
||||
200: {"description": "Library agent found"},
|
||||
404: {"description": "Agent not found"},
|
||||
},
|
||||
)
|
||||
async def get_library_agent_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
@@ -174,32 +119,15 @@ async def get_library_agent_by_store_listing_version_id(
|
||||
"""
|
||||
Get Library Agent from Store Listing Version ID.
|
||||
"""
|
||||
try:
|
||||
return await library_db.get_library_agent_by_store_version_id(
|
||||
store_listing_version_id, user_id
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not fetch library agent from store version ID: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
) from e
|
||||
return await library_db.get_library_agent_by_store_version_id(
|
||||
store_listing_version_id, user_id
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
summary="Add Marketplace Agent",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
responses={
|
||||
201: {"description": "Agent added successfully"},
|
||||
404: {"description": "Store listing version not found"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def add_marketplace_agent_to_library(
|
||||
store_listing_version_id: str = Body(embed=True),
|
||||
@@ -210,59 +138,19 @@ async def add_marketplace_agent_to_library(
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Add an agent from the marketplace to the user's library.
|
||||
|
||||
Args:
|
||||
store_listing_version_id: ID of the store listing version to add.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
library_model.LibraryAgent: Agent added to the library
|
||||
|
||||
Raises:
|
||||
HTTPException(404): If the listing version is not found.
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
agent = await library_db.add_store_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if source != "onboarding":
|
||||
await complete_onboarding_step(
|
||||
user_id, OnboardingStep.MARKETPLACE_ADD_AGENT
|
||||
)
|
||||
return agent
|
||||
|
||||
except store_exceptions.AgentNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Could not find store listing version {store_listing_version_id} "
|
||||
"to add to library"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Database error while adding agent to library: {e}", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Inspect DB logs for details."},
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while adding agent to library: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Check server logs for more information.",
|
||||
},
|
||||
) from e
|
||||
agent = await library_db.add_store_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if source != "onboarding":
|
||||
await complete_onboarding_step(user_id, OnboardingStep.MARKETPLACE_ADD_AGENT)
|
||||
return agent
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{library_agent_id}",
|
||||
summary="Update Library Agent",
|
||||
responses={
|
||||
200: {"description": "Agent updated successfully"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
@@ -271,52 +159,21 @@ async def update_library_agent(
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Update the library agent with the given fields.
|
||||
|
||||
Args:
|
||||
library_agent_id: ID of the library agent to update.
|
||||
payload: Fields to update (auto_update_version, is_favorite, etc.).
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Raises:
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.update_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
auto_update_version=payload.auto_update_version,
|
||||
graph_version=payload.graph_version,
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
settings=payload.settings,
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Database error while updating library agent: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Verify DB connection."},
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while updating library agent: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Check server logs."},
|
||||
) from e
|
||||
return await library_db.update_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
auto_update_version=payload.auto_update_version,
|
||||
graph_version=payload.graph_version,
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
settings=payload.settings,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{library_agent_id}",
|
||||
summary="Delete Library Agent",
|
||||
responses={
|
||||
204: {"description": "Agent deleted successfully"},
|
||||
404: {"description": "Agent not found"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def delete_library_agent(
|
||||
library_agent_id: str,
|
||||
@@ -324,28 +181,11 @@ async def delete_library_agent(
|
||||
) -> Response:
|
||||
"""
|
||||
Soft-delete the specified library agent.
|
||||
|
||||
Args:
|
||||
library_agent_id: ID of the library agent to delete.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
204 No Content if successful.
|
||||
|
||||
Raises:
|
||||
HTTPException(404): If the agent does not exist.
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
await library_db.delete_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
await library_db.delete_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
||||
|
||||
@@ -118,21 +118,6 @@ async def test_get_library_agents_success(
|
||||
)
|
||||
|
||||
|
||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
||||
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents?search_term=test")
|
||||
assert response.status_code == 500
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
search_term="test",
|
||||
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
||||
page=1,
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_favorite_library_agents_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
@@ -190,23 +175,6 @@ async def test_get_favorite_library_agents_success(
|
||||
)
|
||||
|
||||
|
||||
def test_get_favorite_library_agents_error(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.api.features.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents/favorites")
|
||||
assert response.status_code == 500
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
page=1,
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
|
||||
def test_add_agent_to_library_success(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
@@ -258,19 +226,3 @@ def test_add_agent_to_library_success(
|
||||
store_listing_version_id="test-version-id", user_id=test_user_id
|
||||
)
|
||||
mock_complete_onboarding.assert_awaited_once()
|
||||
|
||||
|
||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.api.features.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.post(
|
||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
||||
)
|
||||
assert response.status_code == 500
|
||||
assert "detail" in response.json() # Verify error response structure
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id=test_user_id
|
||||
)
|
||||
|
||||
@@ -454,6 +454,7 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
all_errors: dict[str, int] = {} # Aggregate errors across all content types
|
||||
|
||||
# Process content types in explicit order
|
||||
processing_order = [
|
||||
@@ -499,23 +500,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
||||
success = sum(1 for result in results if result is True)
|
||||
failed = len(results) - success
|
||||
|
||||
# Aggregate unique errors to avoid Sentry spam
|
||||
# Aggregate errors across all content types
|
||||
if failed > 0:
|
||||
# Group errors by type and message
|
||||
error_summary: dict[str, int] = {}
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
error_key = f"{type(result).__name__}: {str(result)}"
|
||||
error_summary[error_key] = error_summary.get(error_key, 0) + 1
|
||||
|
||||
# Log aggregated error summary
|
||||
error_details = ", ".join(
|
||||
f"{error} ({count}x)" for error, count in error_summary.items()
|
||||
)
|
||||
logger.error(
|
||||
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
|
||||
f"Errors: {error_details}"
|
||||
)
|
||||
all_errors[error_key] = all_errors.get(error_key, 0) + 1
|
||||
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": len(missing_items),
|
||||
@@ -542,6 +532,13 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
# Log aggregated errors once at the end
|
||||
if all_errors:
|
||||
error_details = ", ".join(
|
||||
f"{error} ({count}x)" for error, count in all_errors.items()
|
||||
)
|
||||
logger.error(f"Embedding backfill errors: {error_details}")
|
||||
|
||||
return {
|
||||
"by_type": results_by_type,
|
||||
"totals": {
|
||||
|
||||
@@ -393,7 +393,6 @@ async def get_creators(
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
summary="Get creator details",
|
||||
operation_id="getV2GetCreatorDetails",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
|
||||
@@ -261,14 +261,36 @@ async def get_onboarding_agents(
|
||||
return await get_recommended_agents(user_id)
|
||||
|
||||
|
||||
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||
"""Response for onboarding status check."""
|
||||
|
||||
is_onboarding_enabled: bool
|
||||
is_chat_enabled: bool
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding/enabled",
|
||||
summary="Is onboarding enabled",
|
||||
tags=["onboarding", "public"],
|
||||
dependencies=[Security(requires_user)],
|
||||
response_model=OnboardingStatusResponse,
|
||||
)
|
||||
async def is_onboarding_enabled() -> bool:
|
||||
return await onboarding_enabled()
|
||||
async def is_onboarding_enabled(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> OnboardingStatusResponse:
|
||||
# Check if chat is enabled for user
|
||||
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
|
||||
|
||||
# If chat is enabled, skip legacy onboarding
|
||||
if is_chat_enabled:
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=False,
|
||||
is_chat_enabled=True,
|
||||
)
|
||||
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=await onboarding_enabled(),
|
||||
is_chat_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
|
||||
@@ -18,7 +18,6 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.llm_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
@@ -38,11 +37,9 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.v2.llm.routes as public_llm_routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
@@ -112,27 +109,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated discriminator_mapping
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
# migrate_llm_models uses registry default model
|
||||
from backend.blocks.llm import LlmModel
|
||||
|
||||
default_model_slug = llm_registry.get_default_model_slug()
|
||||
if default_model_slug:
|
||||
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
|
||||
else:
|
||||
logger.warning("Skipping LLM model migration: no default model available")
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
@@ -317,16 +298,6 @@ app.include_router(
|
||||
tags=["v2", "executions", "review"],
|
||||
prefix="/api/review",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.llm_routes.router,
|
||||
tags=["v2", "admin", "llm"],
|
||||
prefix="/api/llm/admin",
|
||||
)
|
||||
app.include_router(
|
||||
public_llm_routes.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
|
||||
@@ -77,39 +77,7 @@ async def event_broadcaster(manager: ConnectionManager):
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
async def registry_refresh_worker():
|
||||
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
|
||||
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info(
|
||||
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
|
||||
)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if (
|
||||
message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info(
|
||||
"Broadcasting LLM registry refresh to all WebSocket clients"
|
||||
)
|
||||
await manager.broadcast_to_all(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data={
|
||||
"type": "LLM_REGISTRY_REFRESH",
|
||||
"event": "registry_updated",
|
||||
},
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
execution_worker(),
|
||||
notification_worker(),
|
||||
registry_refresh_worker(),
|
||||
)
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
@@ -9,7 +10,6 @@ from backend.blocks.llm import (
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
llm_model_schema_extra,
|
||||
)
|
||||
from backend.data.block import (
|
||||
BlockCategory,
|
||||
@@ -50,10 +50,9 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
@@ -83,7 +82,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
|
||||
@@ -4,19 +4,17 @@ import logging
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Iterable, List, Literal, Optional
|
||||
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -24,7 +22,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.llm_registry import ModelMetadata
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -69,123 +66,114 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
|
||||
def AICredentialsField() -> AICredentials:
|
||||
"""
|
||||
Returns a CredentialsField for LLM providers.
|
||||
The discriminator_mapping will be refreshed when the schema is generated
|
||||
if it's empty, ensuring the LLM registry is loaded.
|
||||
"""
|
||||
# Get the mapping now - it may be empty initially, but will be refreshed
|
||||
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
|
||||
mapping = llm_registry.get_llm_discriminator_mapping()
|
||||
|
||||
return CredentialsField(
|
||||
description="API key for the LLM provider.",
|
||||
discriminator="model",
|
||||
discriminator_mapping=mapping, # May be empty initially, refreshed later
|
||||
discriminator_mapping={
|
||||
model.value: model.metadata.provider for model in LlmModel
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def llm_model_schema_extra() -> dict[str, Any]:
|
||||
return {"options": llm_registry.get_llm_model_schema_options()}
|
||||
class ModelMetadata(NamedTuple):
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
display_name: str
|
||||
provider_name: str
|
||||
creator_name: str
|
||||
price_tier: Literal[1, 2, 3]
|
||||
|
||||
|
||||
class LlmModelMeta(type):
|
||||
"""
|
||||
Metaclass for LlmModel that enables attribute-style access to dynamic models.
|
||||
|
||||
This allows code like `LlmModel.GPT4O` to work by converting the attribute
|
||||
name to a slug format:
|
||||
- GPT4O -> gpt-4o
|
||||
- GPT4O_MINI -> gpt-4o-mini
|
||||
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
|
||||
"""
|
||||
|
||||
def __getattr__(cls, name: str):
|
||||
# Don't intercept private/dunder attributes
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
|
||||
|
||||
# Convert attribute name to slug format:
|
||||
# 1. Lowercase: GPT4O -> gpt4o
|
||||
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
|
||||
slug = name.lower().replace("_", "-")
|
||||
|
||||
# Check for exact match in registry first (e.g., "o1" stays "o1")
|
||||
registry_slugs = llm_registry.get_dynamic_model_slugs()
|
||||
if slug in registry_slugs:
|
||||
return cls(slug)
|
||||
|
||||
# If no exact match, try inserting hyphen between letter and digit
|
||||
# e.g., gpt4o -> gpt-4o
|
||||
transformed_slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
|
||||
return cls(transformed_slug)
|
||||
|
||||
def __iter__(cls):
|
||||
"""Iterate over all models from the registry.
|
||||
|
||||
Yields LlmModel instances for each model in the dynamic registry.
|
||||
Used by __get_pydantic_json_schema__ to build model metadata.
|
||||
"""
|
||||
for model in llm_registry.iter_dynamic_models():
|
||||
yield cls(model.slug)
|
||||
class LlmModelMeta(EnumMeta):
|
||||
pass
|
||||
|
||||
|
||||
class LlmModel(str, metaclass=LlmModelMeta):
|
||||
"""
|
||||
Dynamic LLM model type that accepts any model slug from the registry.
|
||||
|
||||
This is a string subclass (not an Enum) that allows any model slug value.
|
||||
All models are managed via the LLM Registry in the database.
|
||||
|
||||
Usage:
|
||||
model = LlmModel("gpt-4o") # Direct construction
|
||||
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
|
||||
model.value # Returns the slug string
|
||||
model.provider # Returns the provider from registry
|
||||
"""
|
||||
|
||||
def __new__(cls, value: str):
|
||||
if isinstance(value, LlmModel):
|
||||
return value
|
||||
return str.__new__(cls, value)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""
|
||||
Tell Pydantic how to validate LlmModel.
|
||||
|
||||
Accepts strings and converts them to LlmModel instances.
|
||||
"""
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls, # The validator function (LlmModel constructor)
|
||||
core_schema.str_schema(), # Accept string input
|
||||
serialization=core_schema.to_string_ser_schema(), # Serialize as string
|
||||
)
|
||||
|
||||
@property
|
||||
def value(self) -> str:
|
||||
"""Return the model slug (for compatibility with enum-style access)."""
|
||||
return str(self)
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "LlmModel":
|
||||
"""
|
||||
Get the default model from the registry.
|
||||
|
||||
Returns the recommended model if set, otherwise gpt-4o if available
|
||||
and enabled, otherwise the first enabled model from the registry.
|
||||
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
|
||||
"""
|
||||
from backend.data.llm_registry import get_default_model_slug
|
||||
|
||||
slug = get_default_model_slug()
|
||||
if slug is None:
|
||||
# Registry is empty (e.g., at module import time before DB connection).
|
||||
# Fall back to gpt-4o for backward compatibility.
|
||||
slug = "gpt-4o"
|
||||
return cls(slug)
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
|
||||
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
||||
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
|
||||
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
|
||||
# Groq models
|
||||
LLAMA3_3_70B = "llama-3.3-70b-versatile"
|
||||
LLAMA3_1_8B = "llama-3.1-8b-instant"
|
||||
# Ollama models
|
||||
OLLAMA_LLAMA3_3 = "llama3.3"
|
||||
OLLAMA_LLAMA3_2 = "llama3.2"
|
||||
OLLAMA_LLAMA3_8B = "llama3"
|
||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
QWEN3_CODER = "qwen/qwen3-coder"
|
||||
# Llama API models
|
||||
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
|
||||
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
|
||||
# v0 by Vercel models
|
||||
V0_1_5_MD = "v0-1.5-md"
|
||||
V0_1_5_LG = "v0-1.5-lg"
|
||||
V0_1_0_MD = "v0-1.0-md"
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, schema, handler):
|
||||
@@ -193,15 +181,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
||||
llm_model_metadata = {}
|
||||
for model in cls:
|
||||
model_name = model.value
|
||||
# Skip disabled models - only show enabled models in the picker
|
||||
if not llm_registry.is_model_enabled(model_name):
|
||||
continue
|
||||
# Use registry directly with None check to gracefully handle
|
||||
# missing metadata during startup/import before registry is populated
|
||||
metadata = llm_registry.get_llm_model_metadata(model_name)
|
||||
if metadata is None:
|
||||
# Skip models without metadata (registry not yet populated)
|
||||
continue
|
||||
metadata = model.metadata
|
||||
llm_model_metadata[model_name] = {
|
||||
"creator": metadata.creator_name,
|
||||
"creator_name": metadata.creator_name,
|
||||
@@ -217,12 +197,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
raise ValueError(
|
||||
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
|
||||
)
|
||||
return MODEL_METADATA[self]
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
@@ -237,11 +212,300 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
||||
return self.metadata.max_output_tokens
|
||||
|
||||
|
||||
# MODEL_METADATA removed - all models now come from the database via llm_registry
|
||||
MODEL_METADATA = {
|
||||
# https://platform.openai.com/docs/models
|
||||
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2),
|
||||
LlmModel.O3_MINI: ModelMetadata(
|
||||
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1
|
||||
), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata(
|
||||
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
|
||||
), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata(
|
||||
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
|
||||
), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3
|
||||
),
|
||||
LlmModel.GPT5_1: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2
|
||||
),
|
||||
LlmModel.GPT5: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_MINI: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_NANO: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata(
|
||||
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
|
||||
),
|
||||
LlmModel.GPT41: ModelMetadata(
|
||||
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT41_MINI: ModelMetadata(
|
||||
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1
|
||||
), # gpt-4o-mini-2024-07-18
|
||||
LlmModel.GPT4O: ModelMetadata(
|
||||
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
|
||||
), # gpt-4o-2024-08-06
|
||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
||||
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
|
||||
), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3
|
||||
), # claude-4-opus-20250514
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-5-20251101
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||
), # claude-haiku-4-5-20251001
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||
), # claude-3-haiku-20240307
|
||||
# https://docs.aimlapi.com/api-overview/model-database/text-models
|
||||
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata(
|
||||
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata(
|
||||
"aiml_api",
|
||||
128000,
|
||||
40000,
|
||||
"Llama 3.1 Nemotron 70B Instruct",
|
||||
"AI/ML",
|
||||
"Nvidia",
|
||||
1,
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
|
||||
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
|
||||
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
|
||||
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
# https://console.groq.com/docs/models
|
||||
LlmModel.LLAMA3_3_70B: ModelMetadata(
|
||||
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1
|
||||
),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata(
|
||||
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
|
||||
),
|
||||
# https://ollama.com/library
|
||||
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
|
||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||
),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1050000,
|
||||
8192,
|
||||
"Gemini 2.5 Pro Preview 03.25",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65535,
|
||||
"Gemini 2.5 Flash Lite Preview 06.17",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
8192,
|
||||
"Gemini 2.0 Flash Lite 001",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
|
||||
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
|
||||
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
16000,
|
||||
"Sonar Deep Research",
|
||||
"OpenRouter",
|
||||
"Perplexity",
|
||||
3,
|
||||
),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
||||
"open_router",
|
||||
131000,
|
||||
4096,
|
||||
"Hermes 3 Llama 3.1 405B",
|
||||
"OpenRouter",
|
||||
"Nous Research",
|
||||
1,
|
||||
),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
||||
"open_router",
|
||||
12288,
|
||||
12288,
|
||||
"Hermes 3 Llama 3.1 70B",
|
||||
"OpenRouter",
|
||||
"Nous Research",
|
||||
1,
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
|
||||
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
|
||||
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
|
||||
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
|
||||
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
|
||||
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||
),
|
||||
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.GROK_4: ModelMetadata(
|
||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||
),
|
||||
LlmModel.GROK_4_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_4_1_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
|
||||
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.KIMI_K2: ModelMetadata(
|
||||
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
|
||||
),
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
262144,
|
||||
"Qwen 3 235B A22B Thinking 2507",
|
||||
"OpenRouter",
|
||||
"Qwen",
|
||||
1,
|
||||
),
|
||||
LlmModel.QWEN3_CODER: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
|
||||
),
|
||||
# Llama API models
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"llama_api",
|
||||
128000,
|
||||
4028,
|
||||
"Llama 4 Scout 17B 16E Instruct FP8",
|
||||
"Llama API",
|
||||
"Meta",
|
||||
1,
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
|
||||
"llama_api",
|
||||
128000,
|
||||
4028,
|
||||
"Llama 4 Maverick 17B 128E Instruct FP8",
|
||||
"Llama API",
|
||||
"Meta",
|
||||
1,
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
|
||||
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
|
||||
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
|
||||
),
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1),
|
||||
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1),
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1),
|
||||
}
|
||||
|
||||
# Default model constant for backward compatibility
|
||||
# Uses the dynamic registry to get the default model
|
||||
DEFAULT_LLM_MODEL = LlmModel.default()
|
||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
@@ -334,10 +598,7 @@ def get_parallel_tool_calls_param(
|
||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||
):
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
# Check for o-series models (o1, o1-mini, o3-mini, etc.) which don't support
|
||||
# parallel tool calls. Use regex to avoid false positives like "openai/gpt-oss".
|
||||
is_o_series = re.match(r"^o\d", llm_model) is not None
|
||||
if is_o_series or parallel_tool_calls is None:
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.NOT_GIVEN
|
||||
return parallel_tool_calls
|
||||
|
||||
@@ -373,98 +634,19 @@ async def llm_call(
|
||||
- prompt_tokens: The number of tokens used in the prompt.
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
# Get model metadata and check if enabled - with fallback support
|
||||
# The model we'll actually use (may differ if original is disabled)
|
||||
model_to_use = llm_model.value
|
||||
|
||||
# Check if model is in registry and if it's enabled
|
||||
from backend.data.llm_registry import (
|
||||
get_fallback_model_for_disabled,
|
||||
get_model_info,
|
||||
)
|
||||
|
||||
model_info = get_model_info(llm_model.value)
|
||||
|
||||
if model_info and not model_info.is_enabled:
|
||||
# Model is disabled - try to find a fallback from the same provider
|
||||
fallback = get_fallback_model_for_disabled(llm_model.value)
|
||||
if fallback:
|
||||
logger.warning(
|
||||
f"Model '{llm_model.value}' is disabled. Using fallback model '{fallback.slug}' from the same provider ({fallback.metadata.provider})."
|
||||
)
|
||||
model_to_use = fallback.slug
|
||||
# Use fallback model's metadata
|
||||
provider = fallback.metadata.provider
|
||||
context_window = fallback.metadata.context_window
|
||||
model_max_output = fallback.metadata.max_output_tokens or int(2**15)
|
||||
else:
|
||||
# No fallback available - raise error
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' is disabled and no fallback model "
|
||||
f"from the same provider is available. Please enable the model or "
|
||||
f"select a different model in the block configuration."
|
||||
)
|
||||
else:
|
||||
# Model is enabled or not in registry (legacy/static model)
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
except ValueError:
|
||||
# Model not in cache - try refreshing the registry once if we have DB access
|
||||
logger.warning(f"Model {llm_model.value} not found in registry cache")
|
||||
|
||||
# Try refreshing the registry if we have database access
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
try:
|
||||
logger.info(
|
||||
f"Refreshing LLM registry and retrying lookup for {llm_model.value}"
|
||||
)
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Try again after refresh
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
logger.info(
|
||||
f"Successfully loaded model {llm_model.value} metadata after registry refresh"
|
||||
)
|
||||
except ValueError:
|
||||
# Still not found after refresh
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry after refresh. "
|
||||
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
|
||||
)
|
||||
except Exception as refresh_exc:
|
||||
logger.error(f"Failed to refresh LLM registry: {refresh_exc}")
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
|
||||
"Please ensure the model is added to the LLM registry via the admin UI."
|
||||
) from refresh_exc
|
||||
else:
|
||||
# No DB access (e.g., in executor without direct DB connection)
|
||||
# The registry should have been loaded on startup
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry cache. "
|
||||
"The registry may need to be refreshed. Please contact support or try again later."
|
||||
)
|
||||
|
||||
# Create effective model for model-specific parameter resolution (e.g., o-series check)
|
||||
# This uses the resolved model_to_use which may differ from llm_model if fallback occurred
|
||||
effective_model = LlmModel(model_to_use)
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
prompt = compress_prompt(
|
||||
messages=prompt,
|
||||
target_tokens=context_window // 2,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
lossy_ok=True,
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
# model_max_output already set above
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
@@ -475,14 +657,14 @@ async def llm_call(
|
||||
response_format = None
|
||||
|
||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
@@ -529,7 +711,7 @@ async def llm_call(
|
||||
)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
@@ -593,7 +775,7 @@ async def llm_call(
|
||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if force_json_output else None
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
@@ -615,7 +797,7 @@ async def llm_call(
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = await client.generate(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
options={"num_ctx": max_tokens},
|
||||
@@ -637,7 +819,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
@@ -645,7 +827,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
@@ -679,7 +861,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
@@ -687,7 +869,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
@@ -714,7 +896,7 @@ async def llm_call(
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "aiml_api":
|
||||
client = openai.AsyncOpenAI(
|
||||
client = openai.OpenAI(
|
||||
base_url="https://api.aimlapi.com/v2",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
default_headers={
|
||||
@@ -724,8 +906,8 @@ async def llm_call(
|
||||
},
|
||||
)
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
completion = client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@@ -753,11 +935,11 @@ async def llm_call(
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
@@ -808,10 +990,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
@@ -874,7 +1055,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -1240,10 +1421,9 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
@@ -1337,9 +1517,8 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for summarizing the text.",
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
title="Focus",
|
||||
@@ -1555,9 +1734,8 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for the conversation.",
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
max_tokens: int | None = SchemaField(
|
||||
@@ -1594,7 +1772,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1657,10 +1835,9 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
max_retries: int = SchemaField(
|
||||
@@ -1715,7 +1892,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
|
||||
@@ -226,10 +226,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=llm.LlmModel.default,
|
||||
default=llm.DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm.llm_model_schema_extra(),
|
||||
)
|
||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||
multiple_tool_calls: bool = SchemaField(
|
||||
|
||||
@@ -10,13 +10,13 @@ import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
ModelMetadata,
|
||||
)
|
||||
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
||||
from backend.data import llm_registry
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
Returns the provider name for the model in the required format for Stagehand:
|
||||
provider/model_name
|
||||
"""
|
||||
model_metadata = self.metadata
|
||||
model_metadata = MODEL_METADATA[LlmModel(self.value)]
|
||||
model_name = self.value
|
||||
|
||||
if len(model_name.split("/")) == 1 and not self.value.startswith(
|
||||
@@ -107,23 +107,19 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self.metadata.provider
|
||||
return MODEL_METADATA[LlmModel(self.value)].provider
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
# Fallback to LlmModel enum if registry lookup fails
|
||||
return LlmModel(self.value).metadata
|
||||
return MODEL_METADATA[LlmModel(self.value)]
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return self.metadata.context_window
|
||||
return MODEL_METADATA[LlmModel(self.value)].context_window
|
||||
|
||||
@property
|
||||
def max_output_tokens(self) -> int | None:
|
||||
return self.metadata.max_output_tokens
|
||||
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
|
||||
|
||||
|
||||
class StagehandObserveBlock(Block):
|
||||
|
||||
@@ -25,7 +25,6 @@ from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.llm_registry import update_schema_with_llm_registry
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
@@ -144,59 +143,35 @@ class BlockInfo(BaseModel):
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
|
||||
|
||||
@classmethod
|
||||
def clear_schema_cache(cls) -> None:
|
||||
"""Clear the cached JSON schema for this class."""
|
||||
# Use None instead of {} because {} is truthy and would prevent regeneration
|
||||
cls.cached_jsonschema = None # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def clear_all_schema_caches() -> None:
|
||||
"""Clear cached JSON schemas for all BlockSchema subclasses."""
|
||||
|
||||
def clear_recursive(cls: type) -> None:
|
||||
"""Recursively clear cache for class and all subclasses."""
|
||||
if hasattr(cls, "clear_schema_cache"):
|
||||
cls.clear_schema_cache()
|
||||
for subclass in cls.__subclasses__():
|
||||
clear_recursive(subclass)
|
||||
|
||||
clear_recursive(BlockSchema)
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def jsonschema(cls) -> dict[str, Any]:
|
||||
# Generate schema if not cached
|
||||
if not cls.cached_jsonschema:
|
||||
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||
if cls.cached_jsonschema:
|
||||
return cls.cached_jsonschema
|
||||
|
||||
def ref_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||
keys = {"allOf", "anyOf", "oneOf"}
|
||||
one_key = next(
|
||||
(k for k in keys if k in obj and len(obj[k]) == 1), None
|
||||
)
|
||||
if one_key:
|
||||
obj.update(obj[one_key][0])
|
||||
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||
|
||||
return {
|
||||
key: ref_to_dict(value)
|
||||
for key, value in obj.items()
|
||||
if not key.startswith("$") and key != one_key
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
def ref_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||
keys = {"allOf", "anyOf", "oneOf"}
|
||||
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
||||
if one_key:
|
||||
obj.update(obj[one_key][0])
|
||||
|
||||
return obj
|
||||
return {
|
||||
key: ref_to_dict(value)
|
||||
for key, value in obj.items()
|
||||
if not key.startswith("$") and key != one_key
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
return obj
|
||||
|
||||
# Always post-process to ensure LLM registry data is up-to-date
|
||||
# This refreshes model options and discriminator mappings even if schema was cached
|
||||
update_schema_with_llm_registry(cls.cached_jsonschema, cls)
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
|
||||
return cls.cached_jsonschema
|
||||
|
||||
@@ -259,7 +234,7 @@ class BlockSchema(BaseModel):
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||
cls.cached_jsonschema = None
|
||||
cls.cached_jsonschema = {}
|
||||
|
||||
credentials_fields = cls.get_credentials_fields()
|
||||
|
||||
@@ -898,28 +873,6 @@ def is_block_auth_configured(
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
# This ensures the registry cache is populated even in executor context
|
||||
try:
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
|
||||
# Only refresh if we have DB access (check if Prisma is connected)
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
logger.info("LLM registry refreshed during block initialization")
|
||||
else:
|
||||
logger.warning(
|
||||
"Prisma not connected, skipping LLM registry refresh during block initialization"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh LLM registry during block initialization: %s", exc
|
||||
)
|
||||
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
|
||||
@@ -24,18 +23,19 @@ from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AIConversationBlock,
|
||||
AIListGeneratorBlock,
|
||||
AIStructuredResponseGeneratorBlock,
|
||||
AITextGeneratorBlock,
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
@@ -55,63 +55,210 @@ from backend.integrations.credentials_store import (
|
||||
v0_credentials,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
|
||||
PROVIDER_CREDENTIALS = {
|
||||
"openai": openai_credentials,
|
||||
"anthropic": anthropic_credentials,
|
||||
"groq": groq_credentials,
|
||||
"open_router": open_router_credentials,
|
||||
"llama_api": llama_api_credentials,
|
||||
"aiml_api": aiml_api_credentials,
|
||||
"v0": v0_credentials,
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2,
|
||||
LlmModel.O1: 16,
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: 6,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
LlmModel.GPT41: 2,
|
||||
LlmModel.GPT41_MINI: 1,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1,
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
|
||||
LlmModel.GEMINI_2_5_FLASH: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH: 1,
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.DEEPSEEK_R1_0528: 1,
|
||||
LlmModel.PERPLEXITY_SONAR: 1,
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.GROK_4: 9,
|
||||
LlmModel.GROK_4_FAST: 1,
|
||||
LlmModel.GROK_4_1_FAST: 1,
|
||||
LlmModel.GROK_CODE_FAST_1: 1,
|
||||
LlmModel.KIMI_K2: 1,
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: 1,
|
||||
LlmModel.QWEN3_CODER: 9,
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: 1,
|
||||
LlmModel.V0_1_5_LG: 2,
|
||||
LlmModel.V0_1_0_MD: 1,
|
||||
}
|
||||
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
# All LLM costs now come from the database via llm_registry
|
||||
|
||||
LLM_COST: list[BlockCost] = []
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_COST:
|
||||
raise ValueError(f"Missing MODEL_COST for model: {model}")
|
||||
|
||||
|
||||
def _build_llm_costs_from_registry() -> list[BlockCost]:
|
||||
"""Build BlockCost list from all models in the LLM registry."""
|
||||
costs: list[BlockCost] = []
|
||||
for model in llm_registry.iter_dynamic_models():
|
||||
for cost in model.costs:
|
||||
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
|
||||
if not credentials:
|
||||
logger.warning(
|
||||
"Skipping cost entry for %s due to unknown credentials provider %s",
|
||||
model.slug,
|
||||
cost.credential_provider,
|
||||
)
|
||||
continue
|
||||
cost_filter = {
|
||||
"model": model.slug,
|
||||
LLM_COST = (
|
||||
# Anthropic Models
|
||||
[
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": credentials.id,
|
||||
"provider": credentials.provider,
|
||||
"type": credentials.type,
|
||||
"id": anthropic_credentials.id,
|
||||
"provider": anthropic_credentials.provider,
|
||||
"type": anthropic_credentials.type,
|
||||
},
|
||||
}
|
||||
costs.append(
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter=cost_filter,
|
||||
cost_amount=cost.credit_cost,
|
||||
)
|
||||
)
|
||||
return costs
|
||||
|
||||
|
||||
def refresh_llm_costs() -> None:
|
||||
"""Refresh LLM costs from the registry. All costs now come from the database."""
|
||||
LLM_COST.clear()
|
||||
LLM_COST.extend(_build_llm_costs_from_registry())
|
||||
|
||||
|
||||
# Initial load will happen after registry is refreshed at startup
|
||||
# Don't call refresh_llm_costs() here - it will be called after registry refresh
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "anthropic"
|
||||
]
|
||||
# OpenAI Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": openai_credentials.id,
|
||||
"provider": openai_credentials.provider,
|
||||
"type": openai_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "openai"
|
||||
]
|
||||
# Groq Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {"id": groq_credentials.id},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "groq"
|
||||
]
|
||||
# Open Router Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": open_router_credentials.id,
|
||||
"provider": open_router_credentials.provider,
|
||||
"type": open_router_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "open_router"
|
||||
]
|
||||
# Llama API Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": llama_api_credentials.id,
|
||||
"provider": llama_api_credentials.provider,
|
||||
"type": llama_api_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "llama_api"
|
||||
]
|
||||
# v0 by Vercel Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": v0_credentials.id,
|
||||
"provider": v0_credentials.provider,
|
||||
"type": v0_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "v0"
|
||||
]
|
||||
# AI/ML Api Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": aiml_api_credentials.id,
|
||||
"provider": aiml_api_credentials.provider,
|
||||
"type": aiml_api_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "aiml_api"
|
||||
]
|
||||
)
|
||||
|
||||
# =============== This is the exhaustive list of cost for each Block =============== #
|
||||
|
||||
|
||||
@@ -1511,10 +1511,8 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
llm_model_fields[block.id] = field_name
|
||||
|
||||
# Get all model slugs from the registry (dynamic, not hardcoded enum)
|
||||
from backend.data import llm_registry
|
||||
|
||||
enum_values = list(llm_registry.get_all_model_slugs_for_validation())
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel]
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
|
||||
# Update each block
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
"""
|
||||
LLM Registry module for managing LLM models, providers, and costs dynamically.
|
||||
|
||||
This module provides a database-driven registry system for LLM models,
|
||||
replacing hardcoded model configurations with a flexible admin-managed system.
|
||||
"""
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
# Re-export for backwards compatibility
|
||||
from backend.data.llm_registry.notifications import (
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
from backend.data.llm_registry.registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_default_model_slug,
|
||||
get_dynamic_model_slugs,
|
||||
get_fallback_model_for_disabled,
|
||||
get_llm_discriminator_mapping,
|
||||
get_llm_model_cost,
|
||||
get_llm_model_metadata,
|
||||
get_llm_model_schema_options,
|
||||
get_model_info,
|
||||
is_model_enabled,
|
||||
iter_dynamic_models,
|
||||
refresh_llm_registry,
|
||||
register_static_costs,
|
||||
register_static_metadata,
|
||||
)
|
||||
from backend.data.llm_registry.schema_utils import (
|
||||
is_llm_model_field,
|
||||
refresh_llm_discriminator_mapping,
|
||||
refresh_llm_model_options,
|
||||
update_schema_with_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Registry functions
|
||||
"get_all_model_slugs_for_validation",
|
||||
"get_default_model_slug",
|
||||
"get_dynamic_model_slugs",
|
||||
"get_fallback_model_for_disabled",
|
||||
"get_llm_discriminator_mapping",
|
||||
"get_llm_model_cost",
|
||||
"get_llm_model_metadata",
|
||||
"get_llm_model_schema_options",
|
||||
"get_model_info",
|
||||
"is_model_enabled",
|
||||
"iter_dynamic_models",
|
||||
"refresh_llm_registry",
|
||||
"register_static_costs",
|
||||
"register_static_metadata",
|
||||
# Notifications
|
||||
"REGISTRY_REFRESH_CHANNEL",
|
||||
"publish_registry_refresh_notification",
|
||||
"subscribe_to_registry_refresh",
|
||||
# Schema utilities
|
||||
"is_llm_model_field",
|
||||
"refresh_llm_discriminator_mapping",
|
||||
"refresh_llm_model_options",
|
||||
"update_schema_with_llm_registry",
|
||||
]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Type definitions for LLM model metadata."""
|
||||
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
|
||||
class ModelMetadata(NamedTuple):
|
||||
"""Metadata for an LLM model.
|
||||
|
||||
Attributes:
|
||||
provider: The provider identifier (e.g., "openai", "anthropic")
|
||||
context_window: Maximum context window size in tokens
|
||||
max_output_tokens: Maximum output tokens (None if unlimited)
|
||||
display_name: Human-readable name for the model
|
||||
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
|
||||
creator_name: Name of the organization that created the model
|
||||
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
|
||||
"""
|
||||
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
display_name: str
|
||||
provider_name: str
|
||||
creator_name: str
|
||||
price_tier: Literal[1, 2, 3]
|
||||
@@ -1,89 +0,0 @@
|
||||
"""
|
||||
Redis pub/sub notifications for LLM registry updates.
|
||||
|
||||
When models are added/updated/removed via the admin UI, this module
|
||||
publishes notifications to Redis that all executor services subscribe to,
|
||||
ensuring they refresh their registry cache in real-time.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis channel name for LLM registry refresh notifications
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
async def publish_registry_refresh_notification() -> None:
|
||||
"""
|
||||
Publish a notification to Redis that the LLM registry has been updated.
|
||||
All executor services subscribed to this channel will refresh their registry.
|
||||
"""
|
||||
try:
|
||||
redis = await connect_async()
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.info("Published LLM registry refresh notification to Redis")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to publish LLM registry refresh notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Any, # Async callable that takes no args
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to Redis notifications for LLM registry updates.
|
||||
This runs in a loop and processes messages as they arrive.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable to execute when a refresh notification is received
|
||||
"""
|
||||
try:
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info(
|
||||
"Subscribed to LLM registry refresh notifications on channel: %s",
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
)
|
||||
|
||||
# Process messages in a loop
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if (
|
||||
message
|
||||
and message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info("Received LLM registry refresh notification")
|
||||
try:
|
||||
await on_refresh()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Error refreshing LLM registry from notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Error processing registry refresh message: %s", exc, exc_info=True
|
||||
)
|
||||
# Continue listening even if one message fails
|
||||
await asyncio.sleep(1)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to subscribe to LLM registry refresh notifications: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
@@ -1,388 +0,0 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _json_to_dict(value: Any) -> dict[str, Any]:
|
||||
"""Convert Prisma Json type to dict, with fallback to empty dict."""
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
# Prisma Json type should always be a dict at runtime
|
||||
return dict(value) if value else {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCost:
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None
|
||||
credential_type: str | None
|
||||
currency: str | None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCreator:
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
website_url: str | None
|
||||
logo_url: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModel:
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any]
|
||||
extra_metadata: dict[str, Any]
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
|
||||
_static_metadata: dict[str, ModelMetadata] = {}
|
||||
_static_costs: dict[str, int] = {}
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_discriminator_mapping: dict[str, str] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
|
||||
"""Register static metadata for legacy models (deprecated)."""
|
||||
_static_metadata.update({str(key): value for key, value in metadata.items()})
|
||||
_refresh_cached_schema()
|
||||
|
||||
|
||||
def register_static_costs(costs: dict[Any, int]) -> None:
|
||||
"""Register static costs for legacy models (deprecated)."""
|
||||
_static_costs.update({str(key): value for key, value in costs.items()})
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
options: list[dict[str, str]] = []
|
||||
# Only include enabled models in the dropdown options
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
options.append(
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
)
|
||||
|
||||
for slug, metadata in _static_metadata.items():
|
||||
if slug in _dynamic_models:
|
||||
continue
|
||||
options.append(
|
||||
{
|
||||
"label": slug,
|
||||
"value": slug,
|
||||
"group": metadata.provider,
|
||||
"description": "",
|
||||
}
|
||||
)
|
||||
return options
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
|
||||
async with _lock:
|
||||
try:
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
"Creator": True,
|
||||
}
|
||||
)
|
||||
logger.debug("Found %d LLM model records in database", len(records))
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
dynamic: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
provider_display_name = (
|
||||
record.Provider.displayName if record.Provider else record.providerId
|
||||
)
|
||||
# Creator name: prefer Creator.name, fallback to provider display name
|
||||
creator_name = (
|
||||
record.Creator.name if record.Creator else provider_display_name
|
||||
)
|
||||
# Price tier: default to 1 (cheapest) if not set
|
||||
price_tier = getattr(record, "priceTier", 1) or 1
|
||||
# Clamp to valid range 1-3
|
||||
price_tier = max(1, min(3, price_tier))
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display_name,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier, # type: ignore[arg-type]
|
||||
)
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=_json_to_dict(cost.metadata),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
# Map creator if present
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
dynamic[record.slug] = RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=_json_to_dict(record.capabilities),
|
||||
extra_metadata=_json_to_dict(record.metadata),
|
||||
provider_display_name=(
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
),
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
)
|
||||
|
||||
# Atomic swap - build new structures then replace references
|
||||
# This ensures readers never see partially updated state
|
||||
global _dynamic_models
|
||||
_dynamic_models = dynamic
|
||||
_refresh_cached_schema()
|
||||
logger.info(
|
||||
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
|
||||
len(dynamic),
|
||||
sum(1 for m in dynamic.values() if m.is_enabled),
|
||||
sum(1 for m in dynamic.values() if not m.is_enabled),
|
||||
)
|
||||
|
||||
|
||||
def _refresh_cached_schema() -> None:
|
||||
"""Refresh cached schema options and discriminator mapping."""
|
||||
global _schema_options, _discriminator_mapping
|
||||
|
||||
# Build new structures
|
||||
new_options = _build_schema_options()
|
||||
new_mapping = {
|
||||
slug: entry.metadata.provider for slug, entry in _dynamic_models.items()
|
||||
}
|
||||
for slug, metadata in _static_metadata.items():
|
||||
new_mapping.setdefault(slug, metadata.provider)
|
||||
|
||||
# Atomic swap - replace references to ensure readers see consistent state
|
||||
_schema_options = new_options
|
||||
_discriminator_mapping = new_mapping
|
||||
|
||||
|
||||
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
|
||||
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
|
||||
if slug in _dynamic_models:
|
||||
return _dynamic_models[slug].metadata
|
||||
return _static_metadata.get(slug)
|
||||
|
||||
|
||||
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
|
||||
"""Get model cost configuration by slug."""
|
||||
if slug in _dynamic_models:
|
||||
return _dynamic_models[slug].costs
|
||||
cost_value = _static_costs.get(slug)
|
||||
if cost_value is None:
|
||||
return tuple()
|
||||
return (
|
||||
RegistryModelCost(
|
||||
credit_cost=cost_value,
|
||||
credential_provider="static",
|
||||
credential_id=None,
|
||||
credential_type=None,
|
||||
currency=None,
|
||||
metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_llm_model_schema_options() -> list[dict[str, str]]:
|
||||
"""
|
||||
Get schema options for LLM model selection dropdown.
|
||||
|
||||
Returns a copy of cached schema options that are refreshed when the registry is
|
||||
updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
|
||||
"""
|
||||
# Return a copy to prevent external mutation
|
||||
return list(_schema_options)
|
||||
|
||||
|
||||
def get_llm_discriminator_mapping() -> dict[str, str]:
|
||||
"""
|
||||
Get discriminator mapping for LLM models.
|
||||
|
||||
Returns a copy of cached discriminator mapping that is refreshed when the registry
|
||||
is updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
|
||||
"""
|
||||
# Return a copy to prevent external mutation
|
||||
return dict(_discriminator_mapping)
|
||||
|
||||
|
||||
def get_dynamic_model_slugs() -> set[str]:
|
||||
"""Get all dynamic model slugs from the registry."""
|
||||
return set(_dynamic_models.keys())
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> set[str]:
|
||||
"""
|
||||
Get ALL model slugs (both enabled and disabled) for validation purposes.
|
||||
|
||||
This is used for JSON schema enum validation - we need to accept any known
|
||||
model value (even disabled ones) so that existing graphs don't fail validation.
|
||||
The actual fallback/enforcement happens at runtime in llm_call().
|
||||
"""
|
||||
all_slugs = set(_dynamic_models.keys())
|
||||
all_slugs.update(_static_metadata.keys())
|
||||
return all_slugs
|
||||
|
||||
|
||||
def iter_dynamic_models() -> Iterable[RegistryModel]:
|
||||
"""Iterate over all dynamic models in the registry."""
|
||||
return tuple(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
|
||||
"""
|
||||
Find a fallback model when the requested model is disabled.
|
||||
|
||||
Looks for an enabled model from the same provider. Prefers models with
|
||||
similar names or capabilities if possible.
|
||||
|
||||
Args:
|
||||
disabled_model_slug: The slug of the disabled model
|
||||
|
||||
Returns:
|
||||
An enabled RegistryModel from the same provider, or None if no fallback found
|
||||
"""
|
||||
disabled_model = _dynamic_models.get(disabled_model_slug)
|
||||
if not disabled_model:
|
||||
return None
|
||||
|
||||
provider = disabled_model.metadata.provider
|
||||
|
||||
# Find all enabled models from the same provider
|
||||
candidates = [
|
||||
model
|
||||
for model in _dynamic_models.values()
|
||||
if model.is_enabled and model.metadata.provider == provider
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Sort by: prefer models with similar context window, then by name
|
||||
candidates.sort(
|
||||
key=lambda m: (
|
||||
abs(m.metadata.context_window - disabled_model.metadata.context_window),
|
||||
m.display_name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def is_model_enabled(model_slug: str) -> bool:
|
||||
"""Check if a model is enabled in the registry."""
|
||||
model = _dynamic_models.get(model_slug)
|
||||
if not model:
|
||||
# Model not in registry - assume it's a static/legacy model and allow it
|
||||
return True
|
||||
return model.is_enabled
|
||||
|
||||
|
||||
def get_model_info(model_slug: str) -> RegistryModel | None:
|
||||
"""Get model info from the registry."""
|
||||
return _dynamic_models.get(model_slug)
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""
|
||||
Get the default model slug to use for block defaults.
|
||||
|
||||
Returns the recommended model if set (configured via admin UI),
|
||||
otherwise returns the first enabled model alphabetically.
|
||||
Returns None if no models are available or enabled.
|
||||
"""
|
||||
# Return the recommended model if one is set and enabled
|
||||
for model in _dynamic_models.values():
|
||||
if model.is_recommended and model.is_enabled:
|
||||
return model.slug
|
||||
|
||||
# No recommended model set - find first enabled model alphabetically
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
logger.warning(
|
||||
"No recommended model set, using '%s' as default",
|
||||
model.slug,
|
||||
)
|
||||
return model.slug
|
||||
|
||||
# No enabled models available
|
||||
if _dynamic_models:
|
||||
logger.error(
|
||||
"No enabled models found in registry (%d models registered but all disabled)",
|
||||
len(_dynamic_models),
|
||||
)
|
||||
else:
|
||||
logger.error("No models registered in LLM registry")
|
||||
|
||||
return None
|
||||
@@ -1,130 +0,0 @@
|
||||
"""
|
||||
Helper utilities for LLM registry integration with block schemas.
|
||||
|
||||
This module handles the dynamic injection of discriminator mappings
|
||||
and model options from the LLM registry into block schemas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.llm_registry.registry import (
|
||||
get_all_model_slugs_for_validation,
|
||||
get_default_model_slug,
|
||||
get_llm_discriminator_mapping,
|
||||
get_llm_model_schema_options,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
|
||||
"""
|
||||
Check if a field is an LLM model selection field.
|
||||
|
||||
Returns True if the field has 'options' in json_schema_extra
|
||||
(set by llm_model_schema_extra() in blocks/llm.py).
|
||||
"""
|
||||
if not hasattr(field_info, "json_schema_extra"):
|
||||
return False
|
||||
|
||||
extra = field_info.json_schema_extra
|
||||
if isinstance(extra, dict):
|
||||
return "options" in extra
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Refresh LLM model options from the registry.
|
||||
|
||||
Updates 'options' (for frontend dropdown) to show only enabled models,
|
||||
but keeps the 'enum' (for validation) inclusive of ALL known models.
|
||||
|
||||
This is important because:
|
||||
- Options: What users see in the dropdown (enabled models only)
|
||||
- Enum: What values pass validation (all known models, including disabled)
|
||||
|
||||
Existing graphs may have disabled models selected - they should pass validation
|
||||
and the fallback logic in llm_call() will handle using an alternative model.
|
||||
"""
|
||||
fresh_options = get_llm_model_schema_options()
|
||||
if not fresh_options:
|
||||
return
|
||||
|
||||
# Update options array (UI dropdown) - only enabled models
|
||||
if "options" in field_schema:
|
||||
field_schema["options"] = fresh_options
|
||||
|
||||
all_known_slugs = get_all_model_slugs_for_validation()
|
||||
if all_known_slugs and "enum" in field_schema:
|
||||
existing_enum = set(field_schema.get("enum", []))
|
||||
combined_enum = existing_enum | all_known_slugs
|
||||
field_schema["enum"] = sorted(combined_enum)
|
||||
|
||||
# Set the default value from the registry (gpt-4o if available, else first enabled)
|
||||
# This ensures new blocks have a sensible default pre-selected
|
||||
default_slug = get_default_model_slug()
|
||||
if default_slug:
|
||||
field_schema["default"] = default_slug
|
||||
|
||||
|
||||
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Refresh discriminator_mapping for fields that use model-based discrimination.
|
||||
|
||||
The discriminator is already set when AICredentialsField() creates the field.
|
||||
We only need to refresh the mapping when models are added/removed.
|
||||
"""
|
||||
if field_schema.get("discriminator") != "model":
|
||||
return
|
||||
|
||||
# Always refresh the mapping to get latest models
|
||||
fresh_mapping = get_llm_discriminator_mapping()
|
||||
if fresh_mapping is not None:
|
||||
field_schema["discriminator_mapping"] = fresh_mapping
|
||||
|
||||
|
||||
def update_schema_with_llm_registry(
|
||||
schema: dict[str, Any], model_class: type | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update a JSON schema with current LLM registry data.
|
||||
|
||||
Refreshes:
|
||||
1. Model options for LLM model selection fields (dropdown choices)
|
||||
2. Discriminator mappings for credentials fields (model → provider)
|
||||
|
||||
Args:
|
||||
schema: The JSON schema to update (mutated in-place)
|
||||
model_class: The Pydantic model class (optional, for field introspection)
|
||||
"""
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
if not isinstance(field_schema, dict):
|
||||
continue
|
||||
|
||||
# Refresh model options for LLM model fields
|
||||
if model_class and hasattr(model_class, "model_fields"):
|
||||
field_info = model_class.model_fields.get(field_name)
|
||||
if field_info and is_llm_model_field(field_name, field_info):
|
||||
try:
|
||||
refresh_llm_model_options(field_schema)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh LLM options for field %s: %s",
|
||||
field_name,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Refresh discriminator mapping for fields that use model discrimination
|
||||
try:
|
||||
refresh_llm_discriminator_mapping(field_schema)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh discriminator mapping for field %s: %s",
|
||||
field_name,
|
||||
exc,
|
||||
)
|
||||
@@ -40,7 +40,6 @@ from pydantic_core import (
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.llm_registry import update_schema_with_llm_registry
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.settings import Secrets
|
||||
@@ -545,9 +544,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
else:
|
||||
schema["credentials_provider"] = allowed_providers
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
|
||||
# Ensure LLM discriminators are populated (delegates to shared helper)
|
||||
update_schema_with_llm_registry(schema, model_class)
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
@@ -696,20 +693,16 @@ def CredentialsField(
|
||||
This is enforced by the `BlockSchema` base class.
|
||||
"""
|
||||
|
||||
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
|
||||
field_schema_extra: dict[str, Any] = {}
|
||||
|
||||
# Always include discriminator if provided
|
||||
if discriminator is not None:
|
||||
field_schema_extra["discriminator"] = discriminator
|
||||
# Always include discriminator_mapping when discriminator is set (even if empty initially)
|
||||
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
|
||||
|
||||
# Include other optional fields (only if not None)
|
||||
if required_scopes:
|
||||
field_schema_extra["credentials_scopes"] = list(required_scopes)
|
||||
if discriminator_values:
|
||||
field_schema_extra["discriminator_values"] = discriminator_values
|
||||
field_schema_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
"discriminator_values": discriminator_values,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# Merge any json_schema_extra passed in kwargs
|
||||
if "json_schema_extra" in kwargs:
|
||||
|
||||
@@ -41,6 +41,7 @@ FrontendOnboardingStep = Literal[
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.AGENT_INPUT,
|
||||
OnboardingStep.CONGRATS,
|
||||
OnboardingStep.VISIT_COPILOT,
|
||||
OnboardingStep.MARKETPLACE_VISIT,
|
||||
OnboardingStep.BUILDER_OPEN,
|
||||
]
|
||||
@@ -122,6 +123,9 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
|
||||
reward = 0
|
||||
match step:
|
||||
# Welcome bonus for visiting copilot ($5 = 500 credits)
|
||||
case OnboardingStep.VISIT_COPILOT:
|
||||
reward = 500
|
||||
# Reward user when they clicked New Run during onboarding
|
||||
# This is because they need credits before scheduling a run (next step)
|
||||
# This is seen as a reward for the GET_RESULTS step in the wallet
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
Helper functions for LLM registry initialization in executor context.
|
||||
|
||||
These functions handle refreshing the LLM registry when the executor starts
|
||||
and subscribing to real-time updates via Redis pub/sub.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from backend.data import db, llm_registry
|
||||
from backend.data.block import BlockSchema, initialize_blocks
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.data.llm_registry import subscribe_to_registry_refresh
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def initialize_registry_for_executor() -> None:
|
||||
"""
|
||||
Initialize blocks and refresh LLM registry in the executor context.
|
||||
|
||||
This must run in the executor's event loop to have access to the database.
|
||||
"""
|
||||
try:
|
||||
# Connect to database if not already connected
|
||||
if not db.is_connected():
|
||||
await db.connect()
|
||||
logger.info("[GraphExecutor] Connected to database for registry refresh")
|
||||
|
||||
# Initialize blocks (internally refreshes LLM registry and costs)
|
||||
await initialize_blocks()
|
||||
logger.info("[GraphExecutor] Blocks initialized")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_registry_on_notification() -> None:
|
||||
"""Refresh LLM registry when notified via Redis pub/sub."""
|
||||
try:
|
||||
# Ensure DB is connected
|
||||
if not db.is_connected():
|
||||
await db.connect()
|
||||
|
||||
# Refresh registry and costs
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they regenerate with new model options
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
logger.info("[GraphExecutor] LLM registry refreshed from notification")
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def subscribe_to_registry_updates() -> None:
|
||||
"""Subscribe to Redis pub/sub for LLM registry refresh notifications."""
|
||||
await subscribe_to_registry_refresh(refresh_registry_on_notification)
|
||||
@@ -702,20 +702,6 @@ class ExecutionProcessor:
|
||||
)
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
|
||||
# Initialize LLM registry and subscribe to updates
|
||||
from backend.executor.llm_registry_init import (
|
||||
initialize_registry_for_executor,
|
||||
subscribe_to_registry_updates,
|
||||
)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
initialize_registry_for_executor(), self.node_execution_loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
subscribe_to_registry_updates(), self.node_execution_loop
|
||||
)
|
||||
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
|
||||
@@ -1,935 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Sequence, cast
|
||||
|
||||
import prisma
|
||||
import prisma.models
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
from backend.util.models import Pagination
|
||||
|
||||
|
||||
def _json_dict(value: Any | None) -> dict[str, Any]:
|
||||
if not value:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return {}
|
||||
|
||||
|
||||
def _map_cost(record: prisma.models.LlmModelCost) -> llm_model.LlmModelCost:
|
||||
return llm_model.LlmModelCost(
|
||||
id=record.id,
|
||||
unit=record.unit,
|
||||
credit_cost=record.creditCost,
|
||||
credential_provider=record.credentialProvider,
|
||||
credential_id=record.credentialId,
|
||||
credential_type=record.credentialType,
|
||||
currency=record.currency,
|
||||
metadata=_json_dict(record.metadata),
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(
|
||||
record: prisma.models.LlmModelCreator,
|
||||
) -> llm_model.LlmModelCreator:
|
||||
return llm_model.LlmModelCreator(
|
||||
id=record.id,
|
||||
name=record.name,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
website_url=record.websiteUrl,
|
||||
logo_url=record.logoUrl,
|
||||
metadata=_json_dict(record.metadata),
|
||||
)
|
||||
|
||||
|
||||
def _map_model(record: prisma.models.LlmModel) -> llm_model.LlmModel:
|
||||
costs = []
|
||||
if record.Costs:
|
||||
costs = [_map_cost(cost) for cost in record.Costs]
|
||||
|
||||
creator = None
|
||||
if hasattr(record, "Creator") and record.Creator:
|
||||
creator = _map_creator(record.Creator)
|
||||
|
||||
return llm_model.LlmModel(
|
||||
id=record.id,
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
provider_id=record.providerId,
|
||||
creator_id=record.creatorId,
|
||||
creator=creator,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
capabilities=_json_dict(record.capabilities),
|
||||
metadata=_json_dict(record.metadata),
|
||||
costs=costs,
|
||||
)
|
||||
|
||||
|
||||
def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
|
||||
models: list[llm_model.LlmModel] = []
|
||||
if record.Models:
|
||||
models = [_map_model(model) for model in record.Models]
|
||||
|
||||
return llm_model.LlmProvider(
|
||||
id=record.id,
|
||||
name=record.name,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
default_credential_provider=record.defaultCredentialProvider,
|
||||
default_credential_id=record.defaultCredentialId,
|
||||
default_credential_type=record.defaultCredentialType,
|
||||
supports_tools=record.supportsTools,
|
||||
supports_json_output=record.supportsJsonOutput,
|
||||
supports_reasoning=record.supportsReasoning,
|
||||
supports_parallel_tool=record.supportsParallelTool,
|
||||
metadata=_json_dict(record.metadata),
|
||||
models=models,
|
||||
)
|
||||
|
||||
|
||||
async def list_providers(
|
||||
include_models: bool = True, enabled_only: bool = False
|
||||
) -> list[llm_model.LlmProvider]:
|
||||
"""
|
||||
List all LLM providers.
|
||||
|
||||
Args:
|
||||
include_models: Whether to include models for each provider
|
||||
enabled_only: If True, only include enabled models (for public routes)
|
||||
"""
|
||||
include: Any = None
|
||||
if include_models:
|
||||
model_where = {"isEnabled": True} if enabled_only else None
|
||||
include = {
|
||||
"Models": {
|
||||
"include": {"Costs": True, "Creator": True},
|
||||
"where": model_where,
|
||||
}
|
||||
}
|
||||
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
|
||||
return [_map_provider(record) for record in records]
|
||||
|
||||
|
||||
async def upsert_provider(
|
||||
request: llm_model.UpsertLlmProviderRequest,
|
||||
provider_id: str | None = None,
|
||||
) -> llm_model.LlmProvider:
|
||||
data: Any = {
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"defaultCredentialProvider": request.default_credential_provider,
|
||||
"defaultCredentialId": request.default_credential_id,
|
||||
"defaultCredentialType": request.default_credential_type,
|
||||
"supportsTools": request.supports_tools,
|
||||
"supportsJsonOutput": request.supports_json_output,
|
||||
"supportsReasoning": request.supports_reasoning,
|
||||
"supportsParallelTool": request.supports_parallel_tool,
|
||||
"metadata": prisma.Json(request.metadata or {}),
|
||||
}
|
||||
include: Any = {"Models": {"include": {"Costs": True, "Creator": True}}}
|
||||
if provider_id:
|
||||
record = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
data=data,
|
||||
include=include,
|
||||
)
|
||||
else:
|
||||
record = await prisma.models.LlmProvider.prisma().create(
|
||||
data=data,
|
||||
include=include,
|
||||
)
|
||||
if record is None:
|
||||
raise ValueError("Failed to create/update provider")
|
||||
return _map_provider(record)
|
||||
|
||||
|
||||
async def delete_provider(provider_id: str) -> bool:
|
||||
"""
|
||||
Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
Due to onDelete: Restrict on LlmModel.Provider, the database will
|
||||
block deletion if models exist.
|
||||
|
||||
Args:
|
||||
provider_id: UUID of the provider to delete
|
||||
|
||||
Returns:
|
||||
True if deleted successfully
|
||||
|
||||
Raises:
|
||||
ValueError: If provider not found or has associated models
|
||||
"""
|
||||
# Check if provider exists
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Check if provider has any models
|
||||
model_count = len(provider.Models) if provider.Models else 0
|
||||
if model_count > 0:
|
||||
raise ValueError(
|
||||
f"Cannot delete provider '{provider.displayName}' because it has "
|
||||
f"{model_count} model(s). Delete all models first."
|
||||
)
|
||||
|
||||
# Safe to delete
|
||||
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
|
||||
return True
|
||||
|
||||
|
||||
async def list_models(
|
||||
provider_id: str | None = None,
|
||||
enabled_only: bool = False,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> llm_model.LlmModelsResponse:
|
||||
"""
|
||||
List LLM models with pagination.
|
||||
|
||||
Args:
|
||||
provider_id: Optional filter by provider ID
|
||||
enabled_only: If True, only return enabled models (for public routes)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Number of models per page
|
||||
"""
|
||||
where: Any = {}
|
||||
if provider_id:
|
||||
where["providerId"] = provider_id
|
||||
if enabled_only:
|
||||
where["isEnabled"] = True
|
||||
|
||||
# Get total count for pagination
|
||||
total_items = await prisma.models.LlmModel.prisma().count(
|
||||
where=where if where else None
|
||||
)
|
||||
|
||||
# Calculate pagination
|
||||
skip = (page - 1) * page_size
|
||||
total_pages = (total_items + page_size - 1) // page_size if total_items > 0 else 0
|
||||
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
where=where if where else None,
|
||||
include={"Costs": True, "Creator": True},
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
)
|
||||
models = [_map_model(record) for record in records]
|
||||
|
||||
return llm_model.LlmModelsResponse(
|
||||
models=models,
|
||||
pagination=Pagination(
|
||||
total_items=total_items,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _cost_create_payload(
|
||||
costs: Sequence[llm_model.LlmModelCostInput],
|
||||
) -> dict[str, Iterable[dict[str, Any]]]:
|
||||
|
||||
create_items = []
|
||||
for cost in costs:
|
||||
item: dict[str, Any] = {
|
||||
"unit": cost.unit,
|
||||
"creditCost": cost.credit_cost,
|
||||
"credentialProvider": cost.credential_provider,
|
||||
}
|
||||
# Only include optional fields if they have values
|
||||
if cost.credential_id:
|
||||
item["credentialId"] = cost.credential_id
|
||||
if cost.credential_type:
|
||||
item["credentialType"] = cost.credential_type
|
||||
if cost.currency:
|
||||
item["currency"] = cost.currency
|
||||
# Handle metadata - use Prisma Json type
|
||||
if cost.metadata is not None and cost.metadata != {}:
|
||||
item["metadata"] = prisma.Json(cost.metadata)
|
||||
create_items.append(item)
|
||||
return {"create": create_items}
|
||||
|
||||
|
||||
async def create_model(
|
||||
request: llm_model.CreateLlmModelRequest,
|
||||
) -> llm_model.LlmModel:
|
||||
data: Any = {
|
||||
"slug": request.slug,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"Provider": {"connect": {"id": request.provider_id}},
|
||||
"contextWindow": request.context_window,
|
||||
"maxOutputTokens": request.max_output_tokens,
|
||||
"isEnabled": request.is_enabled,
|
||||
"capabilities": prisma.Json(request.capabilities or {}),
|
||||
"metadata": prisma.Json(request.metadata or {}),
|
||||
"Costs": _cost_create_payload(request.costs),
|
||||
}
|
||||
if request.creator_id:
|
||||
data["Creator"] = {"connect": {"id": request.creator_id}}
|
||||
|
||||
record = await prisma.models.LlmModel.prisma().create(
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
return _map_model(record)
|
||||
|
||||
|
||||
async def update_model(
|
||||
model_id: str,
|
||||
request: llm_model.UpdateLlmModelRequest,
|
||||
) -> llm_model.LlmModel:
|
||||
# Build scalar field updates (non-relation fields)
|
||||
scalar_data: Any = {}
|
||||
if request.display_name is not None:
|
||||
scalar_data["displayName"] = request.display_name
|
||||
if request.description is not None:
|
||||
scalar_data["description"] = request.description
|
||||
if request.context_window is not None:
|
||||
scalar_data["contextWindow"] = request.context_window
|
||||
if request.max_output_tokens is not None:
|
||||
scalar_data["maxOutputTokens"] = request.max_output_tokens
|
||||
if request.is_enabled is not None:
|
||||
scalar_data["isEnabled"] = request.is_enabled
|
||||
if request.capabilities is not None:
|
||||
scalar_data["capabilities"] = request.capabilities
|
||||
if request.metadata is not None:
|
||||
scalar_data["metadata"] = request.metadata
|
||||
# Foreign keys can be updated directly as scalar fields
|
||||
if request.provider_id is not None:
|
||||
scalar_data["providerId"] = request.provider_id
|
||||
if request.creator_id is not None:
|
||||
# Empty string means remove the creator
|
||||
scalar_data["creatorId"] = request.creator_id if request.creator_id else None
|
||||
|
||||
# If we have costs to update, we need to handle them separately
|
||||
# because nested writes have different constraints
|
||||
if request.costs is not None:
|
||||
# Wrap cost replacement in a transaction for atomicity
|
||||
async with transaction() as tx:
|
||||
# First update scalar fields
|
||||
if scalar_data:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data=scalar_data,
|
||||
)
|
||||
# Then handle costs: delete existing and create new
|
||||
await tx.llmmodelcost.delete_many(where={"llmModelId": model_id})
|
||||
if request.costs:
|
||||
cost_payload = _cost_create_payload(request.costs)
|
||||
for cost_item in cost_payload["create"]:
|
||||
cost_item["llmModelId"] = model_id
|
||||
await tx.llmmodelcost.create(data=cast(Any, cost_item))
|
||||
# Fetch the updated record (outside transaction)
|
||||
record = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
else:
|
||||
# No costs update - simple update
|
||||
record = await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data=scalar_data,
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
|
||||
if not record:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return _map_model(record)
|
||||
|
||||
|
||||
async def toggle_model(
|
||||
model_id: str,
|
||||
is_enabled: bool,
|
||||
migrate_to_slug: str | None = None,
|
||||
migration_reason: str | None = None,
|
||||
custom_credit_cost: int | None = None,
|
||||
) -> llm_model.ToggleLlmModelResponse:
|
||||
"""
|
||||
Toggle a model's enabled status, optionally migrating workflows when disabling.
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model to toggle
|
||||
is_enabled: New enabled status
|
||||
migrate_to_slug: If disabling and this is provided, migrate all workflows
|
||||
using this model to the specified replacement model
|
||||
migration_reason: Optional reason for the migration (e.g., "Provider outage")
|
||||
custom_credit_cost: Optional custom pricing override for migrated workflows.
|
||||
When set, the billing system should use this cost instead
|
||||
of the target model's cost for affected nodes.
|
||||
|
||||
Returns:
|
||||
ToggleLlmModelResponse with the updated model and optional migration stats
|
||||
"""
|
||||
import json
|
||||
|
||||
# Get the model being toggled
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
nodes_migrated = 0
|
||||
migration_id: str | None = None
|
||||
|
||||
# If disabling with migration, perform migration first
|
||||
if not is_enabled and migrate_to_slug:
|
||||
# Validate replacement model exists and is enabled
|
||||
replacement = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migrate_to_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(f"Replacement model '{migrate_to_slug}' not found")
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
# Perform all operations atomically within a single transaction
|
||||
# This ensures no nodes are missed between query and update
|
||||
async with transaction() as tx:
|
||||
# Get the IDs of nodes that will be migrated (inside transaction for consistency)
|
||||
node_ids_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT id
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
FOR UPDATE
|
||||
""",
|
||||
model.slug,
|
||||
)
|
||||
migrated_node_ids = (
|
||||
[row["id"] for row in node_ids_result] if node_ids_result else []
|
||||
)
|
||||
nodes_migrated = len(migrated_node_ids)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
# Update by IDs to ensure we only update the exact nodes we queried
|
||||
# Use JSON array and jsonb_array_elements_text for safe parameterization
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
""",
|
||||
migrate_to_slug,
|
||||
node_ids_json,
|
||||
)
|
||||
|
||||
record = await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
include={"Costs": True},
|
||||
)
|
||||
|
||||
# Create migration record for revert capability
|
||||
if nodes_migrated > 0:
|
||||
migration_data: Any = {
|
||||
"sourceModelSlug": model.slug,
|
||||
"targetModelSlug": migrate_to_slug,
|
||||
"reason": migration_reason,
|
||||
"migratedNodeIds": json.dumps(migrated_node_ids),
|
||||
"nodeCount": nodes_migrated,
|
||||
"customCreditCost": custom_credit_cost,
|
||||
}
|
||||
migration_record = await tx.llmmodelmigration.create(
|
||||
data=migration_data
|
||||
)
|
||||
migration_id = migration_record.id
|
||||
else:
|
||||
# Simple toggle without migration
|
||||
record = await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
include={"Costs": True},
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return llm_model.ToggleLlmModelResponse(
|
||||
model=_map_model(record),
|
||||
nodes_migrated=nodes_migrated,
|
||||
migrated_to_slug=migrate_to_slug if nodes_migrated > 0 else None,
|
||||
migration_id=migration_id,
|
||||
)
|
||||
|
||||
|
||||
async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
|
||||
"""Get usage count for a model."""
|
||||
import prisma as prisma_module
|
||||
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
model.slug,
|
||||
)
|
||||
node_count = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
return llm_model.LlmModelUsageResponse(model_slug=model.slug, node_count=node_count)
|
||||
|
||||
|
||||
async def delete_model(
|
||||
model_id: str, replacement_model_slug: str | None = None
|
||||
) -> llm_model.DeleteLlmModelResponse:
|
||||
"""
|
||||
Delete a model and optionally migrate all AgentNodes using it to a replacement model.
|
||||
|
||||
This performs an atomic operation within a database transaction:
|
||||
1. Validates the model exists
|
||||
2. Counts affected nodes
|
||||
3. If nodes exist, validates replacement model and migrates them
|
||||
4. Deletes the LlmModel record (CASCADE deletes costs)
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model to delete
|
||||
replacement_model_slug: Slug of the model to migrate to (required only if nodes use this model)
|
||||
|
||||
Returns:
|
||||
DeleteLlmModelResponse with migration stats
|
||||
|
||||
Raises:
|
||||
ValueError: If model not found, nodes exist but no replacement provided,
|
||||
replacement not found, or replacement is disabled
|
||||
"""
|
||||
# 1. Get the model being deleted (validation - outside transaction)
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
deleted_slug = model.slug
|
||||
deleted_display_name = model.displayName
|
||||
|
||||
# 2. Count affected nodes first to determine if replacement is needed
|
||||
import prisma as prisma_module
|
||||
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
deleted_slug,
|
||||
)
|
||||
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
# 3. Validate replacement model only if there are nodes to migrate
|
||||
if nodes_to_migrate > 0:
|
||||
if not replacement_model_slug:
|
||||
raise ValueError(
|
||||
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
|
||||
f"are using it. Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
replacement = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": replacement_model_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(f"Replacement model '{replacement_model_slug}' not found")
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
# 4. Perform migration (if needed) and deletion atomically within a transaction
|
||||
async with transaction() as tx:
|
||||
# Migrate all AgentNode.constantInput->model to replacement
|
||||
if nodes_to_migrate > 0 and replacement_model_slug:
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug,
|
||||
)
|
||||
|
||||
# Delete the model (CASCADE will delete costs automatically)
|
||||
await tx.llmmodel.delete(where={"id": model_id})
|
||||
|
||||
# Build appropriate message based on whether migration happened
|
||||
if nodes_to_migrate > 0:
|
||||
message = (
|
||||
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}) "
|
||||
f"and migrated {nodes_to_migrate} workflow node(s) to '{replacement_model_slug}'."
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}). "
|
||||
f"No workflows were using this model."
|
||||
)
|
||||
|
||||
return llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug=deleted_slug,
|
||||
deleted_model_display_name=deleted_display_name,
|
||||
replacement_model_slug=replacement_model_slug,
|
||||
nodes_migrated=nodes_to_migrate,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
def _map_migration(
|
||||
record: prisma.models.LlmModelMigration,
|
||||
) -> llm_model.LlmModelMigration:
|
||||
return llm_model.LlmModelMigration(
|
||||
id=record.id,
|
||||
source_model_slug=record.sourceModelSlug,
|
||||
target_model_slug=record.targetModelSlug,
|
||||
reason=record.reason,
|
||||
node_count=record.nodeCount,
|
||||
custom_credit_cost=record.customCreditCost,
|
||||
is_reverted=record.isReverted,
|
||||
created_at=record.createdAt.isoformat(),
|
||||
reverted_at=record.revertedAt.isoformat() if record.revertedAt else None,
|
||||
)
|
||||
|
||||
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> list[llm_model.LlmModelMigration]:
|
||||
"""
|
||||
List model migrations, optionally including reverted ones.
|
||||
|
||||
Args:
|
||||
include_reverted: If True, include reverted migrations. Default is False.
|
||||
|
||||
Returns:
|
||||
List of LlmModelMigration records
|
||||
"""
|
||||
where: Any = None if include_reverted else {"isReverted": False}
|
||||
records = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||
where=where,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [_map_migration(record) for record in records]
|
||||
|
||||
|
||||
async def get_migration(migration_id: str) -> llm_model.LlmModelMigration | None:
|
||||
"""Get a specific migration by ID."""
|
||||
record = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
return _map_migration(record) if record else None
|
||||
|
||||
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> llm_model.RevertMigrationResponse:
|
||||
"""
|
||||
Revert a model migration, restoring affected nodes to their original model.
|
||||
|
||||
This only reverts the specific nodes that were migrated, not all nodes
|
||||
currently using the target model.
|
||||
|
||||
Args:
|
||||
migration_id: UUID of the migration to revert
|
||||
re_enable_source_model: Whether to re-enable the source model if it's disabled
|
||||
|
||||
Returns:
|
||||
RevertMigrationResponse with revert stats
|
||||
|
||||
Raises:
|
||||
ValueError: If migration not found, already reverted, or source model not available
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Get the migration record
|
||||
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
if not migration:
|
||||
raise ValueError(f"Migration with id '{migration_id}' not found")
|
||||
|
||||
if migration.isReverted:
|
||||
raise ValueError(
|
||||
f"Migration '{migration_id}' has already been reverted "
|
||||
f"on {migration.revertedAt.isoformat() if migration.revertedAt else 'unknown date'}"
|
||||
)
|
||||
|
||||
# Check if source model exists
|
||||
source_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migration.sourceModelSlug}
|
||||
)
|
||||
if not source_model:
|
||||
raise ValueError(
|
||||
f"Source model '{migration.sourceModelSlug}' no longer exists. "
|
||||
f"Cannot revert migration."
|
||||
)
|
||||
|
||||
# Get the migrated node IDs (Prisma auto-parses JSONB to list)
|
||||
migrated_node_ids: list[str] = (
|
||||
migration.migratedNodeIds
|
||||
if isinstance(migration.migratedNodeIds, list)
|
||||
else json.loads(migration.migratedNodeIds) # type: ignore
|
||||
)
|
||||
if not migrated_node_ids:
|
||||
raise ValueError("No nodes to revert in this migration")
|
||||
|
||||
# Track if we need to re-enable the source model
|
||||
source_model_was_disabled = not source_model.isEnabled
|
||||
should_re_enable = source_model_was_disabled and re_enable_source_model
|
||||
source_model_re_enabled = False
|
||||
|
||||
# Perform revert atomically
|
||||
async with transaction() as tx:
|
||||
# Re-enable the source model if requested and it was disabled
|
||||
if should_re_enable:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": source_model.id},
|
||||
data={"isEnabled": True},
|
||||
)
|
||||
source_model_re_enabled = True
|
||||
|
||||
# Update only the specific nodes that were migrated
|
||||
# We need to check that they still have the target model (haven't been changed since)
|
||||
# Use a single batch update for efficiency
|
||||
# Use JSON array and jsonb_array_elements_text for safe parameterization
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
result = await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
AND "constantInput"::jsonb->>'model' = $3
|
||||
""",
|
||||
migration.sourceModelSlug,
|
||||
node_ids_json,
|
||||
migration.targetModelSlug,
|
||||
)
|
||||
nodes_reverted = result if result else 0
|
||||
|
||||
# Mark migration as reverted
|
||||
await tx.llmmodelmigration.update(
|
||||
where={"id": migration_id},
|
||||
data={
|
||||
"isReverted": True,
|
||||
"revertedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
# Calculate nodes that were already changed since migration
|
||||
nodes_already_changed = len(migrated_node_ids) - nodes_reverted
|
||||
|
||||
# Build appropriate message
|
||||
message_parts = [
|
||||
f"Successfully reverted migration: {nodes_reverted} node(s) restored "
|
||||
f"from '{migration.targetModelSlug}' to '{migration.sourceModelSlug}'."
|
||||
]
|
||||
if nodes_already_changed > 0:
|
||||
message_parts.append(
|
||||
f" {nodes_already_changed} node(s) were already changed and not reverted."
|
||||
)
|
||||
if source_model_re_enabled:
|
||||
message_parts.append(
|
||||
f" Model '{migration.sourceModelSlug}' has been re-enabled."
|
||||
)
|
||||
|
||||
return llm_model.RevertMigrationResponse(
|
||||
migration_id=migration_id,
|
||||
source_model_slug=migration.sourceModelSlug,
|
||||
target_model_slug=migration.targetModelSlug,
|
||||
nodes_reverted=nodes_reverted,
|
||||
nodes_already_changed=nodes_already_changed,
|
||||
source_model_re_enabled=source_model_re_enabled,
|
||||
message="".join(message_parts),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Creator CRUD operations
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def list_creators() -> list[llm_model.LlmModelCreator]:
|
||||
"""List all LLM model creators."""
|
||||
records = await prisma.models.LlmModelCreator.prisma().find_many(
|
||||
order={"displayName": "asc"}
|
||||
)
|
||||
return [_map_creator(record) for record in records]
|
||||
|
||||
|
||||
async def get_creator(creator_id: str) -> llm_model.LlmModelCreator | None:
|
||||
"""Get a specific creator by ID."""
|
||||
record = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"id": creator_id}
|
||||
)
|
||||
return _map_creator(record) if record else None
|
||||
|
||||
|
||||
async def upsert_creator(
|
||||
request: llm_model.UpsertLlmCreatorRequest,
|
||||
creator_id: str | None = None,
|
||||
) -> llm_model.LlmModelCreator:
|
||||
"""Create or update a model creator."""
|
||||
data: Any = {
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"websiteUrl": request.website_url,
|
||||
"logoUrl": request.logo_url,
|
||||
"metadata": prisma.Json(request.metadata or {}),
|
||||
}
|
||||
if creator_id:
|
||||
record = await prisma.models.LlmModelCreator.prisma().update(
|
||||
where={"id": creator_id},
|
||||
data=data,
|
||||
)
|
||||
else:
|
||||
record = await prisma.models.LlmModelCreator.prisma().create(data=data)
|
||||
if record is None:
|
||||
raise ValueError("Failed to create/update creator")
|
||||
return _map_creator(record)
|
||||
|
||||
|
||||
async def delete_creator(creator_id: str) -> bool:
|
||||
"""
|
||||
Delete a model creator.
|
||||
|
||||
This will set creatorId to NULL on all associated models (due to onDelete: SetNull).
|
||||
|
||||
Args:
|
||||
creator_id: UUID of the creator to delete
|
||||
|
||||
Returns:
|
||||
True if deleted successfully
|
||||
|
||||
Raises:
|
||||
ValueError: If creator not found
|
||||
"""
|
||||
creator = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"id": creator_id}
|
||||
)
|
||||
if not creator:
|
||||
raise ValueError(f"Creator with id '{creator_id}' not found")
|
||||
|
||||
await prisma.models.LlmModelCreator.prisma().delete(where={"id": creator_id})
|
||||
return True
|
||||
|
||||
|
||||
async def get_recommended_model() -> llm_model.LlmModel | None:
|
||||
"""
|
||||
Get the currently recommended LLM model.
|
||||
|
||||
Returns:
|
||||
The recommended model, or None if no model is marked as recommended.
|
||||
"""
|
||||
record = await prisma.models.LlmModel.prisma().find_first(
|
||||
where={"isRecommended": True, "isEnabled": True},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return _map_model(record) if record else None
|
||||
|
||||
|
||||
async def set_recommended_model(
|
||||
model_id: str,
|
||||
) -> tuple[llm_model.LlmModel, str | None]:
|
||||
"""
|
||||
Set a model as the recommended model.
|
||||
|
||||
This will clear the isRecommended flag from any other model and set it
|
||||
on the specified model. The model must be enabled.
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model to set as recommended
|
||||
|
||||
Returns:
|
||||
Tuple of (the updated model, previous recommended model slug or None)
|
||||
|
||||
Raises:
|
||||
ValueError: If model not found or not enabled
|
||||
"""
|
||||
# First, verify the model exists and is enabled
|
||||
target_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}
|
||||
)
|
||||
if not target_model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
if not target_model.isEnabled:
|
||||
raise ValueError(
|
||||
f"Cannot set disabled model '{target_model.slug}' as recommended"
|
||||
)
|
||||
|
||||
# Get the current recommended model (if any)
|
||||
current_recommended = await prisma.models.LlmModel.prisma().find_first(
|
||||
where={"isRecommended": True}
|
||||
)
|
||||
previous_slug = current_recommended.slug if current_recommended else None
|
||||
|
||||
# Use a transaction to ensure atomicity
|
||||
async with transaction() as tx:
|
||||
# Clear isRecommended from all models
|
||||
await tx.llmmodel.update_many(
|
||||
where={"isRecommended": True},
|
||||
data={"isRecommended": False},
|
||||
)
|
||||
# Set the new recommended model
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isRecommended": True},
|
||||
)
|
||||
|
||||
# Fetch and return the updated model
|
||||
updated_record = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
if not updated_record:
|
||||
raise ValueError("Failed to fetch updated model")
|
||||
|
||||
return _map_model(updated_record), previous_slug
|
||||
|
||||
|
||||
async def get_recommended_model_slug() -> str | None:
|
||||
"""
|
||||
Get the slug of the currently recommended LLM model.
|
||||
|
||||
Returns:
|
||||
The slug of the recommended model, or None if no model is marked as recommended.
|
||||
"""
|
||||
record = await prisma.models.LlmModel.prisma().find_first(
|
||||
where={"isRecommended": True, "isEnabled": True},
|
||||
)
|
||||
return record.slug if record else None
|
||||
@@ -1,235 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
from backend.util.models import Pagination
|
||||
|
||||
# Pattern for valid model slugs: alphanumeric start, then alphanumeric, dots, underscores, slashes, hyphens
|
||||
SLUG_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._/-]*$")
|
||||
|
||||
|
||||
class LlmModelCost(pydantic.BaseModel):
|
||||
id: str
|
||||
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: Optional[str] = None
|
||||
credential_type: Optional[str] = None
|
||||
currency: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCreator(pydantic.BaseModel):
|
||||
"""Represents the organization that created/trained the model (e.g., OpenAI, Meta)."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
website_url: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModel(pydantic.BaseModel):
|
||||
id: str
|
||||
slug: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
provider_id: str
|
||||
creator_id: Optional[str] = None
|
||||
creator: Optional[LlmModelCreator] = None
|
||||
context_window: int
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: bool = True
|
||||
is_recommended: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvider(pydantic.BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
default_credential_provider: Optional[str] = None
|
||||
default_credential_id: Optional[str] = None
|
||||
default_credential_type: Optional[str] = None
|
||||
supports_tools: bool = True
|
||||
supports_json_output: bool = True
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool: bool = False
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
models: list[LlmModel] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvidersResponse(pydantic.BaseModel):
|
||||
providers: list[LlmProvider]
|
||||
|
||||
|
||||
class LlmModelsResponse(pydantic.BaseModel):
|
||||
models: list[LlmModel]
|
||||
pagination: Optional[Pagination] = None
|
||||
|
||||
|
||||
class LlmCreatorsResponse(pydantic.BaseModel):
|
||||
creators: list[LlmModelCreator]
|
||||
|
||||
|
||||
class UpsertLlmProviderRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
default_credential_provider: Optional[str] = None
|
||||
default_credential_id: Optional[str] = None
|
||||
default_credential_type: Optional[str] = "api_key"
|
||||
supports_tools: bool = True
|
||||
supports_json_output: bool = True
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool: bool = False
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class UpsertLlmCreatorRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
website_url: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCostInput(pydantic.BaseModel):
|
||||
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: Optional[str] = None
|
||||
credential_type: Optional[str] = "api_key"
|
||||
currency: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class CreateLlmModelRequest(pydantic.BaseModel):
|
||||
slug: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
provider_id: str
|
||||
creator_id: Optional[str] = None
|
||||
context_window: int
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: bool = True
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCostInput]
|
||||
|
||||
@pydantic.field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str) -> str:
|
||||
if not v or len(v) > 100:
|
||||
raise ValueError("Slug must be 1-100 characters")
|
||||
if not SLUG_PATTERN.match(v):
|
||||
raise ValueError(
|
||||
"Slug must start with alphanumeric and contain only "
|
||||
"alphanumeric characters, dots, underscores, slashes, or hyphens"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class UpdateLlmModelRequest(pydantic.BaseModel):
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
context_window: Optional[int] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: Optional[bool] = None
|
||||
capabilities: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
provider_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
costs: Optional[list[LlmModelCostInput]] = None
|
||||
|
||||
|
||||
class ToggleLlmModelRequest(pydantic.BaseModel):
|
||||
is_enabled: bool
|
||||
migrate_to_slug: Optional[str] = None
|
||||
migration_reason: Optional[str] = None # e.g., "Provider outage"
|
||||
# Custom pricing override for migrated workflows. When set, billing should use
|
||||
# this cost instead of the target model's cost for affected nodes.
|
||||
# See LlmModelMigration in schema.prisma for full documentation.
|
||||
custom_credit_cost: Optional[int] = None
|
||||
|
||||
|
||||
class ToggleLlmModelResponse(pydantic.BaseModel):
|
||||
model: LlmModel
|
||||
nodes_migrated: int = 0
|
||||
migrated_to_slug: Optional[str] = None
|
||||
migration_id: Optional[str] = None # ID of the migration record for revert
|
||||
|
||||
|
||||
class DeleteLlmModelResponse(pydantic.BaseModel):
|
||||
deleted_model_slug: str
|
||||
deleted_model_display_name: str
|
||||
replacement_model_slug: Optional[str] = None
|
||||
nodes_migrated: int
|
||||
message: str
|
||||
|
||||
|
||||
class LlmModelUsageResponse(pydantic.BaseModel):
|
||||
model_slug: str
|
||||
node_count: int
|
||||
|
||||
|
||||
# Migration tracking models
|
||||
class LlmModelMigration(pydantic.BaseModel):
|
||||
id: str
|
||||
source_model_slug: str
|
||||
target_model_slug: str
|
||||
reason: Optional[str] = None
|
||||
node_count: int
|
||||
# Custom pricing override - billing should use this instead of target model's cost
|
||||
custom_credit_cost: Optional[int] = None
|
||||
is_reverted: bool = False
|
||||
created_at: datetime
|
||||
reverted_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class LlmMigrationsResponse(pydantic.BaseModel):
|
||||
migrations: list[LlmModelMigration]
|
||||
|
||||
|
||||
class RevertMigrationRequest(pydantic.BaseModel):
|
||||
re_enable_source_model: bool = (
|
||||
True # Whether to re-enable the source model if disabled
|
||||
)
|
||||
|
||||
|
||||
class RevertMigrationResponse(pydantic.BaseModel):
|
||||
migration_id: str
|
||||
source_model_slug: str
|
||||
target_model_slug: str
|
||||
nodes_reverted: int
|
||||
nodes_already_changed: int = (
|
||||
0 # Nodes that were modified since migration (not reverted)
|
||||
)
|
||||
source_model_re_enabled: bool = False # Whether the source model was re-enabled
|
||||
message: str
|
||||
|
||||
|
||||
class SetRecommendedModelRequest(pydantic.BaseModel):
|
||||
model_id: str
|
||||
|
||||
|
||||
class SetRecommendedModelResponse(pydantic.BaseModel):
|
||||
model: LlmModel
|
||||
previous_recommended_slug: Optional[str] = None
|
||||
message: str
|
||||
|
||||
|
||||
class RecommendedModelResponse(pydantic.BaseModel):
|
||||
model: Optional[LlmModel] = None
|
||||
slug: Optional[str] = None
|
||||
@@ -1,29 +0,0 @@
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/llm",
|
||||
tags=["llm"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models(
|
||||
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = fastapi.Query(
|
||||
default=50, ge=1, le=100, description="Number of models per page"
|
||||
),
|
||||
):
|
||||
"""List all enabled LLM models available to users."""
|
||||
return await llm_db.list_models(enabled_only=True, page=page, page_size=page_size)
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
"""List all LLM providers with their enabled models."""
|
||||
providers = await llm_db.list_providers(include_models=True, enabled_only=True)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
@@ -135,6 +135,12 @@ class GraphValidationError(ValueError):
|
||||
)
|
||||
|
||||
|
||||
class InvalidInputError(ValueError):
|
||||
"""Raised when user input validation fails (e.g., search term too long)"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(Exception):
|
||||
"""Raised when there is an error interacting with the database"""
|
||||
|
||||
|
||||
@@ -359,8 +359,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for the Agent Generator service",
|
||||
)
|
||||
agentgenerator_timeout: int = Field(
|
||||
default=120,
|
||||
description="The timeout in seconds for Agent Generator service requests",
|
||||
default=600,
|
||||
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
|
||||
)
|
||||
|
||||
enable_example_blocks: bool = Field(
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmProvider" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"defaultCredentialProvider" TEXT,
|
||||
"defaultCredentialId" TEXT,
|
||||
"defaultCredentialType" TEXT,
|
||||
"supportsTools" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"supportsReasoning" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
"supportsParallelTool" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
|
||||
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id"),
|
||||
CONSTRAINT "LlmProvider_name_key" UNIQUE ("name")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModel" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"slug" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"providerId" TEXT NOT NULL,
|
||||
"contextWindow" INTEGER NOT NULL,
|
||||
"maxOutputTokens" INTEGER,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
|
||||
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id"),
|
||||
CONSTRAINT "LlmModel_slug_key" UNIQUE ("slug")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCost" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
|
||||
"creditCost" INTEGER NOT NULL,
|
||||
"credentialProvider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"credentialType" TEXT,
|
||||
"currency" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
"llmModelId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_slug_idx" ON "LlmModel"("slug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelCost_llmModelId_idx" ON "LlmModelCost"("llmModelId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelCost_credentialProvider_idx" ON "LlmModelCost"("credentialProvider");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModelCost_llmModelId_credentialProvider_unit_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
-- Seed LLM Registry from existing hard-coded data
|
||||
-- This migration populates the LlmProvider, LlmModel, and LlmModelCost tables
|
||||
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
|
||||
|
||||
-- Insert Providers
|
||||
INSERT INTO "LlmProvider" ("id", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "supportsTools", "supportsJsonOutput", "supportsReasoning", "supportsParallelTool", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', true, true, true, true, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', true, true, true, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', true, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', true, true, false, false, '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Models (using CTEs to reference provider IDs)
|
||||
WITH provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
model_slug,
|
||||
model_display_name,
|
||||
NULL,
|
||||
p."id",
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM (VALUES
|
||||
-- OpenAI models
|
||||
('o3', 'O3', 'openai', 200000, 100000),
|
||||
('o3-mini', 'O3 Mini', 'openai', 200000, 100000),
|
||||
('o1', 'O1', 'openai', 200000, 100000),
|
||||
('o1-mini', 'O1 Mini', 'openai', 128000, 65536),
|
||||
('gpt-5-2025-08-07', 'GPT 5', 'openai', 400000, 128000),
|
||||
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 400000, 128000),
|
||||
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 400000, 128000),
|
||||
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 400000, 128000),
|
||||
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 400000, 16384),
|
||||
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 1000000, 32768),
|
||||
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 1047576, 32768),
|
||||
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 128000, 16384),
|
||||
('gpt-4o', 'GPT 4o', 'openai', 128000, 16384),
|
||||
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 128000, 4096),
|
||||
('gpt-3.5-turbo', 'GPT 3.5 Turbo', 'openai', 16385, 4096),
|
||||
-- Anthropic models
|
||||
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 200000, 32000),
|
||||
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 200000, 32000),
|
||||
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 200000, 64000),
|
||||
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 200000, 64000),
|
||||
('claude-3-7-sonnet-20250219', 'Claude 3.7 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 200000, 4096),
|
||||
-- AI/ML API models
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 32000, 8000),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 128000, 40000),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 128000, NULL),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 131000, 2000),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 128000, NULL),
|
||||
-- Groq models
|
||||
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 128000, 32768),
|
||||
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 128000, 8192),
|
||||
-- Ollama models
|
||||
('llama3.3', 'Llama 3.3', 'ollama', 8192, NULL),
|
||||
('llama3.2', 'Llama 3.2', 'ollama', 8192, NULL),
|
||||
('llama3', 'Llama 3', 'ollama', 8192, NULL),
|
||||
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 8192, NULL),
|
||||
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 32768, NULL),
|
||||
-- OpenRouter models
|
||||
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 1050000, 8192),
|
||||
('google/gemini-3-pro-preview', 'Gemini 3 Pro Preview', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 1048576, 8192),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 1048576, 8192),
|
||||
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 128000, 4096),
|
||||
('cohere/command-r-08-2024', 'Command R', 'open_router', 128000, 4096),
|
||||
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 128000, 4096),
|
||||
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 64000, 2048),
|
||||
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 163840, 163840),
|
||||
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 127000, 8000),
|
||||
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 200000, 8000),
|
||||
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 128000, 16000),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 131000, 4096),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 12288, 12288),
|
||||
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 131072, 131072),
|
||||
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 131072, 32768),
|
||||
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 300000, 5120),
|
||||
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 128000, 5120),
|
||||
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 300000, 5120),
|
||||
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 65536, 4096),
|
||||
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 4096, 4096),
|
||||
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 131072, 131072),
|
||||
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 1048576, 1000000),
|
||||
('x-ai/grok-4', 'Grok 4', 'open_router', 256000, 256000),
|
||||
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 2000000, 30000),
|
||||
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 2000000, 30000),
|
||||
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 256000, 10000),
|
||||
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 131000, 131000),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 262144, 262144),
|
||||
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 262144, 262144),
|
||||
-- Llama API models
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 128000, 4028),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 128000, 4028),
|
||||
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 128000, 4028),
|
||||
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 128000, 4028),
|
||||
-- v0 models
|
||||
('v0-1.5-md', 'v0 1.5 MD', 'v0', 128000, 64000),
|
||||
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 512000, 64000),
|
||||
('v0-1.0-md', 'v0 1.0 MD', 'v0', 128000, 64000)
|
||||
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
|
||||
JOIN provider_ids p ON p."name" = models.provider_name
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Insert Costs (using CTEs to reference model IDs)
|
||||
WITH model_ids AS (
|
||||
SELECT "id", "slug", "providerId" FROM "LlmModel"
|
||||
),
|
||||
provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
'RUN'::"LlmCostUnit",
|
||||
cost,
|
||||
p."name",
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM (VALUES
|
||||
-- OpenAI costs
|
||||
('o3', 4),
|
||||
('o3-mini', 2),
|
||||
('o1', 16),
|
||||
('o1-mini', 4),
|
||||
('gpt-5-2025-08-07', 2),
|
||||
('gpt-5.1-2025-11-13', 5),
|
||||
('gpt-5-mini-2025-08-07', 1),
|
||||
('gpt-5-nano-2025-08-07', 1),
|
||||
('gpt-5-chat-latest', 5),
|
||||
('gpt-4.1-2025-04-14', 2),
|
||||
('gpt-4.1-mini-2025-04-14', 1),
|
||||
('gpt-4o-mini', 1),
|
||||
('gpt-4o', 3),
|
||||
('gpt-4-turbo', 10),
|
||||
('gpt-3.5-turbo', 1),
|
||||
-- Anthropic costs
|
||||
('claude-opus-4-1-20250805', 21),
|
||||
('claude-opus-4-20250514', 21),
|
||||
('claude-sonnet-4-20250514', 5),
|
||||
('claude-haiku-4-5-20251001', 4),
|
||||
('claude-opus-4-5-20251101', 14),
|
||||
('claude-sonnet-4-5-20250929', 9),
|
||||
('claude-3-7-sonnet-20250219', 5),
|
||||
('claude-3-haiku-20240307', 1),
|
||||
-- AI/ML API costs
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
|
||||
-- Groq costs
|
||||
('llama-3.3-70b-versatile', 1),
|
||||
('llama-3.1-8b-instant', 1),
|
||||
-- Ollama costs
|
||||
('llama3.3', 1),
|
||||
('llama3.2', 1),
|
||||
('llama3', 1),
|
||||
('llama3.1:405b', 1),
|
||||
('dolphin-mistral:latest', 1),
|
||||
-- OpenRouter costs
|
||||
('google/gemini-2.5-pro-preview-03-25', 4),
|
||||
('google/gemini-3-pro-preview', 5),
|
||||
('mistralai/mistral-nemo', 1),
|
||||
('cohere/command-r-08-2024', 1),
|
||||
('cohere/command-r-plus-08-2024', 3),
|
||||
('deepseek/deepseek-chat', 2),
|
||||
('perplexity/sonar', 1),
|
||||
('perplexity/sonar-pro', 5),
|
||||
('perplexity/sonar-deep-research', 10),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 1),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 1),
|
||||
('amazon/nova-lite-v1', 1),
|
||||
('amazon/nova-micro-v1', 1),
|
||||
('amazon/nova-pro-v1', 1),
|
||||
('microsoft/wizardlm-2-8x22b', 1),
|
||||
('gryphe/mythomax-l2-13b', 1),
|
||||
('meta-llama/llama-4-scout', 1),
|
||||
('meta-llama/llama-4-maverick', 1),
|
||||
('x-ai/grok-4', 9),
|
||||
('x-ai/grok-4-fast', 1),
|
||||
('x-ai/grok-4.1-fast', 1),
|
||||
('x-ai/grok-code-fast-1', 1),
|
||||
('moonshotai/kimi-k2', 1),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 1),
|
||||
('qwen/qwen3-coder', 9),
|
||||
('google/gemini-2.5-flash', 1),
|
||||
('google/gemini-2.0-flash-001', 1),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 1),
|
||||
('google/gemini-2.0-flash-lite-001', 1),
|
||||
('deepseek/deepseek-r1-0528', 1),
|
||||
('openai/gpt-oss-120b', 1),
|
||||
('openai/gpt-oss-20b', 1),
|
||||
-- Llama API costs
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
|
||||
('Llama-3.3-8B-Instruct', 1),
|
||||
('Llama-3.3-70B-Instruct', 1),
|
||||
-- v0 costs
|
||||
('v0-1.5-md', 1),
|
||||
('v0-1.5-lg', 2),
|
||||
('v0-1.0-md', 1)
|
||||
) AS costs(model_slug, cost)
|
||||
JOIN model_ids m ON m."slug" = costs.model_slug
|
||||
JOIN provider_ids p ON p."id" = m."providerId"
|
||||
ON CONFLICT ("llmModelId", "credentialProvider", "unit") DO NOTHING;
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelMigration" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"sourceModelSlug" TEXT NOT NULL,
|
||||
"targetModelSlug" TEXT NOT NULL,
|
||||
"reason" TEXT,
|
||||
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
|
||||
"nodeCount" INTEGER NOT NULL,
|
||||
"customCreditCost" INTEGER,
|
||||
"isReverted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"revertedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_sourceModelSlug_idx" ON "LlmModelMigration"("sourceModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_isReverted_idx" ON "LlmModelMigration"("isReverted");
|
||||
@@ -1,127 +0,0 @@
|
||||
-- Add LlmModelCreator table
|
||||
-- Creator represents who made/trained the model (e.g., OpenAI, Meta)
|
||||
-- This is distinct from Provider who hosts/serves the model (e.g., OpenRouter)
|
||||
|
||||
-- Create the LlmModelCreator table
|
||||
CREATE TABLE "LlmModelCreator" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"websiteUrl" TEXT,
|
||||
"logoUrl" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- Create unique index on name
|
||||
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
|
||||
|
||||
-- Add creatorId column to LlmModel
|
||||
ALTER TABLE "LlmModel" ADD COLUMN "creatorId" TEXT;
|
||||
|
||||
-- Add foreign key constraint
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
|
||||
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- Create index on creatorId
|
||||
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- Seed creators based on known model creators
|
||||
INSERT INTO "LlmModelCreator" ("id", "updatedAt", "name", "displayName", "description", "websiteUrl", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT models', 'https://openai.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude models', 'https://anthropic.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama models', 'https://ai.meta.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini models', 'https://deepmind.google', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'mistral', 'Mistral AI', 'Creator of Mistral models', 'https://mistral.ai', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command models', 'https://cohere.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek models', 'https://deepseek.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar models', 'https://perplexity.ai', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'qwen', 'Qwen (Alibaba)', 'Creator of Qwen models', 'https://qwenlm.github.io', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova models', 'https://aws.amazon.com/bedrock', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of WizardLM models', 'https://microsoft.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'moonshot', 'Moonshot AI', 'Creator of Kimi models', 'https://moonshot.cn', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nous_research', 'Nous Research', 'Creator of Hermes models', 'https://nousresearch.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 models', 'https://vercel.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cognitive_computations', 'Cognitive Computations', 'Creator of Dolphin models', 'https://erichartford.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', '{}')
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Update existing models with their creators
|
||||
-- OpenAI models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'openai')
|
||||
WHERE "slug" LIKE 'gpt-%' OR "slug" LIKE 'o1%' OR "slug" LIKE 'o3%' OR "slug" LIKE 'openai/%';
|
||||
|
||||
-- Anthropic models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'anthropic')
|
||||
WHERE "slug" LIKE 'claude-%';
|
||||
|
||||
-- Meta/Llama models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'meta')
|
||||
WHERE "slug" LIKE 'llama%' OR "slug" LIKE 'Llama%' OR "slug" LIKE 'meta-llama/%' OR "slug" LIKE '%/llama-%';
|
||||
|
||||
-- Google models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'google')
|
||||
WHERE "slug" LIKE 'google/%' OR "slug" LIKE 'gemini%';
|
||||
|
||||
-- Mistral models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'mistral')
|
||||
WHERE "slug" LIKE 'mistral%' OR "slug" LIKE 'mistralai/%';
|
||||
|
||||
-- Cohere models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cohere')
|
||||
WHERE "slug" LIKE 'cohere/%' OR "slug" LIKE 'command-%';
|
||||
|
||||
-- DeepSeek models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'deepseek')
|
||||
WHERE "slug" LIKE 'deepseek/%' OR "slug" LIKE 'deepseek-%';
|
||||
|
||||
-- Perplexity models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'perplexity')
|
||||
WHERE "slug" LIKE 'perplexity/%' OR "slug" LIKE 'sonar%';
|
||||
|
||||
-- Qwen models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'qwen')
|
||||
WHERE "slug" LIKE 'Qwen/%' OR "slug" LIKE 'qwen/%';
|
||||
|
||||
-- xAI/Grok models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'xai')
|
||||
WHERE "slug" LIKE 'x-ai/%' OR "slug" LIKE 'grok%';
|
||||
|
||||
-- Amazon models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'amazon')
|
||||
WHERE "slug" LIKE 'amazon/%' OR "slug" LIKE 'nova-%';
|
||||
|
||||
-- Microsoft models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'microsoft')
|
||||
WHERE "slug" LIKE 'microsoft/%' OR "slug" LIKE 'wizardlm%';
|
||||
|
||||
-- Moonshot models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'moonshot')
|
||||
WHERE "slug" LIKE 'moonshotai/%' OR "slug" LIKE 'kimi%';
|
||||
|
||||
-- NVIDIA models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nvidia')
|
||||
WHERE "slug" LIKE 'nvidia/%' OR "slug" LIKE '%nemotron%';
|
||||
|
||||
-- Nous Research models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nous_research')
|
||||
WHERE "slug" LIKE 'nousresearch/%' OR "slug" LIKE 'hermes%';
|
||||
|
||||
-- Vercel/v0 models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'vercel')
|
||||
WHERE "slug" LIKE 'v0-%';
|
||||
|
||||
-- Dolphin models (Cognitive Computations / Eric Hartford)
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cognitive_computations')
|
||||
WHERE "slug" LIKE 'dolphin-%';
|
||||
|
||||
-- Gryphe models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'gryphe')
|
||||
WHERE "slug" LIKE 'gryphe/%' OR "slug" LIKE 'mythomax%';
|
||||
@@ -1,4 +0,0 @@
|
||||
-- CreateIndex
|
||||
-- Index for efficient LLM model lookups on AgentNode.constantInput->>'model'
|
||||
-- This improves performance of model migration queries in the LLM registry
|
||||
CREATE INDEX "AgentNode_constantInput_model_idx" ON "AgentNode" ((("constantInput"->>'model')));
|
||||
@@ -1,52 +0,0 @@
|
||||
-- Add GPT-5.2 model and update O3 slug
|
||||
-- This migration adds the new GPT-5.2 model added in dev branch
|
||||
|
||||
-- Update O3 slug to match dev branch format
|
||||
UPDATE "LlmModel"
|
||||
SET "slug" = 'o3-2025-04-16'
|
||||
WHERE "slug" = 'o3';
|
||||
|
||||
-- Update cost reference for O3 if needed
|
||||
-- (costs are linked by model ID, so no update needed)
|
||||
|
||||
-- Add GPT-5.2 model
|
||||
WITH provider_id AS (
|
||||
SELECT "id" FROM "LlmProvider" WHERE "name" = 'openai'
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
'gpt-5.2-2025-12-11',
|
||||
'GPT 5.2',
|
||||
'OpenAI GPT-5.2 model',
|
||||
p."id",
|
||||
400000,
|
||||
128000,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM provider_id p
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Add cost for GPT-5.2
|
||||
WITH model_id AS (
|
||||
SELECT m."id", p."name" as provider_name
|
||||
FROM "LlmModel" m
|
||||
JOIN "LlmProvider" p ON p."id" = m."providerId"
|
||||
WHERE m."slug" = 'gpt-5.2-2025-12-11'
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
'RUN'::"LlmCostUnit",
|
||||
3, -- Same cost tier as GPT-5.1
|
||||
m.provider_name,
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM model_id m
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM "LlmModelCost" c WHERE c."llmModelId" = m."id"
|
||||
);
|
||||
@@ -1,11 +0,0 @@
|
||||
-- Add isRecommended field to LlmModel table
|
||||
-- This allows admins to mark a model as the recommended default
|
||||
|
||||
ALTER TABLE "LlmModel" ADD COLUMN "isRecommended" BOOLEAN NOT NULL DEFAULT false;
|
||||
|
||||
-- Set gpt-4o-mini as the default recommended model (if it exists)
|
||||
UPDATE "LlmModel" SET "isRecommended" = true WHERE "slug" = 'gpt-4o-mini' AND "isEnabled" = true;
|
||||
|
||||
-- Create unique partial index to enforce only one recommended model at the database level
|
||||
-- This prevents multiple rows from having isRecommended = true
|
||||
CREATE UNIQUE INDEX "LlmModel_single_recommended_idx" ON "LlmModel" ("isRecommended") WHERE "isRecommended" = true;
|
||||
@@ -1,61 +0,0 @@
|
||||
-- Add new columns to LlmModel table for extended model metadata
|
||||
-- These columns support the LLM Picker UI enhancements
|
||||
|
||||
-- Add priceTier column: 1=cheapest, 2=medium, 3=expensive
|
||||
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "priceTier" INTEGER NOT NULL DEFAULT 1;
|
||||
|
||||
-- Add creatorId column for model creator relationship (if not exists)
|
||||
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "creatorId" TEXT;
|
||||
|
||||
-- Add isRecommended column (if not exists)
|
||||
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "isRecommended" BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Add index on creatorId if not exists
|
||||
CREATE INDEX IF NOT EXISTS "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- Add foreign key for creatorId if not exists
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'LlmModel_creatorId_fkey') THEN
|
||||
-- Only add FK if LlmModelCreator table exists
|
||||
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'LlmModelCreator') THEN
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
|
||||
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
END IF;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
-- Update priceTier values for existing models based on original MODEL_METADATA
|
||||
-- Tier 1 = cheapest, Tier 2 = medium, Tier 3 = expensive
|
||||
|
||||
-- OpenAI models
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o3';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'o3-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'o1';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o1-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-5.2';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5.1';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-nano';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5-chat-latest';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" LIKE 'gpt-4.1%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-4o-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-4o';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-4-turbo';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-3.5-turbo';
|
||||
|
||||
-- Anthropic models
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude-opus%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude-sonnet%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude%-4-5-sonnet%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude%-haiku%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'claude-3-haiku-20240307';
|
||||
|
||||
-- OpenRouter models - Pro/expensive tiers
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'google/gemini%-pro%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%command-r-plus%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%sonar-pro%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%sonar-deep-research%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'x-ai/grok-4';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%qwen3-coder%';
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterEnum
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'VISIT_COPILOT';
|
||||
@@ -81,6 +81,7 @@ enum OnboardingStep {
|
||||
AGENT_INPUT
|
||||
CONGRATS
|
||||
// First Wins
|
||||
VISIT_COPILOT
|
||||
GET_RESULTS
|
||||
MARKETPLACE_VISIT
|
||||
MARKETPLACE_ADD_AGENT
|
||||
@@ -1094,153 +1095,6 @@ enum APIKeyStatus {
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
///////////// LLM REGISTRY AND BILLING DATA /////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// LlmCostUnit: Defines how LLM MODEL costs are calculated (per run or per token).
|
||||
// This is distinct from BlockCostType (in backend/data/block.py) which defines
|
||||
// how BLOCK EXECUTION costs are calculated (per run, per byte, or per second).
|
||||
// LlmCostUnit is for pricing individual LLM model API calls in the registry,
|
||||
// while BlockCostType is for billing platform block executions.
|
||||
enum LlmCostUnit {
|
||||
RUN
|
||||
TOKENS
|
||||
}
|
||||
|
||||
model LlmModelCreator {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique // e.g., "openai", "anthropic", "meta"
|
||||
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
|
||||
description String?
|
||||
websiteUrl String? // Link to creator's website
|
||||
logoUrl String? // URL to creator's logo
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
}
|
||||
|
||||
model LlmProvider {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
defaultCredentialProvider String?
|
||||
defaultCredentialId String?
|
||||
defaultCredentialType String?
|
||||
|
||||
supportsTools Boolean @default(true)
|
||||
supportsJsonOutput Boolean @default(true)
|
||||
supportsReasoning Boolean @default(false)
|
||||
supportsParallelTool Boolean @default(false)
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
}
|
||||
|
||||
model LlmModel {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
slug String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
providerId String
|
||||
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
|
||||
|
||||
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
|
||||
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
|
||||
creatorId String?
|
||||
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
|
||||
|
||||
contextWindow Int
|
||||
maxOutputTokens Int?
|
||||
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive
|
||||
isEnabled Boolean @default(true)
|
||||
isRecommended Boolean @default(false)
|
||||
|
||||
capabilities Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
|
||||
Costs LlmModelCost[]
|
||||
|
||||
@@index([providerId, isEnabled])
|
||||
@@index([creatorId])
|
||||
@@index([slug])
|
||||
}
|
||||
|
||||
model LlmModelCost {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
unit LlmCostUnit @default(RUN)
|
||||
|
||||
creditCost Int
|
||||
|
||||
credentialProvider String
|
||||
credentialId String?
|
||||
credentialType String?
|
||||
currency String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
llmModelId String
|
||||
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([llmModelId, credentialProvider, unit])
|
||||
@@index([llmModelId])
|
||||
@@index([credentialProvider])
|
||||
}
|
||||
|
||||
// Tracks model migrations for revert capability
|
||||
// When a model is disabled with migration, we record which nodes were affected
|
||||
// so they can be reverted when the original model is back online
|
||||
model LlmModelMigration {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
sourceModelSlug String // The original model that was disabled
|
||||
targetModelSlug String // The model workflows were migrated to
|
||||
reason String? // Why the migration happened (e.g., "Provider outage")
|
||||
|
||||
// Track affected nodes as JSON array of node IDs
|
||||
// Format: ["node-uuid-1", "node-uuid-2", ...]
|
||||
migratedNodeIds Json @default("[]")
|
||||
nodeCount Int // Number of nodes migrated
|
||||
|
||||
// Custom pricing override for migrated workflows during the migration period.
|
||||
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
|
||||
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
|
||||
// to avoid billing surprises, or offer a discount during the transition.
|
||||
//
|
||||
// IMPORTANT: This field is intended for integration with the billing system.
|
||||
// When billing calculates costs for nodes affected by this migration, it should
|
||||
// check if customCreditCost is set and use it instead of the target model's cost.
|
||||
// If null, the target model's normal cost applies.
|
||||
//
|
||||
// TODO: Integrate with billing system to apply this override during cost calculation.
|
||||
customCreditCost Int?
|
||||
|
||||
// Revert tracking
|
||||
isReverted Boolean @default(false)
|
||||
revertedAt DateTime?
|
||||
|
||||
@@index([sourceModelSlug])
|
||||
@@index([targetModelSlug])
|
||||
@@index([isReverted])
|
||||
}
|
||||
////////////// OAUTH PROVIDER TABLES //////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
"defaults"
|
||||
],
|
||||
"dependencies": {
|
||||
"@ai-sdk/react": "3.0.61",
|
||||
"@faker-js/faker": "10.0.0",
|
||||
"@hookform/resolvers": "5.2.2",
|
||||
"@next/third-parties": "15.4.6",
|
||||
@@ -60,6 +61,10 @@
|
||||
"@rjsf/utils": "6.1.2",
|
||||
"@rjsf/validator-ajv8": "6.1.2",
|
||||
"@sentry/nextjs": "10.27.0",
|
||||
"@streamdown/cjk": "1.0.1",
|
||||
"@streamdown/code": "1.0.1",
|
||||
"@streamdown/math": "1.0.1",
|
||||
"@streamdown/mermaid": "1.0.1",
|
||||
"@supabase/ssr": "0.7.0",
|
||||
"@supabase/supabase-js": "2.78.0",
|
||||
"@tanstack/react-query": "5.90.6",
|
||||
@@ -68,6 +73,7 @@
|
||||
"@vercel/analytics": "1.5.0",
|
||||
"@vercel/speed-insights": "1.2.0",
|
||||
"@xyflow/react": "12.9.2",
|
||||
"ai": "6.0.59",
|
||||
"boring-avatars": "1.11.2",
|
||||
"class-variance-authority": "0.7.1",
|
||||
"clsx": "2.1.1",
|
||||
@@ -112,9 +118,11 @@
|
||||
"remark-math": "6.0.0",
|
||||
"shepherd.js": "14.5.1",
|
||||
"sonner": "2.0.7",
|
||||
"streamdown": "2.1.0",
|
||||
"tailwind-merge": "2.6.0",
|
||||
"tailwind-scrollbar": "3.1.0",
|
||||
"tailwindcss-animate": "1.0.7",
|
||||
"use-stick-to-bottom": "1.1.2",
|
||||
"uuid": "11.1.0",
|
||||
"vaul": "1.1.2",
|
||||
"zod": "3.25.76",
|
||||
|
||||
1161
autogpt_platform/frontend/pnpm-lock.yaml
generated
1161
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -2,8 +2,9 @@
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { resolveResponse, shouldShowOnboarding } from "@/app/api/helpers";
|
||||
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
|
||||
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
|
||||
export default function OnboardingPage() {
|
||||
const router = useRouter();
|
||||
@@ -11,10 +12,13 @@ export default function OnboardingPage() {
|
||||
useEffect(() => {
|
||||
async function redirectToStep() {
|
||||
try {
|
||||
// Check if onboarding is enabled
|
||||
const isEnabled = await shouldShowOnboarding();
|
||||
if (!isEnabled) {
|
||||
router.replace("/");
|
||||
// Check if onboarding is enabled (also gets chat flag for redirect)
|
||||
const { shouldShowOnboarding, isChatEnabled } =
|
||||
await getOnboardingStatus();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
if (!shouldShowOnboarding) {
|
||||
router.replace(homepageRoute);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -22,7 +26,7 @@ export default function OnboardingPage() {
|
||||
|
||||
// Handle completed onboarding
|
||||
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
||||
router.replace("/");
|
||||
router.replace(homepageRoute);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
import { Cpu } from "@phosphor-icons/react";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
@@ -29,11 +26,6 @@ const sidebarLinkGroups = [
|
||||
href: "/admin/execution-analytics",
|
||||
icon: <FileText className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "LLM Registry",
|
||||
href: "/admin/llms",
|
||||
icon: <Cpu size={24} />,
|
||||
},
|
||||
{
|
||||
text: "Admin User Management",
|
||||
href: "/admin/settings",
|
||||
|
||||
@@ -1,493 +0,0 @@
|
||||
"use server";
|
||||
|
||||
import { revalidatePath } from "next/cache";
|
||||
|
||||
// Generated API functions
|
||||
import {
|
||||
getV2ListLlmProviders,
|
||||
postV2CreateLlmProvider,
|
||||
patchV2UpdateLlmProvider,
|
||||
deleteV2DeleteLlmProvider,
|
||||
getV2ListLlmModels,
|
||||
postV2CreateLlmModel,
|
||||
patchV2UpdateLlmModel,
|
||||
patchV2ToggleLlmModelAvailability,
|
||||
deleteV2DeleteLlmModelAndMigrateWorkflows,
|
||||
getV2GetModelUsageCount,
|
||||
getV2ListModelMigrations,
|
||||
postV2RevertAModelMigration,
|
||||
getV2ListModelCreators,
|
||||
postV2CreateModelCreator,
|
||||
patchV2UpdateModelCreator,
|
||||
deleteV2DeleteModelCreator,
|
||||
postV2SetRecommendedModel,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
|
||||
// Generated types
|
||||
import type { LlmProvidersResponse } from "@/app/api/__generated__/models/llmProvidersResponse";
|
||||
import type { LlmModelsResponse } from "@/app/api/__generated__/models/llmModelsResponse";
|
||||
import type { UpsertLlmProviderRequest } from "@/app/api/__generated__/models/upsertLlmProviderRequest";
|
||||
import type { CreateLlmModelRequest } from "@/app/api/__generated__/models/createLlmModelRequest";
|
||||
import type { UpdateLlmModelRequest } from "@/app/api/__generated__/models/updateLlmModelRequest";
|
||||
import type { ToggleLlmModelRequest } from "@/app/api/__generated__/models/toggleLlmModelRequest";
|
||||
import type { LlmMigrationsResponse } from "@/app/api/__generated__/models/llmMigrationsResponse";
|
||||
import type { LlmCreatorsResponse } from "@/app/api/__generated__/models/llmCreatorsResponse";
|
||||
import type { UpsertLlmCreatorRequest } from "@/app/api/__generated__/models/upsertLlmCreatorRequest";
|
||||
import type { LlmModelUsageResponse } from "@/app/api/__generated__/models/llmModelUsageResponse";
|
||||
import { LlmCostUnit } from "@/app/api/__generated__/models/llmCostUnit";
|
||||
|
||||
const ADMIN_LLM_PATH = "/admin/llms";
|
||||
|
||||
// =============================================================================
|
||||
// Utilities
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Extracts and validates a required string field from FormData.
|
||||
* Throws an error if the field is missing or empty.
|
||||
*/
|
||||
function getRequiredFormField(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): string {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = raw ? String(raw).trim() : "";
|
||||
if (!value) {
|
||||
throw new Error(`${displayName || fieldName} is required`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts and validates a required positive number field from FormData.
|
||||
* Throws an error if the field is missing, empty, or not a positive number.
|
||||
*/
|
||||
function getRequiredPositiveNumber(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): number {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = Number(raw);
|
||||
if (raw === null || raw === "" || !Number.isFinite(value) || value <= 0) {
|
||||
throw new Error(`${displayName || fieldName} must be a positive number`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts and validates a required number field from FormData.
|
||||
* Throws an error if the field is missing, empty, or not a finite number.
|
||||
*/
|
||||
function getRequiredNumber(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): number {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = Number(raw);
|
||||
if (raw === null || raw === "" || !Number.isFinite(value)) {
|
||||
throw new Error(`${displayName || fieldName} is required`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Provider Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmProviders(): Promise<LlmProvidersResponse> {
|
||||
const response = await getV2ListLlmProviders({ include_models: true });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch LLM providers");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createLlmProviderAction(formData: FormData) {
|
||||
const payload: UpsertLlmProviderRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
default_credential_provider: formData.get("default_credential_provider")
|
||||
? String(formData.get("default_credential_provider")).trim()
|
||||
: undefined,
|
||||
default_credential_id: formData.get("default_credential_id")
|
||||
? String(formData.get("default_credential_id")).trim()
|
||||
: undefined,
|
||||
default_credential_type: formData.get("default_credential_type")
|
||||
? String(formData.get("default_credential_type")).trim()
|
||||
: "api_key",
|
||||
supports_tools: formData.getAll("supports_tools").includes("on"),
|
||||
supports_json_output: formData
|
||||
.getAll("supports_json_output")
|
||||
.includes("on"),
|
||||
supports_reasoning: formData.getAll("supports_reasoning").includes("on"),
|
||||
supports_parallel_tool: formData
|
||||
.getAll("supports_parallel_tool")
|
||||
.includes("on"),
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await postV2CreateLlmProvider(payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create LLM provider");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmProviderAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const providerId = getRequiredFormField(
|
||||
formData,
|
||||
"provider_id",
|
||||
"Provider id",
|
||||
);
|
||||
|
||||
const response = await deleteV2DeleteLlmProvider(providerId);
|
||||
if (response.status !== 200) {
|
||||
const errorData = response.data as { detail?: string };
|
||||
throw new Error(errorData?.detail || "Failed to delete provider");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmProviderAction(formData: FormData) {
|
||||
const providerId = getRequiredFormField(
|
||||
formData,
|
||||
"provider_id",
|
||||
"Provider id",
|
||||
);
|
||||
|
||||
const payload: UpsertLlmProviderRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
default_credential_provider: formData.get("default_credential_provider")
|
||||
? String(formData.get("default_credential_provider")).trim()
|
||||
: undefined,
|
||||
default_credential_id: formData.get("default_credential_id")
|
||||
? String(formData.get("default_credential_id")).trim()
|
||||
: undefined,
|
||||
default_credential_type: formData.get("default_credential_type")
|
||||
? String(formData.get("default_credential_type")).trim()
|
||||
: "api_key",
|
||||
supports_tools: formData.getAll("supports_tools").includes("on"),
|
||||
supports_json_output: formData
|
||||
.getAll("supports_json_output")
|
||||
.includes("on"),
|
||||
supports_reasoning: formData.getAll("supports_reasoning").includes("on"),
|
||||
supports_parallel_tool: formData
|
||||
.getAll("supports_parallel_tool")
|
||||
.includes("on"),
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await patchV2UpdateLlmProvider(providerId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to update LLM provider");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Model Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmModels(): Promise<LlmModelsResponse> {
|
||||
const response = await getV2ListLlmModels();
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch LLM models");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createLlmModelAction(formData: FormData) {
|
||||
const providerId = getRequiredFormField(formData, "provider_id", "Provider");
|
||||
const creatorId = formData.get("creator_id");
|
||||
const contextWindow = getRequiredPositiveNumber(
|
||||
formData,
|
||||
"context_window",
|
||||
"Context window",
|
||||
);
|
||||
const creditCost = getRequiredNumber(formData, "credit_cost", "Credit cost");
|
||||
|
||||
// Fetch provider to get default credentials
|
||||
const providersResponse = await getV2ListLlmProviders({
|
||||
include_models: false,
|
||||
});
|
||||
if (providersResponse.status !== 200) {
|
||||
throw new Error("Failed to fetch providers");
|
||||
}
|
||||
const provider = providersResponse.data.providers.find(
|
||||
(p) => p.id === providerId,
|
||||
);
|
||||
|
||||
if (!provider) {
|
||||
throw new Error("Provider not found");
|
||||
}
|
||||
|
||||
const payload: CreateLlmModelRequest = {
|
||||
slug: String(formData.get("slug") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
provider_id: providerId,
|
||||
creator_id: creatorId ? String(creatorId) : undefined,
|
||||
context_window: contextWindow,
|
||||
max_output_tokens: formData.get("max_output_tokens")
|
||||
? Number(formData.get("max_output_tokens"))
|
||||
: undefined,
|
||||
is_enabled: formData.getAll("is_enabled").includes("on"),
|
||||
capabilities: {},
|
||||
metadata: {},
|
||||
costs: [
|
||||
{
|
||||
unit: (formData.get("unit") as LlmCostUnit) || LlmCostUnit.RUN,
|
||||
credit_cost: creditCost,
|
||||
credential_provider:
|
||||
provider.default_credential_provider || provider.name,
|
||||
credential_id: provider.default_credential_id || undefined,
|
||||
credential_type: provider.default_credential_type || "api_key",
|
||||
metadata: {},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await postV2CreateLlmModel(payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create LLM model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmModelAction(formData: FormData) {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
const creatorId = formData.get("creator_id");
|
||||
|
||||
const payload: UpdateLlmModelRequest = {
|
||||
display_name: formData.get("display_name")
|
||||
? String(formData.get("display_name"))
|
||||
: undefined,
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
provider_id: formData.get("provider_id")
|
||||
? String(formData.get("provider_id"))
|
||||
: undefined,
|
||||
creator_id: creatorId ? String(creatorId) : undefined,
|
||||
context_window: formData.get("context_window")
|
||||
? Number(formData.get("context_window"))
|
||||
: undefined,
|
||||
max_output_tokens: formData.get("max_output_tokens")
|
||||
? Number(formData.get("max_output_tokens"))
|
||||
: undefined,
|
||||
is_enabled: formData.has("is_enabled")
|
||||
? formData.getAll("is_enabled").includes("on")
|
||||
: undefined,
|
||||
costs: formData.get("credit_cost")
|
||||
? [
|
||||
{
|
||||
unit: (formData.get("unit") as LlmCostUnit) || LlmCostUnit.RUN,
|
||||
credit_cost: Number(formData.get("credit_cost")),
|
||||
credential_provider: String(
|
||||
formData.get("credential_provider") || "",
|
||||
).trim(),
|
||||
credential_id: formData.get("credential_id")
|
||||
? String(formData.get("credential_id"))
|
||||
: undefined,
|
||||
credential_type: formData.get("credential_type")
|
||||
? String(formData.get("credential_type"))
|
||||
: undefined,
|
||||
metadata: {},
|
||||
},
|
||||
]
|
||||
: undefined,
|
||||
};
|
||||
|
||||
const response = await patchV2UpdateLlmModel(modelId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to update LLM model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function toggleLlmModelAction(formData: FormData): Promise<void> {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
const shouldEnable = formData.get("is_enabled") === "true";
|
||||
const migrateToSlug = formData.get("migrate_to_slug");
|
||||
const migrationReason = formData.get("migration_reason");
|
||||
const customCreditCost = formData.get("custom_credit_cost");
|
||||
|
||||
const payload: ToggleLlmModelRequest = {
|
||||
is_enabled: shouldEnable,
|
||||
migrate_to_slug: migrateToSlug ? String(migrateToSlug) : undefined,
|
||||
migration_reason: migrationReason ? String(migrationReason) : undefined,
|
||||
custom_credit_cost: customCreditCost ? Number(customCreditCost) : undefined,
|
||||
};
|
||||
|
||||
const response = await patchV2ToggleLlmModelAvailability(modelId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to toggle LLM model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmModelAction(formData: FormData): Promise<void> {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
const rawReplacement = formData.get("replacement_model_slug");
|
||||
const replacementModelSlug =
|
||||
rawReplacement && String(rawReplacement).trim()
|
||||
? String(rawReplacement).trim()
|
||||
: undefined;
|
||||
|
||||
const response = await deleteV2DeleteLlmModelAndMigrateWorkflows(modelId, {
|
||||
replacement_model_slug: replacementModelSlug,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to delete model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function fetchLlmModelUsage(
|
||||
modelId: string,
|
||||
): Promise<LlmModelUsageResponse> {
|
||||
const response = await getV2GetModelUsageCount(modelId);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch model usage");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Migration Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmMigrations(
|
||||
includeReverted: boolean = false,
|
||||
): Promise<LlmMigrationsResponse> {
|
||||
const response = await getV2ListModelMigrations({
|
||||
include_reverted: includeReverted,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch migrations");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function revertLlmMigrationAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const migrationId = getRequiredFormField(
|
||||
formData,
|
||||
"migration_id",
|
||||
"Migration id",
|
||||
);
|
||||
|
||||
const response = await postV2RevertAModelMigration(migrationId, null);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to revert migration");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Creator Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmCreators(): Promise<LlmCreatorsResponse> {
|
||||
const response = await getV2ListModelCreators();
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch creators");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const payload: UpsertLlmCreatorRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
website_url: formData.get("website_url")
|
||||
? String(formData.get("website_url")).trim()
|
||||
: undefined,
|
||||
logo_url: formData.get("logo_url")
|
||||
? String(formData.get("logo_url")).trim()
|
||||
: undefined,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await postV2CreateModelCreator(payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create creator");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const creatorId = getRequiredFormField(formData, "creator_id", "Creator id");
|
||||
|
||||
const payload: UpsertLlmCreatorRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
website_url: formData.get("website_url")
|
||||
? String(formData.get("website_url")).trim()
|
||||
: undefined,
|
||||
logo_url: formData.get("logo_url")
|
||||
? String(formData.get("logo_url")).trim()
|
||||
: undefined,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await patchV2UpdateModelCreator(creatorId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to update creator");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const creatorId = getRequiredFormField(formData, "creator_id", "Creator id");
|
||||
|
||||
const response = await deleteV2DeleteModelCreator(creatorId);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to delete creator");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Recommended Model Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function setRecommendedModelAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
|
||||
const response = await postV2SetRecommendedModel({ model_id: modelId });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to set recommended model");
|
||||
}
|
||||
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { createLlmCreatorAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function AddCreatorModal() {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to create creator");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "512px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Creator
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Add a new model creator (the organization that made/trained the
|
||||
model).
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Name (slug) <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Lowercase identifier (e.g., openai, meta, anthropic)
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={2}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Creator of GPT models..."
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="website_url"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Website URL
|
||||
</label>
|
||||
<input
|
||||
id="website_url"
|
||||
name="website_url"
|
||||
type="url"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="https://openai.com"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Add Creator"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,314 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import { createLlmModelAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
interface Props {
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}
|
||||
|
||||
export function AddModelModal({ providers, creators }: Props) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [selectedCreatorId, setSelectedCreatorId] = useState("");
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to create model");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
// When provider changes, auto-select matching creator if one exists
|
||||
function handleProviderChange(providerId: string) {
|
||||
const provider = providers.find((p) => p.id === providerId);
|
||||
if (provider) {
|
||||
// Find creator with same name as provider (e.g., "openai" -> "openai")
|
||||
const matchingCreator = creators.find((c) => c.name === provider.name);
|
||||
if (matchingCreator) {
|
||||
setSelectedCreatorId(matchingCreator.id);
|
||||
} else {
|
||||
// No matching creator (e.g., OpenRouter hosts other creators' models)
|
||||
setSelectedCreatorId("");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Model"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Model
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Register a new model slug, metadata, and pricing.
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core model details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="slug"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Model Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="slug"
|
||||
required
|
||||
name="slug"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="gpt-4.1-mini-2025-04-14"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="GPT 4.1 Mini"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Model Configuration */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Model Configuration
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Model capabilities and limits
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="provider_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<select
|
||||
id="provider_id"
|
||||
required
|
||||
name="provider_id"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
defaultValue=""
|
||||
onChange={(e) => handleProviderChange(e.target.value)}
|
||||
>
|
||||
<option value="" disabled>
|
||||
Select provider
|
||||
</option>
|
||||
{providers.map((provider) => (
|
||||
<option key={provider.id} value={provider.id}>
|
||||
{provider.display_name} ({provider.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Who hosts/serves the model
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="creator_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Creator
|
||||
</label>
|
||||
<select
|
||||
id="creator_id"
|
||||
name="creator_id"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
value={selectedCreatorId}
|
||||
onChange={(e) => setSelectedCreatorId(e.target.value)}
|
||||
>
|
||||
<option value="">No creator selected</option>
|
||||
{creators.map((creator) => (
|
||||
<option key={creator.id} value={creator.id}>
|
||||
{creator.display_name} ({creator.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Who made/trained the model (e.g., OpenAI, Meta)
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="context_window"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Context Window <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="context_window"
|
||||
required
|
||||
type="number"
|
||||
name="context_window"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="128000"
|
||||
min={1}
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="max_output_tokens"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Max Output Tokens
|
||||
</label>
|
||||
<input
|
||||
id="max_output_tokens"
|
||||
type="number"
|
||||
name="max_output_tokens"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="16384"
|
||||
min={1}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Pricing */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">Pricing</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credit cost per run (credentials are managed via the provider)
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-1">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="credit_cost"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credit Cost <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="credit_cost"
|
||||
required
|
||||
type="number"
|
||||
name="credit_cost"
|
||||
step="1"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="5"
|
||||
min={0}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credit cost is always in platform credits. Credentials are
|
||||
inherited from the selected provider.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Enabled Toggle */}
|
||||
<div className="flex items-center gap-3 border-t border-border pt-6">
|
||||
<input type="hidden" name="is_enabled" value="off" />
|
||||
<input
|
||||
id="is_enabled"
|
||||
type="checkbox"
|
||||
name="is_enabled"
|
||||
defaultChecked
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor="is_enabled"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Enabled by default
|
||||
</label>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Save Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { createLlmProviderAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function AddProviderModal() {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to create provider",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Provider
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Define a new upstream provider and default credential information.
|
||||
</div>
|
||||
|
||||
{/* Setup Instructions */}
|
||||
<div className="mb-6 rounded-lg border border-primary/30 bg-primary/5 p-4">
|
||||
<div className="space-y-2">
|
||||
<h4 className="text-sm font-semibold text-foreground">
|
||||
Before Adding a Provider
|
||||
</h4>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
To use a new provider, you must first configure its credentials in
|
||||
the backend:
|
||||
</p>
|
||||
<ol className="list-inside list-decimal space-y-1 text-xs text-muted-foreground">
|
||||
<li>
|
||||
Add the credential to{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
backend/integrations/credentials_store.py
|
||||
</code>{" "}
|
||||
with a UUID, provider name, and settings secret reference
|
||||
</li>
|
||||
<li>
|
||||
Add it to the{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>{" "}
|
||||
dictionary in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
backend/data/block_cost_config.py
|
||||
</code>
|
||||
</li>
|
||||
<li>
|
||||
Use the <strong>same provider name</strong> in the
|
||||
"Credential Provider" field below that matches the key
|
||||
in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</li>
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core provider details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="e.g. openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Default Credentials */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Default Credentials
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credential provider name that matches the key in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_provider"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Provider <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_provider"
|
||||
name="default_credential_provider"
|
||||
required
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
<strong>Important:</strong> This must exactly match the key in
|
||||
the{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>{" "}
|
||||
dictionary in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
block_cost_config.py
|
||||
</code>
|
||||
. Common values: "openai", "anthropic",
|
||||
"groq", "open_router", etc.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Capabilities */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Capabilities
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Provider feature flags
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
{[
|
||||
{ name: "supports_tools", label: "Supports tools" },
|
||||
{ name: "supports_json_output", label: "Supports JSON output" },
|
||||
{ name: "supports_reasoning", label: "Supports reasoning" },
|
||||
{
|
||||
name: "supports_parallel_tool",
|
||||
label: "Supports parallel tool calls",
|
||||
},
|
||||
].map(({ name, label }) => (
|
||||
<div
|
||||
key={name}
|
||||
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
|
||||
>
|
||||
<input type="hidden" name={name} value="off" />
|
||||
<input
|
||||
id={name}
|
||||
type="checkbox"
|
||||
name={name}
|
||||
defaultChecked={
|
||||
name !== "supports_reasoning" &&
|
||||
name !== "supports_parallel_tool"
|
||||
}
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor={name}
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Save Provider"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { updateLlmCreatorAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { DeleteCreatorModal } from "./DeleteCreatorModal";
|
||||
|
||||
export function CreatorsTable({ creators }: { creators: LlmModelCreator[] }) {
|
||||
if (!creators.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No creators registered yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Creator</TableHead>
|
||||
<TableHead>Description</TableHead>
|
||||
<TableHead>Website</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{creators.map((creator) => (
|
||||
<TableRow key={creator.id}>
|
||||
<TableCell>
|
||||
<div className="font-medium">{creator.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{creator.name}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span className="text-sm text-muted-foreground">
|
||||
{creator.description || "—"}
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{creator.website_url ? (
|
||||
<a
|
||||
href={creator.website_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-sm text-primary hover:underline"
|
||||
>
|
||||
{(() => {
|
||||
try {
|
||||
return new URL(creator.website_url).hostname;
|
||||
} catch {
|
||||
return creator.website_url;
|
||||
}
|
||||
})()}
|
||||
</a>
|
||||
) : (
|
||||
<span className="text-muted-foreground">—</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
<EditCreatorModal creator={creator} />
|
||||
<DeleteCreatorModal creator={creator} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EditCreatorModal({ creator }: { creator: LlmModelCreator }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to update creator");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "512px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small" className="min-w-0">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<input type="hidden" name="creator_id" value={creator.id} />
|
||||
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Name (slug)</label>
|
||||
<input
|
||||
required
|
||||
name="name"
|
||||
defaultValue={creator.name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Display Name</label>
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={creator.display_name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Description</label>
|
||||
<textarea
|
||||
name="description"
|
||||
rows={2}
|
||||
defaultValue={creator.description ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Website URL</label>
|
||||
<input
|
||||
name="website_url"
|
||||
type="url"
|
||||
defaultValue={creator.website_url ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Updating..." : "Update"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import { deleteLlmCreatorAction } from "../actions";
|
||||
|
||||
export function DeleteCreatorModal({ creator }: { creator: LlmModelCreator }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to delete creator");
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "480px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{creator.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">
|
||||
({creator.name})
|
||||
</span>
|
||||
</p>
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Models using this creator will have their creator field
|
||||
cleared. This is safe and won't affect model
|
||||
functionality.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="creator_id" value={creator.id} />
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={isDeleting}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
{isDeleting ? "Deleting..." : "Delete Creator"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import { deleteLlmModelAction, fetchLlmModelUsage } from "../actions";
|
||||
|
||||
export function DeleteModelModal({
|
||||
model,
|
||||
availableModels,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
availableModels: LlmModel[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [selectedReplacement, setSelectedReplacement] = useState<string>("");
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [usageCount, setUsageCount] = useState<number | null>(null);
|
||||
const [usageLoading, setUsageLoading] = useState(false);
|
||||
const [usageError, setUsageError] = useState<string | null>(null);
|
||||
|
||||
// Filter out the current model and disabled models from replacement options
|
||||
const replacementOptions = availableModels.filter(
|
||||
(m) => m.id !== model.id && m.is_enabled,
|
||||
);
|
||||
|
||||
// Check if migration is required (has blocks using this model)
|
||||
const requiresMigration = usageCount !== null && usageCount > 0;
|
||||
|
||||
async function fetchUsage() {
|
||||
setUsageLoading(true);
|
||||
setUsageError(null);
|
||||
try {
|
||||
const usage = await fetchLlmModelUsage(model.id);
|
||||
setUsageCount(usage.node_count);
|
||||
} catch (err) {
|
||||
console.error("Failed to fetch model usage:", err);
|
||||
setUsageError("Failed to load usage count");
|
||||
setUsageCount(null);
|
||||
} finally {
|
||||
setUsageLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to delete model");
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if delete button should be enabled
|
||||
const canDelete =
|
||||
!isDeleting &&
|
||||
!usageLoading &&
|
||||
usageCount !== null &&
|
||||
(requiresMigration
|
||||
? selectedReplacement && replacementOptions.length > 0
|
||||
: true);
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Model"
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: async (isOpen) => {
|
||||
setOpen(isOpen);
|
||||
if (isOpen) {
|
||||
setUsageCount(null);
|
||||
setUsageError(null);
|
||||
setError(null);
|
||||
setSelectedReplacement("");
|
||||
await fetchUsage();
|
||||
}
|
||||
},
|
||||
}}
|
||||
styling={{ maxWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
{requiresMigration
|
||||
? "This action cannot be undone. All workflows using this model will be migrated to the replacement model you select."
|
||||
: "This action cannot be undone."}
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{model.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">({model.slug})</span>
|
||||
</p>
|
||||
{usageLoading && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Loading usage count...
|
||||
</p>
|
||||
)}
|
||||
{usageError && (
|
||||
<p className="mt-2 text-destructive">{usageError}</p>
|
||||
)}
|
||||
{!usageLoading && !usageError && usageCount !== null && (
|
||||
<p className="mt-2 font-semibold">
|
||||
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
|
||||
currently use this model
|
||||
</p>
|
||||
)}
|
||||
{requiresMigration && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
All workflows currently using this model will be
|
||||
automatically updated to use the replacement model you
|
||||
choose below.
|
||||
</p>
|
||||
)}
|
||||
{!usageLoading && usageCount === 0 && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
No workflows are using this model. It can be safely deleted.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.id} />
|
||||
<input
|
||||
type="hidden"
|
||||
name="replacement_model_slug"
|
||||
value={selectedReplacement}
|
||||
/>
|
||||
|
||||
{requiresMigration && (
|
||||
<label className="text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Select Replacement Model{" "}
|
||||
<span className="text-destructive">*</span>
|
||||
</span>
|
||||
<select
|
||||
required
|
||||
value={selectedReplacement}
|
||||
onChange={(e) => setSelectedReplacement(e.target.value)}
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="">-- Choose a replacement model --</option>
|
||||
{replacementOptions.map((m) => (
|
||||
<option key={m.id} value={m.slug}>
|
||||
{m.display_name} ({m.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
{replacementOptions.length === 0 && (
|
||||
<p className="mt-2 text-xs text-destructive">
|
||||
No replacement models available. You must have at least one
|
||||
other enabled model before deleting this one.
|
||||
</p>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setSelectedReplacement("");
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={!canDelete}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
{isDeleting
|
||||
? "Deleting..."
|
||||
: requiresMigration
|
||||
? "Delete and Migrate"
|
||||
: "Delete"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { deleteLlmProviderAction } from "../actions";
|
||||
|
||||
export function DeleteProviderModal({ provider }: { provider: LlmProvider }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
const modelCount = provider.models?.length ?? 0;
|
||||
const hasModels = modelCount > 0;
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to delete provider",
|
||||
);
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "480px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<div
|
||||
className={`rounded-lg border p-4 ${
|
||||
hasModels
|
||||
? "border-destructive/30 bg-destructive/10"
|
||||
: "border-amber-500/30 bg-amber-500/10 dark:border-amber-400/30 dark:bg-amber-400/10"
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-start gap-3">
|
||||
<div
|
||||
className={`flex-shrink-0 ${
|
||||
hasModels
|
||||
? "text-destructive"
|
||||
: "text-amber-600 dark:text-amber-400"
|
||||
}`}
|
||||
>
|
||||
{hasModels ? "🚫" : "⚠️"}
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{provider.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">
|
||||
({provider.name})
|
||||
</span>
|
||||
</p>
|
||||
{hasModels ? (
|
||||
<p className="mt-2 text-destructive">
|
||||
This provider has {modelCount} model(s). You must delete all
|
||||
models before you can delete this provider.
|
||||
</p>
|
||||
) : (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
This provider has no models and can be safely deleted.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="provider_id" value={provider.id} />
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={isDeleting || hasModels}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90 disabled:opacity-50"
|
||||
>
|
||||
{isDeleting ? "Deleting..." : "Delete Provider"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,288 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import { toggleLlmModelAction, fetchLlmModelUsage } from "../actions";
|
||||
|
||||
export function DisableModelModal({
|
||||
model,
|
||||
availableModels,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
availableModels: LlmModel[];
|
||||
}) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDisabling, setIsDisabling] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [usageCount, setUsageCount] = useState<number | null>(null);
|
||||
const [selectedMigration, setSelectedMigration] = useState<string>("");
|
||||
const [wantsMigration, setWantsMigration] = useState(false);
|
||||
const [migrationReason, setMigrationReason] = useState("");
|
||||
const [customCreditCost, setCustomCreditCost] = useState<string>("");
|
||||
|
||||
// Filter out the current model and disabled models from replacement options
|
||||
const migrationOptions = availableModels.filter(
|
||||
(m) => m.id !== model.id && m.is_enabled,
|
||||
);
|
||||
|
||||
async function fetchUsage() {
|
||||
try {
|
||||
const usage = await fetchLlmModelUsage(model.id);
|
||||
setUsageCount(usage.node_count);
|
||||
} catch {
|
||||
setUsageCount(null);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDisable(formData: FormData) {
|
||||
setIsDisabling(true);
|
||||
setError(null);
|
||||
try {
|
||||
await toggleLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to disable model");
|
||||
} finally {
|
||||
setIsDisabling(false);
|
||||
}
|
||||
}
|
||||
|
||||
function resetState() {
|
||||
setError(null);
|
||||
setSelectedMigration("");
|
||||
setWantsMigration(false);
|
||||
setMigrationReason("");
|
||||
setCustomCreditCost("");
|
||||
}
|
||||
|
||||
const hasUsage = usageCount !== null && usageCount > 0;
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Disable Model"
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: async (isOpen) => {
|
||||
setOpen(isOpen);
|
||||
if (isOpen) {
|
||||
setUsageCount(null);
|
||||
resetState();
|
||||
await fetchUsage();
|
||||
}
|
||||
},
|
||||
}}
|
||||
styling={{ maxWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0"
|
||||
>
|
||||
Disable
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Disabling a model will hide it from users when creating new workflows.
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to disable:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{model.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">({model.slug})</span>
|
||||
</p>
|
||||
{usageCount === null ? (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Loading usage data...
|
||||
</p>
|
||||
) : usageCount > 0 ? (
|
||||
<p className="mt-2 font-semibold">
|
||||
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
|
||||
currently use this model
|
||||
</p>
|
||||
) : (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
No workflows are currently using this model.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{hasUsage && (
|
||||
<div className="space-y-4 rounded-lg border border-border bg-muted/50 p-4">
|
||||
<label className="flex items-start gap-3">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={wantsMigration}
|
||||
onChange={(e) => {
|
||||
setWantsMigration(e.target.checked);
|
||||
if (!e.target.checked) {
|
||||
setSelectedMigration("");
|
||||
}
|
||||
}}
|
||||
className="mt-1"
|
||||
/>
|
||||
<div className="text-sm">
|
||||
<span className="font-medium">
|
||||
Migrate existing workflows to another model
|
||||
</span>
|
||||
<p className="mt-1 text-muted-foreground">
|
||||
Creates a revertible migration record. If unchecked,
|
||||
existing workflows will use automatic fallback to an enabled
|
||||
model from the same provider.
|
||||
</p>
|
||||
</div>
|
||||
</label>
|
||||
|
||||
{wantsMigration && (
|
||||
<div className="space-y-4 border-t border-border pt-4">
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Replacement Model{" "}
|
||||
<span className="text-destructive">*</span>
|
||||
</span>
|
||||
<select
|
||||
required
|
||||
value={selectedMigration}
|
||||
onChange={(e) => setSelectedMigration(e.target.value)}
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="">-- Choose a replacement model --</option>
|
||||
{migrationOptions.map((m) => (
|
||||
<option key={m.id} value={m.slug}>
|
||||
{m.display_name} ({m.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
{migrationOptions.length === 0 && (
|
||||
<p className="mt-2 text-xs text-destructive">
|
||||
No other enabled models available for migration.
|
||||
</p>
|
||||
)}
|
||||
</label>
|
||||
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Migration Reason{" "}
|
||||
<span className="font-normal text-muted-foreground">
|
||||
(optional)
|
||||
</span>
|
||||
</span>
|
||||
<input
|
||||
type="text"
|
||||
value={migrationReason}
|
||||
onChange={(e) => setMigrationReason(e.target.value)}
|
||||
placeholder="e.g., Provider outage, Cost reduction"
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
<p className="mt-1 text-xs text-muted-foreground">
|
||||
Helps track why the migration was made
|
||||
</p>
|
||||
</label>
|
||||
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Custom Credit Cost{" "}
|
||||
<span className="font-normal text-muted-foreground">
|
||||
(optional)
|
||||
</span>
|
||||
</span>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
value={customCreditCost}
|
||||
onChange={(e) => setCustomCreditCost(e.target.value)}
|
||||
placeholder="Leave blank to use target model's cost"
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
<p className="mt-1 text-xs text-muted-foreground">
|
||||
Override pricing for migrated workflows. When set, billing
|
||||
will use this cost instead of the target model's
|
||||
cost.
|
||||
</p>
|
||||
</label>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<form action={handleDisable} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.id} />
|
||||
<input type="hidden" name="is_enabled" value="false" />
|
||||
{wantsMigration && selectedMigration && (
|
||||
<>
|
||||
<input
|
||||
type="hidden"
|
||||
name="migrate_to_slug"
|
||||
value={selectedMigration}
|
||||
/>
|
||||
{migrationReason && (
|
||||
<input
|
||||
type="hidden"
|
||||
name="migration_reason"
|
||||
value={migrationReason}
|
||||
/>
|
||||
)}
|
||||
{customCreditCost && (
|
||||
<input
|
||||
type="hidden"
|
||||
name="custom_credit_cost"
|
||||
value={customCreditCost}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
resetState();
|
||||
}}
|
||||
disabled={isDisabling}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={
|
||||
isDisabling ||
|
||||
(wantsMigration && !selectedMigration) ||
|
||||
usageCount === null
|
||||
}
|
||||
>
|
||||
{isDisabling
|
||||
? "Disabling..."
|
||||
: wantsMigration && selectedMigration
|
||||
? "Disable & Migrate"
|
||||
: "Disable Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { updateLlmModelAction } from "../actions";
|
||||
|
||||
export function EditModelModal({
|
||||
model,
|
||||
providers,
|
||||
creators,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const cost = model.costs?.[0];
|
||||
const provider = providers.find((p) => p.id === model.provider_id);
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to update model");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Model"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small" className="min-w-0">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Update model metadata and pricing information.
|
||||
</div>
|
||||
{error && (
|
||||
<div className="mb-4 rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.id} />
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Display Name
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={model.display_name}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Provider
|
||||
<select
|
||||
required
|
||||
name="provider_id"
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
defaultValue={model.provider_id}
|
||||
>
|
||||
{providers.map((p) => (
|
||||
<option key={p.id} value={p.id}>
|
||||
{p.display_name} ({p.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Who hosts/serves the model
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Creator
|
||||
<select
|
||||
name="creator_id"
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
defaultValue={model.creator_id ?? ""}
|
||||
>
|
||||
<option value="">No creator selected</option>
|
||||
{creators.map((c) => (
|
||||
<option key={c.id} value={c.id}>
|
||||
{c.display_name} ({c.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Who made/trained the model (e.g., OpenAI, Meta)
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<label className="text-sm font-medium">
|
||||
Description
|
||||
<textarea
|
||||
name="description"
|
||||
rows={2}
|
||||
defaultValue={model.description ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</label>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Context Window
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="context_window"
|
||||
defaultValue={model.context_window}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Max Output Tokens
|
||||
<input
|
||||
type="number"
|
||||
name="max_output_tokens"
|
||||
defaultValue={model.max_output_tokens ?? undefined}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Credit Cost
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="credit_cost"
|
||||
defaultValue={cost?.credit_cost ?? 0}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={0}
|
||||
/>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Credits charged per run
|
||||
</span>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Credential Provider
|
||||
<select
|
||||
required
|
||||
name="credential_provider"
|
||||
defaultValue={cost?.credential_provider ?? provider?.name ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="" disabled>
|
||||
Select provider
|
||||
</option>
|
||||
{providers.map((p) => (
|
||||
<option key={p.id} value={p.name}>
|
||||
{p.display_name} ({p.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Must match a key in PROVIDER_CREDENTIALS
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
{/* Hidden defaults for credential_type and unit */}
|
||||
<input
|
||||
type="hidden"
|
||||
name="credential_type"
|
||||
value={
|
||||
cost?.credential_type ??
|
||||
provider?.default_credential_type ??
|
||||
"api_key"
|
||||
}
|
||||
/>
|
||||
<input type="hidden" name="unit" value={cost?.unit ?? "RUN"} />
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => setOpen(false)}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Updating..." : "Update Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { updateLlmProviderAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
|
||||
export function EditProviderModal({ provider }: { provider: LlmProvider }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to update provider",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Update provider configuration and capabilities.
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
<input type="hidden" name="provider_id" value={provider.id} />
|
||||
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core provider details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
defaultValue={provider.name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="e.g. openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={provider.display_name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
defaultValue={provider.description ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Default Credentials */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Default Credentials
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credential provider name that matches the key in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_provider"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Provider
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_provider"
|
||||
name="default_credential_provider"
|
||||
defaultValue={provider.default_credential_provider ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential ID
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_id"
|
||||
name="default_credential_id"
|
||||
defaultValue={provider.default_credential_id ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional credential ID"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_type"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Type
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_type"
|
||||
name="default_credential_type"
|
||||
defaultValue={provider.default_credential_type ?? "api_key"}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="api_key"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Capabilities */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Capabilities
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Provider feature flags
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
{[
|
||||
{
|
||||
name: "supports_tools",
|
||||
label: "Supports tools",
|
||||
checked: provider.supports_tools,
|
||||
},
|
||||
{
|
||||
name: "supports_json_output",
|
||||
label: "Supports JSON output",
|
||||
checked: provider.supports_json_output,
|
||||
},
|
||||
{
|
||||
name: "supports_reasoning",
|
||||
label: "Supports reasoning",
|
||||
checked: provider.supports_reasoning,
|
||||
},
|
||||
{
|
||||
name: "supports_parallel_tool",
|
||||
label: "Supports parallel tool calls",
|
||||
checked: provider.supports_parallel_tool,
|
||||
},
|
||||
].map(({ name, label, checked }) => (
|
||||
<div
|
||||
key={name}
|
||||
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
|
||||
>
|
||||
<input type="hidden" name={name} value="off" />
|
||||
<input
|
||||
id={name}
|
||||
type="checkbox"
|
||||
name={name}
|
||||
defaultChecked={checked}
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor={name}
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Saving..." : "Save Changes"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { ErrorBoundary } from "@/components/molecules/ErrorBoundary/ErrorBoundary";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { AddProviderModal } from "./AddProviderModal";
|
||||
import { AddModelModal } from "./AddModelModal";
|
||||
import { AddCreatorModal } from "./AddCreatorModal";
|
||||
import { ProviderList } from "./ProviderList";
|
||||
import { ModelsTable } from "./ModelsTable";
|
||||
import { MigrationsTable } from "./MigrationsTable";
|
||||
import { CreatorsTable } from "./CreatorsTable";
|
||||
import { RecommendedModelSelector } from "./RecommendedModelSelector";
|
||||
|
||||
interface Props {
|
||||
providers: LlmProvider[];
|
||||
models: LlmModel[];
|
||||
migrations: LlmModelMigration[];
|
||||
creators: LlmModelCreator[];
|
||||
}
|
||||
|
||||
function AdminErrorFallback() {
|
||||
return (
|
||||
<div className="mx-auto max-w-xl p-6">
|
||||
<ErrorCard
|
||||
responseError={{
|
||||
message:
|
||||
"An error occurred while loading the LLM Registry. Please refresh the page.",
|
||||
}}
|
||||
context="llm-registry"
|
||||
onRetry={() => window.location.reload()}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function LlmRegistryDashboard({
|
||||
providers,
|
||||
models,
|
||||
migrations,
|
||||
creators,
|
||||
}: Props) {
|
||||
return (
|
||||
<ErrorBoundary fallback={<AdminErrorFallback />} context="llm-registry">
|
||||
<div className="mx-auto p-6">
|
||||
<div className="flex flex-col gap-6">
|
||||
{/* Header */}
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">LLM Registry</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Manage providers, creators, models, and credit pricing
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Active Migrations Section - Only show if there are migrations */}
|
||||
{migrations.length > 0 && (
|
||||
<div className="rounded-lg border border-primary/30 bg-primary/5 p-6 shadow-sm">
|
||||
<div className="mb-4">
|
||||
<h2 className="text-xl font-semibold">Active Migrations</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
These migrations can be reverted to restore workflows to their
|
||||
original model
|
||||
</p>
|
||||
</div>
|
||||
<MigrationsTable migrations={migrations} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Providers & Creators Section - Side by Side */}
|
||||
<div className="grid gap-6 lg:grid-cols-2">
|
||||
{/* Providers */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Providers</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Who hosts/serves the models
|
||||
</p>
|
||||
</div>
|
||||
<AddProviderModal />
|
||||
</div>
|
||||
<ProviderList providers={providers} />
|
||||
</div>
|
||||
|
||||
{/* Creators */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Creators</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Who made/trained the models
|
||||
</p>
|
||||
</div>
|
||||
<AddCreatorModal />
|
||||
</div>
|
||||
<CreatorsTable creators={creators} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Models Section */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Models</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Toggle availability, adjust context windows, and update credit
|
||||
pricing
|
||||
</p>
|
||||
</div>
|
||||
<AddModelModal providers={providers} creators={creators} />
|
||||
</div>
|
||||
|
||||
{/* Recommended Model Selector */}
|
||||
<div className="mb-6">
|
||||
<RecommendedModelSelector models={models} />
|
||||
</div>
|
||||
|
||||
<ModelsTable
|
||||
models={models}
|
||||
providers={providers}
|
||||
creators={creators}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { revertLlmMigrationAction } from "../actions";
|
||||
|
||||
export function MigrationsTable({
|
||||
migrations,
|
||||
}: {
|
||||
migrations: LlmModelMigration[];
|
||||
}) {
|
||||
if (!migrations.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No active migrations. Migrations are created when you disable a model
|
||||
with the "Migrate existing workflows" option.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Migration</TableHead>
|
||||
<TableHead>Reason</TableHead>
|
||||
<TableHead>Nodes Affected</TableHead>
|
||||
<TableHead>Custom Cost</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead className="text-right">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{migrations.map((migration) => (
|
||||
<MigrationRow key={migration.id} migration={migration} />
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function MigrationRow({ migration }: { migration: LlmModelMigration }) {
|
||||
const [isReverting, setIsReverting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
async function handleRevert(formData: FormData) {
|
||||
setIsReverting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await revertLlmMigrationAction(formData);
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to revert migration",
|
||||
);
|
||||
} finally {
|
||||
setIsReverting(false);
|
||||
}
|
||||
}
|
||||
|
||||
const createdDate = new Date(migration.created_at);
|
||||
|
||||
return (
|
||||
<>
|
||||
<TableRow>
|
||||
<TableCell>
|
||||
<div className="text-sm">
|
||||
<span className="font-medium">{migration.source_model_slug}</span>
|
||||
<span className="mx-2 text-muted-foreground">→</span>
|
||||
<span className="font-medium">{migration.target_model_slug}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{migration.reason || "—"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm">{migration.node_count}</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm">
|
||||
{migration.custom_credit_cost !== null &&
|
||||
migration.custom_credit_cost !== undefined
|
||||
? `${migration.custom_credit_cost} credits`
|
||||
: "—"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{createdDate.toLocaleDateString()}{" "}
|
||||
{createdDate.toLocaleTimeString([], {
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
})}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<form action={handleRevert} className="inline">
|
||||
<input type="hidden" name="migration_id" value={migration.id} />
|
||||
<Button
|
||||
type="submit"
|
||||
variant="outline"
|
||||
size="small"
|
||||
disabled={isReverting}
|
||||
>
|
||||
{isReverting ? "Reverting..." : "Revert"}
|
||||
</Button>
|
||||
</form>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
{error && (
|
||||
<TableRow>
|
||||
<TableCell colSpan={6}>
|
||||
<div className="rounded border border-destructive/30 bg-destructive/10 p-2 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,262 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { toggleLlmModelAction } from "../actions";
|
||||
import { DeleteModelModal } from "./DeleteModelModal";
|
||||
import { DisableModelModal } from "./DisableModelModal";
|
||||
import { EditModelModal } from "./EditModelModal";
|
||||
import { Star, Spinner } from "@phosphor-icons/react";
|
||||
import { getV2ListLlmModels } from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
|
||||
const PAGE_SIZE = 50;
|
||||
|
||||
export function ModelsTable({
|
||||
models: initialModels,
|
||||
providers,
|
||||
creators,
|
||||
}: {
|
||||
models: LlmModel[];
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}) {
|
||||
const [models, setModels] = useState<LlmModel[]>(initialModels);
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
const [hasMore, setHasMore] = useState(initialModels.length === PAGE_SIZE);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const loadedPagesRef = useRef(1);
|
||||
|
||||
// Sync with parent when initialModels changes (e.g., after enable/disable)
|
||||
// Re-fetch all loaded pages to preserve expanded state
|
||||
useEffect(() => {
|
||||
async function refetchAllPages() {
|
||||
const pagesToLoad = loadedPagesRef.current;
|
||||
|
||||
if (pagesToLoad === 1) {
|
||||
// Only first page loaded, just use initialModels
|
||||
setModels(initialModels);
|
||||
setHasMore(initialModels.length === PAGE_SIZE);
|
||||
return;
|
||||
}
|
||||
|
||||
// Re-fetch all pages we had loaded
|
||||
const allModels: LlmModel[] = [...initialModels];
|
||||
let lastPageHadFullResults = initialModels.length === PAGE_SIZE;
|
||||
|
||||
for (let page = 2; page <= pagesToLoad; page++) {
|
||||
try {
|
||||
const response = await getV2ListLlmModels({
|
||||
page,
|
||||
page_size: PAGE_SIZE,
|
||||
});
|
||||
if (response.status === 200) {
|
||||
allModels.push(...response.data.models);
|
||||
lastPageHadFullResults = response.data.models.length === PAGE_SIZE;
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(`Error refetching page ${page}:`, err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
setModels(allModels);
|
||||
setHasMore(lastPageHadFullResults);
|
||||
}
|
||||
|
||||
refetchAllPages();
|
||||
}, [initialModels]);
|
||||
|
||||
async function loadMore() {
|
||||
if (isLoading) return;
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const nextPage = currentPage + 1;
|
||||
const response = await getV2ListLlmModels({
|
||||
page: nextPage,
|
||||
page_size: PAGE_SIZE,
|
||||
});
|
||||
|
||||
if (response.status === 200) {
|
||||
setModels((prev) => [...prev, ...response.data.models]);
|
||||
setCurrentPage(nextPage);
|
||||
loadedPagesRef.current = nextPage;
|
||||
setHasMore(response.data.models.length === PAGE_SIZE);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Error loading more models:", err);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
if (!models.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No models registered yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const providerLookup = new Map(
|
||||
providers.map((provider) => [provider.id, provider]),
|
||||
);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Model</TableHead>
|
||||
<TableHead>Provider</TableHead>
|
||||
<TableHead>Creator</TableHead>
|
||||
<TableHead>Context Window</TableHead>
|
||||
<TableHead>Max Output</TableHead>
|
||||
<TableHead>Cost</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{models.map((model) => {
|
||||
const cost = model.costs?.[0];
|
||||
const provider = providerLookup.get(model.provider_id);
|
||||
return (
|
||||
<TableRow
|
||||
key={model.id}
|
||||
className={model.is_enabled ? "" : "opacity-60"}
|
||||
>
|
||||
<TableCell>
|
||||
<div className="font-medium">{model.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.slug}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{provider ? (
|
||||
<>
|
||||
<div>{provider.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{provider.name}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
model.provider_id
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{model.creator ? (
|
||||
<>
|
||||
<div>{model.creator.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.creator.name}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<span className="text-muted-foreground">—</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>{model.context_window.toLocaleString()}</TableCell>
|
||||
<TableCell>
|
||||
{model.max_output_tokens
|
||||
? model.max_output_tokens.toLocaleString()
|
||||
: "—"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{cost ? (
|
||||
<>
|
||||
<div className="font-medium">
|
||||
{cost.credit_cost} credits
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{cost.credential_provider}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
"—"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex flex-col gap-1">
|
||||
<span
|
||||
className={`inline-flex rounded-full px-2.5 py-1 text-xs font-semibold ${
|
||||
model.is_enabled
|
||||
? "bg-primary/10 text-primary"
|
||||
: "bg-muted text-muted-foreground"
|
||||
}`}
|
||||
>
|
||||
{model.is_enabled ? "Enabled" : "Disabled"}
|
||||
</span>
|
||||
{model.is_recommended && (
|
||||
<span className="inline-flex items-center gap-1 rounded-full bg-amber-500/10 px-2.5 py-1 text-xs font-semibold text-amber-600 dark:text-amber-400">
|
||||
<Star size={12} weight="fill" />
|
||||
Recommended
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{model.is_enabled ? (
|
||||
<DisableModelModal
|
||||
model={model}
|
||||
availableModels={models}
|
||||
/>
|
||||
) : (
|
||||
<EnableModelButton modelId={model.id} />
|
||||
)}
|
||||
<EditModelModal
|
||||
model={model}
|
||||
providers={providers}
|
||||
creators={creators}
|
||||
/>
|
||||
<DeleteModelModal model={model} availableModels={models} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
{hasMore && (
|
||||
<div className="mt-4 flex justify-center">
|
||||
<Button onClick={loadMore} disabled={isLoading} variant="outline">
|
||||
{isLoading ? (
|
||||
<>
|
||||
<Spinner className="mr-2 h-4 w-4 animate-spin" />
|
||||
Loading...
|
||||
</>
|
||||
) : (
|
||||
"Load More"
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EnableModelButton({ modelId }: { modelId: string }) {
|
||||
return (
|
||||
<form action={toggleLlmModelAction} className="inline">
|
||||
<input type="hidden" name="model_id" value={modelId} />
|
||||
<input type="hidden" name="is_enabled" value="true" />
|
||||
<Button type="submit" variant="outline" size="small" className="min-w-0">
|
||||
Enable
|
||||
</Button>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { DeleteProviderModal } from "./DeleteProviderModal";
|
||||
import { EditProviderModal } from "./EditProviderModal";
|
||||
|
||||
export function ProviderList({ providers }: { providers: LlmProvider[] }) {
|
||||
if (!providers.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No providers configured yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Display Name</TableHead>
|
||||
<TableHead>Default Credential</TableHead>
|
||||
<TableHead>Capabilities</TableHead>
|
||||
<TableHead>Models</TableHead>
|
||||
<TableHead className="w-[100px]">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{providers.map((provider) => (
|
||||
<TableRow key={provider.id}>
|
||||
<TableCell className="font-medium">{provider.name}</TableCell>
|
||||
<TableCell>{provider.display_name}</TableCell>
|
||||
<TableCell>
|
||||
{provider.default_credential_provider
|
||||
? `${provider.default_credential_provider} (${provider.default_credential_id ?? "id?"})`
|
||||
: "—"}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-muted-foreground">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{provider.supports_tools && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Tools
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_json_output && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
JSON
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_reasoning && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Reasoning
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_parallel_tool && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Parallel Tools
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-sm">
|
||||
<span
|
||||
className={
|
||||
(provider.models?.length ?? 0) > 0
|
||||
? "text-foreground"
|
||||
: "text-muted-foreground"
|
||||
}
|
||||
>
|
||||
{provider.models?.length ?? 0}
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex gap-2">
|
||||
<EditProviderModal provider={provider} />
|
||||
<DeleteProviderModal provider={provider} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { setRecommendedModelAction } from "../actions";
|
||||
import { Star } from "@phosphor-icons/react";
|
||||
|
||||
export function RecommendedModelSelector({ models }: { models: LlmModel[] }) {
|
||||
const router = useRouter();
|
||||
const enabledModels = models.filter((m) => m.is_enabled);
|
||||
const currentRecommended = models.find((m) => m.is_recommended);
|
||||
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>(
|
||||
currentRecommended?.id || "",
|
||||
);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const hasChanges = selectedModelId !== (currentRecommended?.id || "");
|
||||
|
||||
async function handleSave() {
|
||||
if (!selectedModelId) return;
|
||||
|
||||
setIsSaving(true);
|
||||
setError(null);
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.set("model_id", selectedModelId);
|
||||
await setRecommendedModelAction(formData);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to save");
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border border-border bg-card p-4">
|
||||
<div className="mb-3 flex items-center gap-2">
|
||||
<Star size={20} weight="fill" className="text-amber-500" />
|
||||
<h3 className="text-sm font-semibold">Recommended Model</h3>
|
||||
</div>
|
||||
<p className="mb-3 text-xs text-muted-foreground">
|
||||
The recommended model is shown as the default suggestion in model
|
||||
selection dropdowns throughout the platform.
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-3">
|
||||
<select
|
||||
value={selectedModelId}
|
||||
onChange={(e) => setSelectedModelId(e.target.value)}
|
||||
className="flex-1 rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
disabled={isSaving}
|
||||
>
|
||||
<option value="">-- Select a model --</option>
|
||||
{enabledModels.map((model) => (
|
||||
<option key={model.id} value={model.id}>
|
||||
{model.display_name} ({model.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleSave}
|
||||
disabled={!hasChanges || !selectedModelId || isSaving}
|
||||
>
|
||||
{isSaving ? "Saving..." : "Save"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{error && <p className="mt-2 text-xs text-destructive">{error}</p>}
|
||||
|
||||
{currentRecommended && !hasChanges && (
|
||||
<p className="mt-2 text-xs text-muted-foreground">
|
||||
Currently set to:{" "}
|
||||
<span className="font-medium">{currentRecommended.display_name}</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
/**
|
||||
* Server-side data fetching for LLM Registry page.
|
||||
*/
|
||||
|
||||
import {
|
||||
fetchLlmCreators,
|
||||
fetchLlmMigrations,
|
||||
fetchLlmModels,
|
||||
fetchLlmProviders,
|
||||
} from "./actions";
|
||||
|
||||
export async function getLlmRegistryPageData() {
|
||||
// Fetch providers and models (required)
|
||||
const [providersResponse, modelsResponse] = await Promise.all([
|
||||
fetchLlmProviders(),
|
||||
fetchLlmModels(),
|
||||
]);
|
||||
|
||||
// Fetch migrations separately with fallback (table might not exist yet)
|
||||
let migrations: Awaited<ReturnType<typeof fetchLlmMigrations>>["migrations"] =
|
||||
[];
|
||||
try {
|
||||
const migrationsResponse = await fetchLlmMigrations(false);
|
||||
migrations = migrationsResponse.migrations;
|
||||
} catch {
|
||||
// Migrations table might not exist yet - that's ok, just show empty list
|
||||
console.warn("Could not fetch migrations - table may not exist yet");
|
||||
}
|
||||
|
||||
// Fetch creators separately with fallback (table might not exist yet)
|
||||
let creators: Awaited<ReturnType<typeof fetchLlmCreators>>["creators"] = [];
|
||||
try {
|
||||
const creatorsResponse = await fetchLlmCreators();
|
||||
creators = creatorsResponse.creators;
|
||||
} catch {
|
||||
// Creators table might not exist yet - that's ok, just show empty list
|
||||
console.warn("Could not fetch creators - table may not exist yet");
|
||||
}
|
||||
|
||||
return {
|
||||
providers: providersResponse.providers,
|
||||
models: modelsResponse.models,
|
||||
migrations,
|
||||
creators,
|
||||
};
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { getLlmRegistryPageData } from "./getLlmRegistryPage";
|
||||
import { LlmRegistryDashboard } from "./components/LlmRegistryDashboard";
|
||||
|
||||
async function LlmRegistryPage() {
|
||||
const data = await getLlmRegistryPageData();
|
||||
return <LlmRegistryDashboard {...data} />;
|
||||
}
|
||||
|
||||
export default async function AdminLlmRegistryPage() {
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedLlmRegistryPage = await withAdminAccess(LlmRegistryPage);
|
||||
return <ProtectedLlmRegistryPage />;
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { NextResponse } from "next/server";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { shouldShowOnboarding } from "@/app/api/helpers";
|
||||
import { getOnboardingStatus } from "@/app/api/helpers";
|
||||
|
||||
// Handle the callback to complete the user session login
|
||||
export async function GET(request: Request) {
|
||||
@@ -25,11 +26,15 @@ export async function GET(request: Request) {
|
||||
const api = new BackendAPI();
|
||||
await api.createUser();
|
||||
|
||||
if (await shouldShowOnboarding()) {
|
||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||
const { shouldShowOnboarding, isChatEnabled } =
|
||||
await getOnboardingStatus();
|
||||
if (shouldShowOnboarding) {
|
||||
next = "/onboarding";
|
||||
revalidatePath("/onboarding", "layout");
|
||||
} else {
|
||||
revalidatePath("/", "layout");
|
||||
next = getHomepageRoute(isChatEnabled);
|
||||
revalidatePath(next, "layout");
|
||||
}
|
||||
} catch (createUserError) {
|
||||
console.error("Error creating user:", createUserError);
|
||||
|
||||
@@ -7,9 +7,8 @@ import { BlockCategoryResponse } from "@/app/api/__generated__/models/blockCateg
|
||||
import { BlockResponse } from "@/app/api/__generated__/models/blockResponse";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
import { useState, useEffect } from "react";
|
||||
import { useState } from "react";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import BackendApi from "@/lib/autogpt-server-api";
|
||||
|
||||
export const useAllBlockContent = () => {
|
||||
const { toast } = useToast();
|
||||
@@ -94,32 +93,6 @@ export const useAllBlockContent = () => {
|
||||
const isErrorOnLoadingMore = (categoryName: string) =>
|
||||
errorLoadingCategories.has(categoryName);
|
||||
|
||||
// Listen for LLM registry refresh notifications
|
||||
useEffect(() => {
|
||||
const api = new BackendApi();
|
||||
const queryClient = getQueryClient();
|
||||
|
||||
const handleNotification = (notification: any) => {
|
||||
if (
|
||||
notification?.type === "LLM_REGISTRY_REFRESH" ||
|
||||
notification?.event === "registry_updated"
|
||||
) {
|
||||
// Invalidate all block-related queries to force refresh
|
||||
const categoriesQueryKey = getGetV2GetBuilderBlockCategoriesQueryKey();
|
||||
queryClient.invalidateQueries({ queryKey: categoriesQueryKey });
|
||||
}
|
||||
};
|
||||
|
||||
const unsubscribe = api.onWebSocketMessage(
|
||||
"notification",
|
||||
handleNotification,
|
||||
);
|
||||
|
||||
return () => {
|
||||
unsubscribe();
|
||||
};
|
||||
}, []);
|
||||
|
||||
return {
|
||||
data,
|
||||
isLoading,
|
||||
|
||||
@@ -610,11 +610,8 @@ const NodeOneOfDiscriminatorField: FC<{
|
||||
|
||||
return oneOfVariants
|
||||
.map((variant) => {
|
||||
const discProperty = variant.properties?.[discriminatorProperty];
|
||||
const variantDiscValue =
|
||||
discProperty && "const" in discProperty
|
||||
? (discProperty.const as string)
|
||||
: undefined; // NOTE: can discriminators only be strings?
|
||||
const variantDiscValue = variant.properties?.[discriminatorProperty]
|
||||
?.const as string; // NOTE: can discriminators only be strings?
|
||||
|
||||
return {
|
||||
value: variantDiscValue,
|
||||
@@ -1127,47 +1124,9 @@ const NodeStringInput: FC<{
|
||||
displayName,
|
||||
}) => {
|
||||
value ||= schema.default || "";
|
||||
|
||||
// Check if we have options with labels (e.g., LLM model picker)
|
||||
const hasOptions = schema.options && schema.options.length > 0;
|
||||
const hasEnum = schema.enum && schema.enum.length > 0;
|
||||
|
||||
// Helper to get display label for a value
|
||||
const getDisplayLabel = (val: string) => {
|
||||
if (hasOptions) {
|
||||
const option = schema.options!.find((opt) => opt.value === val);
|
||||
return option?.label || beautifyString(val);
|
||||
}
|
||||
return beautifyString(val);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={className}>
|
||||
{hasOptions ? (
|
||||
// Render options with proper labels (used by LLM model picker)
|
||||
<Select
|
||||
defaultValue={value}
|
||||
onValueChange={(newValue) => handleInputChange(selfKey, newValue)}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder={schema.placeholder || displayName}>
|
||||
{value ? getDisplayLabel(value) : undefined}
|
||||
</SelectValue>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="nodrag">
|
||||
{schema.options!.map((option, index) => (
|
||||
<SelectItem
|
||||
key={index}
|
||||
value={option.value}
|
||||
title={option.description}
|
||||
>
|
||||
{option.label || beautifyString(option.value)}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
) : hasEnum ? (
|
||||
// Fallback to enum with beautified strings
|
||||
{schema.enum && schema.enum.length > 0 ? (
|
||||
<Select
|
||||
defaultValue={value}
|
||||
onValueChange={(newValue) => handleInputChange(selfKey, newValue)}
|
||||
@@ -1176,8 +1135,8 @@ const NodeStringInput: FC<{
|
||||
<SelectValue placeholder={schema.placeholder || displayName} />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="nodrag">
|
||||
{schema
|
||||
.enum!.filter((option) => option)
|
||||
{schema.enum
|
||||
.filter((option) => option)
|
||||
.map((option, index) => (
|
||||
<SelectItem key={index} value={option}>
|
||||
{beautifyString(option)}
|
||||
|
||||
188
autogpt_platform/frontend/src/app/(platform)/copilot-2/page.tsx
Normal file
188
autogpt_platform/frontend/src/app/(platform)/copilot-2/page.tsx
Normal file
@@ -0,0 +1,188 @@
|
||||
"use client";
|
||||
|
||||
import { useChat } from "@ai-sdk/react";
|
||||
import { DefaultChatTransport } from "ai";
|
||||
import { useState, useMemo } from "react";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { MessageSquare } from "lucide-react";
|
||||
import { ChatSidebar } from "./tools/components/ChatSidebar/ChatSidebar";
|
||||
import {
|
||||
Conversation,
|
||||
ConversationContent,
|
||||
ConversationEmptyState,
|
||||
ConversationScrollButton,
|
||||
} from "@/components/ai-elements/conversation";
|
||||
import {
|
||||
Message,
|
||||
MessageContent,
|
||||
MessageResponse,
|
||||
} from "@/components/ai-elements/message";
|
||||
|
||||
export default function Page() {
|
||||
const [sessionId] = useQueryState("sessionId", parseAsString);
|
||||
const [isCreating, setIsCreating] = useState(false);
|
||||
const [input, setInput] = useState("");
|
||||
|
||||
const transport = useMemo(() => {
|
||||
if (!sessionId) return null;
|
||||
return new DefaultChatTransport({
|
||||
api: `/api/chat/sessions/${sessionId}/stream`,
|
||||
prepareSendMessagesRequest: ({ messages }) => {
|
||||
const last = messages[messages.length - 1];
|
||||
return {
|
||||
body: {
|
||||
message: last.parts
|
||||
?.map((p) => (p.type === "text" ? p.text : ""))
|
||||
.join(""),
|
||||
is_user_message: last.role === "user",
|
||||
context: null,
|
||||
},
|
||||
};
|
||||
},
|
||||
});
|
||||
}, [sessionId]);
|
||||
|
||||
const { messages, sendMessage, status, error } = useChat({
|
||||
transport: transport ?? undefined,
|
||||
});
|
||||
|
||||
function handleSubmit(e: React.FormEvent) {
|
||||
e.preventDefault();
|
||||
if (input.trim()) {
|
||||
sendMessage({ text: input });
|
||||
setInput("");
|
||||
}
|
||||
}
|
||||
|
||||
// Show landing page when no session exists
|
||||
if (!sessionId) {
|
||||
return (
|
||||
<div className="flex h-full">
|
||||
<ChatSidebar isCreating={isCreating} setIsCreating={setIsCreating} />
|
||||
|
||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-zinc-100 p-4">
|
||||
<h2 className="mb-4 text-xl font-semibold text-zinc-700">
|
||||
Start a new conversation
|
||||
</h2>
|
||||
<form onSubmit={handleSubmit} className="w-full max-w-md">
|
||||
<input
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
disabled={isCreating}
|
||||
placeholder="Type your message to start..."
|
||||
className="w-full rounded-md border border-zinc-300 px-4 py-2"
|
||||
/>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isCreating || !input.trim()}
|
||||
className="mt-2 w-full rounded-md bg-blue-600 px-4 py-2 text-white transition-colors hover:bg-blue-700 disabled:opacity-50"
|
||||
>
|
||||
{isCreating ? "Starting..." : "Start Chat"}
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-full">
|
||||
<ChatSidebar isCreating={isCreating} setIsCreating={setIsCreating} />
|
||||
|
||||
<div className="flex h-full flex-1 flex-col">
|
||||
<Conversation className="flex-1">
|
||||
<ConversationContent>
|
||||
{messages.length === 0 ? (
|
||||
<ConversationEmptyState
|
||||
icon={<MessageSquare className="size-12" />}
|
||||
title="Start a conversation"
|
||||
description="Type a message below to begin chatting"
|
||||
/>
|
||||
) : (
|
||||
messages.map((message) => (
|
||||
<Message from={message.role} key={message.id}>
|
||||
<MessageContent>
|
||||
{message.parts.map((part, i) => {
|
||||
switch (part.type) {
|
||||
case "text":
|
||||
return (
|
||||
<MessageResponse key={`${message.id}-${i}`}>
|
||||
{part.text}
|
||||
</MessageResponse>
|
||||
);
|
||||
case "tool-find_block":
|
||||
return (
|
||||
<div
|
||||
key={`${message.id}-${i}`}
|
||||
className="w-fit rounded-xl border border-zinc-200 bg-zinc-100 p-2 text-xs text-zinc-700"
|
||||
>
|
||||
{part.state === "input-streaming" && (
|
||||
<p>Finding blocks for you</p>
|
||||
)}
|
||||
{part.state === "input-available" && (
|
||||
<p>
|
||||
Searching blocks for{" "}
|
||||
{(part.input as { query: string }).query}
|
||||
</p>
|
||||
)}
|
||||
{part.state === "output-available" && (
|
||||
<p>
|
||||
Found{" "}
|
||||
{
|
||||
(
|
||||
JSON.parse(part.output as string) as {
|
||||
count: number;
|
||||
}
|
||||
).count
|
||||
}{" "}
|
||||
blocks
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
})}
|
||||
</MessageContent>
|
||||
</Message>
|
||||
))
|
||||
)}
|
||||
{status === "submitted" && (
|
||||
<Message from="assistant">
|
||||
<MessageContent>
|
||||
<p className="text-zinc-500">Thinking...</p>
|
||||
</MessageContent>
|
||||
</Message>
|
||||
)}
|
||||
{error && (
|
||||
<div className="rounded-lg bg-red-50 p-3 text-red-600">
|
||||
Error: {error.message}
|
||||
</div>
|
||||
)}
|
||||
</ConversationContent>
|
||||
<ConversationScrollButton />
|
||||
</Conversation>
|
||||
|
||||
<form onSubmit={handleSubmit} className="border-t p-4">
|
||||
<div className="mx-auto flex max-w-2xl gap-2">
|
||||
<input
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
disabled={status !== "ready"}
|
||||
placeholder="Say something..."
|
||||
className="flex-1 rounded-md border border-zinc-300 px-4 py-2 focus:border-zinc-500 focus:outline-none"
|
||||
/>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={status !== "ready" || !input.trim()}
|
||||
className="rounded-md bg-zinc-900 px-4 py-2 text-white transition-colors hover:bg-zinc-800 disabled:opacity-50"
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
import { useState } from "react";
|
||||
import { postV2CreateSession } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
|
||||
export const ChatSidebar = ({ isCreating, setIsCreating }: { isCreating: boolean, setIsCreating: (isCreating: boolean) => void }) => {
|
||||
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
|
||||
|
||||
async function createSession(): Promise<string | null> {
|
||||
setIsCreating(true);
|
||||
try {
|
||||
const response = await postV2CreateSession({
|
||||
body: JSON.stringify({}),
|
||||
});
|
||||
if (response.status === 200 && response.data?.id) {
|
||||
return response.data.id;
|
||||
}
|
||||
return null;
|
||||
} catch (error) {
|
||||
return null;
|
||||
} finally {
|
||||
setIsCreating(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleNewSession() {
|
||||
const newSessionId = await createSession();
|
||||
if (newSessionId) {
|
||||
setSessionId(newSessionId);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return (
|
||||
<div className="flex w-64 flex-col border-r border-zinc-200 bg-zinc-50 p-4">
|
||||
<button
|
||||
onClick={handleNewSession}
|
||||
disabled={isCreating}
|
||||
className="rounded-md bg-blue-600 px-4 py-2 text-white transition-colors hover:bg-blue-700 disabled:opacity-50"
|
||||
>
|
||||
{isCreating ? "Creating..." : "New Session"}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
);
|
||||
};
|
||||
@@ -1,12 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||
import type { ReactNode } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { useCopilotStore } from "../../copilot-page-store";
|
||||
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
||||
import { LoadingState } from "./components/LoadingState/LoadingState";
|
||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||
import { useCopilotShell } from "./useCopilotShell";
|
||||
@@ -20,38 +18,21 @@ export function CopilotShell({ children }: Props) {
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
isLoading,
|
||||
isCreatingSession,
|
||||
isLoggedIn,
|
||||
hasActiveSession,
|
||||
sessions,
|
||||
currentSessionId,
|
||||
handleSelectSession,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
handleNewChat,
|
||||
handleNewChatClick,
|
||||
handleSessionClick,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
fetchNextPage,
|
||||
isReadyToShowContent,
|
||||
} = useCopilotShell();
|
||||
|
||||
const setNewChatHandler = useCopilotStore((s) => s.setNewChatHandler);
|
||||
const requestNewChat = useCopilotStore((s) => s.requestNewChat);
|
||||
|
||||
useEffect(
|
||||
function registerNewChatHandler() {
|
||||
setNewChatHandler(handleNewChat);
|
||||
return function cleanup() {
|
||||
setNewChatHandler(null);
|
||||
};
|
||||
},
|
||||
[handleNewChat],
|
||||
);
|
||||
|
||||
function handleNewChatClick() {
|
||||
requestNewChat();
|
||||
}
|
||||
|
||||
if (!isLoggedIn) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
@@ -72,7 +53,7 @@ export function CopilotShell({ children }: Props) {
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={handleSelectSession}
|
||||
onSelectSession={handleSessionClick}
|
||||
onFetchNextPage={fetchNextPage}
|
||||
onNewChat={handleNewChatClick}
|
||||
hasActiveSession={Boolean(hasActiveSession)}
|
||||
@@ -82,7 +63,18 @@ export function CopilotShell({ children }: Props) {
|
||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
<div className="flex min-h-0 flex-1 flex-col">
|
||||
{isReadyToShowContent ? children : <LoadingState />}
|
||||
{isCreatingSession ? (
|
||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Creating your chat...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
children
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -94,7 +86,7 @@ export function CopilotShell({ children }: Props) {
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={handleSelectSession}
|
||||
onSelectSession={handleSessionClick}
|
||||
onFetchNextPage={fetchNextPage}
|
||||
onNewChat={handleNewChatClick}
|
||||
onClose={handleCloseDrawer}
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||
|
||||
export function LoadingState() {
|
||||
return (
|
||||
<div className="flex flex-1 items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading your chats...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -3,17 +3,17 @@ import { useState } from "react";
|
||||
export function useMobileDrawer() {
|
||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||
|
||||
function handleOpenDrawer() {
|
||||
const handleOpenDrawer = () => {
|
||||
setIsDrawerOpen(true);
|
||||
}
|
||||
};
|
||||
|
||||
function handleCloseDrawer() {
|
||||
const handleCloseDrawer = () => {
|
||||
setIsDrawerOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
function handleDrawerOpenChange(open: boolean) {
|
||||
const handleDrawerOpenChange = (open: boolean) => {
|
||||
setIsDrawerOpen(open);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
isDrawerOpen,
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useGetV2ListSessions,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
const PAGE_SIZE = 50;
|
||||
@@ -16,12 +11,12 @@ export interface UseSessionsPaginationArgs {
|
||||
|
||||
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||
const [offset, setOffset] = useState(0);
|
||||
|
||||
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
||||
SessionSummaryResponse[]
|
||||
>([]);
|
||||
|
||||
const [totalCount, setTotalCount] = useState<number | null>(null);
|
||||
const queryClient = useQueryClient();
|
||||
const onStreamComplete = useChatStore((state) => state.onStreamComplete);
|
||||
|
||||
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
||||
{ limit: PAGE_SIZE, offset },
|
||||
@@ -32,38 +27,23 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||
},
|
||||
);
|
||||
|
||||
useEffect(function refreshOnStreamComplete() {
|
||||
const unsubscribe = onStreamComplete(function handleStreamComplete() {
|
||||
setOffset(0);
|
||||
useEffect(() => {
|
||||
const responseData = okData(data);
|
||||
if (responseData) {
|
||||
const newSessions = responseData.sessions;
|
||||
const total = responseData.total;
|
||||
setTotalCount(total);
|
||||
|
||||
if (offset === 0) {
|
||||
setAccumulatedSessions(newSessions);
|
||||
} else {
|
||||
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
||||
}
|
||||
} else if (!enabled) {
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
});
|
||||
return unsubscribe;
|
||||
}, []);
|
||||
|
||||
useEffect(
|
||||
function updateSessionsFromResponse() {
|
||||
const responseData = okData(data);
|
||||
if (responseData) {
|
||||
const newSessions = responseData.sessions;
|
||||
const total = responseData.total;
|
||||
setTotalCount(total);
|
||||
|
||||
if (offset === 0) {
|
||||
setAccumulatedSessions(newSessions);
|
||||
} else {
|
||||
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
||||
}
|
||||
} else if (!enabled) {
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
}
|
||||
},
|
||||
[data, offset, enabled],
|
||||
);
|
||||
}
|
||||
}, [data, offset, enabled]);
|
||||
|
||||
const hasNextPage =
|
||||
totalCount !== null && accumulatedSessions.length < totalCount;
|
||||
@@ -86,17 +66,17 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||
}
|
||||
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
||||
|
||||
function fetchNextPage() {
|
||||
const fetchNextPage = () => {
|
||||
if (hasNextPage && !isFetching) {
|
||||
setOffset((prev) => prev + PAGE_SIZE);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
function reset() {
|
||||
const reset = () => {
|
||||
setOffset(0);
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
sessions: accumulatedSessions,
|
||||
|
||||
@@ -104,76 +104,3 @@ export function mergeCurrentSessionIntoList(
|
||||
export function getCurrentSessionId(searchParams: URLSearchParams) {
|
||||
return searchParams.get("sessionId");
|
||||
}
|
||||
|
||||
export function shouldAutoSelectSession(
|
||||
areAllSessionsLoaded: boolean,
|
||||
hasAutoSelectedSession: boolean,
|
||||
paramSessionId: string | null,
|
||||
visibleSessions: SessionSummaryResponse[],
|
||||
accumulatedSessions: SessionSummaryResponse[],
|
||||
isLoading: boolean,
|
||||
totalCount: number | null,
|
||||
) {
|
||||
if (!areAllSessionsLoaded || hasAutoSelectedSession) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (paramSessionId) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (visibleSessions.length > 0) {
|
||||
return {
|
||||
shouldSelect: true,
|
||||
sessionIdToSelect: visibleSessions[0].id,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) {
|
||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true };
|
||||
}
|
||||
|
||||
if (totalCount === 0) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false };
|
||||
}
|
||||
|
||||
export function checkReadyToShowContent(
|
||||
areAllSessionsLoaded: boolean,
|
||||
paramSessionId: string | null,
|
||||
accumulatedSessions: SessionSummaryResponse[],
|
||||
isCurrentSessionLoading: boolean,
|
||||
currentSessionData: SessionDetailResponse | null | undefined,
|
||||
hasAutoSelectedSession: boolean,
|
||||
) {
|
||||
if (!areAllSessionsLoaded) return false;
|
||||
|
||||
if (paramSessionId) {
|
||||
const sessionFound = accumulatedSessions.some(
|
||||
(s) => s.id === paramSessionId,
|
||||
);
|
||||
return (
|
||||
sessionFound ||
|
||||
(!isCurrentSessionLoading &&
|
||||
currentSessionData !== undefined &&
|
||||
currentSessionData !== null)
|
||||
);
|
||||
}
|
||||
|
||||
return hasAutoSelectedSession;
|
||||
}
|
||||
|
||||
@@ -1,26 +1,22 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
getGetV2GetSessionQueryKey,
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useGetV2GetSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useRef } from "react";
|
||||
import { useCopilotStore } from "../../copilot-page-store";
|
||||
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
||||
import {
|
||||
checkReadyToShowContent,
|
||||
convertSessionDetailToSummary,
|
||||
filterVisibleSessions,
|
||||
getCurrentSessionId,
|
||||
mergeCurrentSessionIntoList,
|
||||
} from "./helpers";
|
||||
import { getCurrentSessionId } from "./helpers";
|
||||
import { useShellSessionList } from "./useShellSessionList";
|
||||
|
||||
export function useCopilotShell() {
|
||||
const pathname = usePathname();
|
||||
@@ -31,7 +27,7 @@ export function useCopilotShell() {
|
||||
const isMobile =
|
||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||
|
||||
const [, setUrlSessionId] = useQueryState("sessionId", parseAsString);
|
||||
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||
|
||||
const isOnHomepage = pathname === "/copilot";
|
||||
const paramSessionId = searchParams.get("sessionId");
|
||||
@@ -45,123 +41,80 @@ export function useCopilotShell() {
|
||||
|
||||
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
||||
|
||||
const {
|
||||
sessions: accumulatedSessions,
|
||||
isLoading: isSessionsLoading,
|
||||
isFetching: isSessionsFetching,
|
||||
hasNextPage,
|
||||
areAllSessionsLoaded,
|
||||
fetchNextPage,
|
||||
reset: resetPagination,
|
||||
} = useSessionsPagination({
|
||||
enabled: paginationEnabled,
|
||||
});
|
||||
|
||||
const currentSessionId = getCurrentSessionId(searchParams);
|
||||
|
||||
const { data: currentSessionData, isLoading: isCurrentSessionLoading } =
|
||||
useGetV2GetSession(currentSessionId || "", {
|
||||
const { data: currentSessionData } = useGetV2GetSession(
|
||||
currentSessionId || "",
|
||||
{
|
||||
query: {
|
||||
enabled: !!currentSessionId,
|
||||
select: okData,
|
||||
},
|
||||
});
|
||||
|
||||
const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false);
|
||||
const hasAutoSelectedRef = useRef(false);
|
||||
const recentlyCreatedSessionsRef = useRef<
|
||||
Map<string, SessionSummaryResponse>
|
||||
>(new Map());
|
||||
|
||||
// Mark as auto-selected when sessionId is in URL
|
||||
useEffect(() => {
|
||||
if (paramSessionId && !hasAutoSelectedRef.current) {
|
||||
hasAutoSelectedRef.current = true;
|
||||
setHasAutoSelectedSession(true);
|
||||
}
|
||||
}, [paramSessionId]);
|
||||
|
||||
// On homepage without sessionId, mark as ready immediately
|
||||
useEffect(() => {
|
||||
if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) {
|
||||
hasAutoSelectedRef.current = true;
|
||||
setHasAutoSelectedSession(true);
|
||||
}
|
||||
}, [isOnHomepage, paramSessionId]);
|
||||
|
||||
// Invalidate sessions list when navigating to homepage (to show newly created sessions)
|
||||
useEffect(() => {
|
||||
if (isOnHomepage && !paramSessionId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
}
|
||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
||||
|
||||
// Track newly created sessions to ensure they stay visible even when switching away
|
||||
useEffect(() => {
|
||||
if (currentSessionId && currentSessionData) {
|
||||
const isNewSession =
|
||||
currentSessionData.updated_at === currentSessionData.created_at;
|
||||
const isNotInAccumulated = !accumulatedSessions.some(
|
||||
(s) => s.id === currentSessionId,
|
||||
);
|
||||
if (isNewSession || isNotInAccumulated) {
|
||||
const summary = convertSessionDetailToSummary(currentSessionData);
|
||||
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
||||
}
|
||||
}
|
||||
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
||||
|
||||
// Clean up recently created sessions that are now in the accumulated list
|
||||
useEffect(() => {
|
||||
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
||||
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
||||
recentlyCreatedSessionsRef.current.delete(sessionId);
|
||||
}
|
||||
}
|
||||
}, [accumulatedSessions]);
|
||||
|
||||
// Reset pagination when query becomes disabled
|
||||
const prevPaginationEnabledRef = useRef(paginationEnabled);
|
||||
useEffect(() => {
|
||||
if (prevPaginationEnabledRef.current && !paginationEnabled) {
|
||||
resetPagination();
|
||||
resetAutoSelect();
|
||||
}
|
||||
prevPaginationEnabledRef.current = paginationEnabled;
|
||||
}, [paginationEnabled, resetPagination]);
|
||||
|
||||
const sessions = mergeCurrentSessionIntoList(
|
||||
accumulatedSessions,
|
||||
currentSessionId,
|
||||
currentSessionData,
|
||||
recentlyCreatedSessionsRef.current,
|
||||
},
|
||||
);
|
||||
|
||||
const visibleSessions = filterVisibleSessions(sessions);
|
||||
const {
|
||||
sessions,
|
||||
isLoading,
|
||||
isSessionsFetching,
|
||||
hasNextPage,
|
||||
fetchNextPage,
|
||||
resetPagination,
|
||||
recentlyCreatedSessionsRef,
|
||||
} = useShellSessionList({
|
||||
paginationEnabled,
|
||||
currentSessionId,
|
||||
currentSessionData,
|
||||
isOnHomepage,
|
||||
paramSessionId,
|
||||
});
|
||||
|
||||
const sidebarSelectedSessionId =
|
||||
isOnHomepage && !paramSessionId ? null : currentSessionId;
|
||||
const stopStream = useChatStore((s) => s.stopStream);
|
||||
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||
|
||||
const isReadyToShowContent = isOnHomepage
|
||||
? true
|
||||
: checkReadyToShowContent(
|
||||
areAllSessionsLoaded,
|
||||
paramSessionId,
|
||||
accumulatedSessions,
|
||||
isCurrentSessionLoading,
|
||||
currentSessionData,
|
||||
hasAutoSelectedSession,
|
||||
);
|
||||
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||
|
||||
function handleSelectSession(sessionId: string) {
|
||||
async function stopCurrentStream() {
|
||||
if (!currentSessionId) return;
|
||||
|
||||
setIsSwitchingSession(true);
|
||||
await new Promise<void>((resolve) => {
|
||||
const unsubscribe = onStreamComplete((completedId) => {
|
||||
if (completedId === currentSessionId) {
|
||||
clearTimeout(timeout);
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
const timeout = setTimeout(() => {
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}, 3000);
|
||||
stopStream(currentSessionId);
|
||||
});
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||
});
|
||||
setIsSwitchingSession(false);
|
||||
}
|
||||
|
||||
function selectSession(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
});
|
||||
}
|
||||
setUrlSessionId(sessionId, { shallow: false });
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function handleNewChat() {
|
||||
resetAutoSelect();
|
||||
function startNewChat() {
|
||||
resetPagination();
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
@@ -170,12 +123,31 @@ export function useCopilotShell() {
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function resetAutoSelect() {
|
||||
hasAutoSelectedRef.current = false;
|
||||
setHasAutoSelectedSession(false);
|
||||
function handleSessionClick(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
selectSession(sessionId);
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
selectSession(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
||||
function handleNewChatClick() {
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
startNewChat();
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
startNewChat();
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isMobile,
|
||||
@@ -183,17 +155,17 @@ export function useCopilotShell() {
|
||||
isLoggedIn,
|
||||
hasActiveSession:
|
||||
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
||||
isLoading,
|
||||
sessions: visibleSessions,
|
||||
currentSessionId: sidebarSelectedSessionId,
|
||||
handleSelectSession,
|
||||
isLoading: isLoading || isCreatingSession,
|
||||
isCreatingSession,
|
||||
sessions,
|
||||
currentSessionId: urlSessionId,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
handleNewChat,
|
||||
handleNewChatClick,
|
||||
handleSessionClick,
|
||||
hasNextPage,
|
||||
isFetchingNextPage: isSessionsFetching,
|
||||
fetchNextPage,
|
||||
isReadyToShowContent,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useMemo, useRef } from "react";
|
||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
||||
import {
|
||||
convertSessionDetailToSummary,
|
||||
filterVisibleSessions,
|
||||
mergeCurrentSessionIntoList,
|
||||
} from "./helpers";
|
||||
|
||||
interface UseShellSessionListArgs {
|
||||
paginationEnabled: boolean;
|
||||
currentSessionId: string | null;
|
||||
currentSessionData: SessionDetailResponse | null | undefined;
|
||||
isOnHomepage: boolean;
|
||||
paramSessionId: string | null;
|
||||
}
|
||||
|
||||
export function useShellSessionList({
|
||||
paginationEnabled,
|
||||
currentSessionId,
|
||||
currentSessionData,
|
||||
isOnHomepage,
|
||||
paramSessionId,
|
||||
}: UseShellSessionListArgs) {
|
||||
const queryClient = useQueryClient();
|
||||
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||
|
||||
const {
|
||||
sessions: accumulatedSessions,
|
||||
isLoading: isSessionsLoading,
|
||||
isFetching: isSessionsFetching,
|
||||
hasNextPage,
|
||||
fetchNextPage,
|
||||
reset: resetPagination,
|
||||
} = useSessionsPagination({
|
||||
enabled: paginationEnabled,
|
||||
});
|
||||
|
||||
const recentlyCreatedSessionsRef = useRef<
|
||||
Map<string, SessionSummaryResponse>
|
||||
>(new Map());
|
||||
|
||||
useEffect(() => {
|
||||
if (isOnHomepage && !paramSessionId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
}
|
||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
||||
|
||||
useEffect(() => {
|
||||
if (currentSessionId && currentSessionData) {
|
||||
const isNewSession =
|
||||
currentSessionData.updated_at === currentSessionData.created_at;
|
||||
const isNotInAccumulated = !accumulatedSessions.some(
|
||||
(s) => s.id === currentSessionId,
|
||||
);
|
||||
if (isNewSession || isNotInAccumulated) {
|
||||
const summary = convertSessionDetailToSummary(currentSessionData);
|
||||
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
||||
}
|
||||
}
|
||||
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
||||
|
||||
useEffect(() => {
|
||||
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
||||
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
||||
recentlyCreatedSessionsRef.current.delete(sessionId);
|
||||
}
|
||||
}
|
||||
}, [accumulatedSessions]);
|
||||
|
||||
useEffect(() => {
|
||||
const unsubscribe = onStreamComplete(() => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
});
|
||||
return unsubscribe;
|
||||
}, [onStreamComplete, queryClient]);
|
||||
|
||||
const sessions = useMemo(
|
||||
() =>
|
||||
mergeCurrentSessionIntoList(
|
||||
accumulatedSessions,
|
||||
currentSessionId,
|
||||
currentSessionData,
|
||||
recentlyCreatedSessionsRef.current,
|
||||
),
|
||||
[accumulatedSessions, currentSessionId, currentSessionData],
|
||||
);
|
||||
|
||||
const visibleSessions = useMemo(
|
||||
() => filterVisibleSessions(sessions),
|
||||
[sessions],
|
||||
);
|
||||
|
||||
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
||||
|
||||
return {
|
||||
sessions: visibleSessions,
|
||||
isLoading,
|
||||
isSessionsFetching,
|
||||
hasNextPage,
|
||||
fetchNextPage,
|
||||
resetPagination,
|
||||
recentlyCreatedSessionsRef,
|
||||
};
|
||||
}
|
||||
@@ -4,51 +4,53 @@ import { create } from "zustand";
|
||||
|
||||
interface CopilotStoreState {
|
||||
isStreaming: boolean;
|
||||
isNewChatModalOpen: boolean;
|
||||
newChatHandler: (() => void) | null;
|
||||
isSwitchingSession: boolean;
|
||||
isCreatingSession: boolean;
|
||||
isInterruptModalOpen: boolean;
|
||||
pendingAction: (() => void) | null;
|
||||
}
|
||||
|
||||
interface CopilotStoreActions {
|
||||
setIsStreaming: (isStreaming: boolean) => void;
|
||||
setNewChatHandler: (handler: (() => void) | null) => void;
|
||||
requestNewChat: () => void;
|
||||
confirmNewChat: () => void;
|
||||
cancelNewChat: () => void;
|
||||
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
|
||||
setIsCreatingSession: (isCreating: boolean) => void;
|
||||
openInterruptModal: (onConfirm: () => void) => void;
|
||||
confirmInterrupt: () => void;
|
||||
cancelInterrupt: () => void;
|
||||
}
|
||||
|
||||
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
||||
|
||||
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
||||
isStreaming: false,
|
||||
isNewChatModalOpen: false,
|
||||
newChatHandler: null,
|
||||
isSwitchingSession: false,
|
||||
isCreatingSession: false,
|
||||
isInterruptModalOpen: false,
|
||||
pendingAction: null,
|
||||
|
||||
setIsStreaming(isStreaming) {
|
||||
set({ isStreaming });
|
||||
},
|
||||
|
||||
setNewChatHandler(handler) {
|
||||
set({ newChatHandler: handler });
|
||||
setIsSwitchingSession(isSwitchingSession) {
|
||||
set({ isSwitchingSession });
|
||||
},
|
||||
|
||||
requestNewChat() {
|
||||
const { isStreaming, newChatHandler } = get();
|
||||
if (isStreaming) {
|
||||
set({ isNewChatModalOpen: true });
|
||||
} else if (newChatHandler) {
|
||||
newChatHandler();
|
||||
}
|
||||
setIsCreatingSession(isCreatingSession) {
|
||||
set({ isCreatingSession });
|
||||
},
|
||||
|
||||
confirmNewChat() {
|
||||
const { newChatHandler } = get();
|
||||
set({ isNewChatModalOpen: false });
|
||||
if (newChatHandler) {
|
||||
newChatHandler();
|
||||
}
|
||||
openInterruptModal(onConfirm) {
|
||||
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
|
||||
},
|
||||
|
||||
cancelNewChat() {
|
||||
set({ isNewChatModalOpen: false });
|
||||
confirmInterrupt() {
|
||||
const { pendingAction } = get();
|
||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||
if (pendingAction) pendingAction();
|
||||
},
|
||||
|
||||
cancelInterrupt() {
|
||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -1,28 +1,5 @@
|
||||
import type { User } from "@supabase/supabase-js";
|
||||
|
||||
export type PageState =
|
||||
| { type: "welcome" }
|
||||
| { type: "newChat" }
|
||||
| { type: "creating"; prompt: string }
|
||||
| { type: "chat"; sessionId: string; initialPrompt?: string };
|
||||
|
||||
export function getInitialPromptFromState(
|
||||
pageState: PageState,
|
||||
storedInitialPrompt: string | undefined,
|
||||
) {
|
||||
if (storedInitialPrompt) return storedInitialPrompt;
|
||||
if (pageState.type === "creating") return pageState.prompt;
|
||||
if (pageState.type === "chat") return pageState.initialPrompt;
|
||||
}
|
||||
|
||||
export function shouldResetToWelcome(pageState: PageState) {
|
||||
return (
|
||||
pageState.type !== "newChat" &&
|
||||
pageState.type !== "creating" &&
|
||||
pageState.type !== "welcome"
|
||||
);
|
||||
}
|
||||
|
||||
export function getGreetingName(user?: User | null): string {
|
||||
if (!user) return "there";
|
||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
export default function CopilotPage() {
|
||||
const { state, handlers } = useCopilotPage();
|
||||
const confirmNewChat = useCopilotStore((s) => s.confirmNewChat);
|
||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||
const {
|
||||
greetingName,
|
||||
quickActions,
|
||||
isLoading,
|
||||
pageState,
|
||||
isNewChatModalOpen,
|
||||
hasSession,
|
||||
initialPrompt,
|
||||
isReady,
|
||||
} = state;
|
||||
const {
|
||||
@@ -27,20 +27,16 @@ export default function CopilotPage() {
|
||||
startChatWithPrompt,
|
||||
handleSessionNotFound,
|
||||
handleStreamingChange,
|
||||
handleCancelNewChat,
|
||||
handleNewChatModalOpen,
|
||||
} = handlers;
|
||||
|
||||
if (!isReady) return null;
|
||||
|
||||
if (pageState.type === "chat") {
|
||||
if (hasSession) {
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
<Chat
|
||||
key={pageState.sessionId ?? "welcome"}
|
||||
className="flex-1"
|
||||
urlSessionId={pageState.sessionId}
|
||||
initialPrompt={pageState.initialPrompt}
|
||||
initialPrompt={initialPrompt}
|
||||
onSessionNotFound={handleSessionNotFound}
|
||||
onStreamingChange={handleStreamingChange}
|
||||
/>
|
||||
@@ -48,31 +44,33 @@ export default function CopilotPage() {
|
||||
title="Interrupt current chat?"
|
||||
styling={{ maxWidth: 300, width: "100%" }}
|
||||
controlled={{
|
||||
isOpen: isNewChatModalOpen,
|
||||
set: handleNewChatModalOpen,
|
||||
isOpen: isInterruptModalOpen,
|
||||
set: (open) => {
|
||||
if (!open) cancelInterrupt();
|
||||
},
|
||||
}}
|
||||
onClose={handleCancelNewChat}
|
||||
onClose={cancelInterrupt}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text variant="body">
|
||||
The current chat response will be interrupted. Are you sure you
|
||||
want to start a new chat?
|
||||
want to continue?
|
||||
</Text>
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={handleCancelNewChat}
|
||||
onClick={cancelInterrupt}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
onClick={confirmNewChat}
|
||||
onClick={confirmInterrupt}
|
||||
>
|
||||
Start new chat
|
||||
Continue
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
@@ -82,19 +80,6 @@ export default function CopilotPage() {
|
||||
);
|
||||
}
|
||||
|
||||
if (pageState.type === "newChat" || pageState.type === "creating") {
|
||||
return (
|
||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading your chats...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||
<div className="w-full text-center">
|
||||
|
||||
@@ -5,79 +5,40 @@ import {
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||
import {
|
||||
Flag,
|
||||
type FlagValues,
|
||||
useGetFlag,
|
||||
} from "@/services/feature-flags/use-get-flag";
|
||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useReducer } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
import { getGreetingName, getQuickActions, type PageState } from "./helpers";
|
||||
import { useCopilotURLState } from "./useCopilotURLState";
|
||||
|
||||
type CopilotState = {
|
||||
pageState: PageState;
|
||||
initialPrompts: Record<string, string>;
|
||||
previousSessionId: string | null;
|
||||
};
|
||||
|
||||
type CopilotAction =
|
||||
| { type: "setPageState"; pageState: PageState }
|
||||
| { type: "setInitialPrompt"; sessionId: string; prompt: string }
|
||||
| { type: "setPreviousSessionId"; sessionId: string | null };
|
||||
|
||||
function isSamePageState(next: PageState, current: PageState) {
|
||||
if (next.type !== current.type) return false;
|
||||
if (next.type === "creating" && current.type === "creating") {
|
||||
return next.prompt === current.prompt;
|
||||
}
|
||||
if (next.type === "chat" && current.type === "chat") {
|
||||
return (
|
||||
next.sessionId === current.sessionId &&
|
||||
next.initialPrompt === current.initialPrompt
|
||||
);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function copilotReducer(
|
||||
state: CopilotState,
|
||||
action: CopilotAction,
|
||||
): CopilotState {
|
||||
if (action.type === "setPageState") {
|
||||
if (isSamePageState(action.pageState, state.pageState)) return state;
|
||||
return { ...state, pageState: action.pageState };
|
||||
}
|
||||
if (action.type === "setInitialPrompt") {
|
||||
if (state.initialPrompts[action.sessionId] === action.prompt) return state;
|
||||
return {
|
||||
...state,
|
||||
initialPrompts: {
|
||||
...state.initialPrompts,
|
||||
[action.sessionId]: action.prompt,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (action.type === "setPreviousSessionId") {
|
||||
if (state.previousSessionId === action.sessionId) return state;
|
||||
return { ...state, previousSessionId: action.sessionId };
|
||||
}
|
||||
return state;
|
||||
}
|
||||
import { getGreetingName, getQuickActions } from "./helpers";
|
||||
import { useCopilotSessionId } from "./useCopilotSessionId";
|
||||
|
||||
export function useCopilotPage() {
|
||||
const router = useRouter();
|
||||
const queryClient = useQueryClient();
|
||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
const isNewChatModalOpen = useCopilotStore((s) => s.isNewChatModalOpen);
|
||||
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
||||
const cancelNewChat = useCopilotStore((s) => s.cancelNewChat);
|
||||
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||
|
||||
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||
useEffect(() => {
|
||||
if (isLoggedIn) {
|
||||
completeStep("VISIT_COPILOT");
|
||||
}
|
||||
}, [completeStep, isLoggedIn]);
|
||||
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const flags = useFlags<FlagValues>();
|
||||
@@ -88,72 +49,27 @@ export function useCopilotPage() {
|
||||
const isFlagReady =
|
||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||
|
||||
const [state, dispatch] = useReducer(copilotReducer, {
|
||||
pageState: { type: "welcome" },
|
||||
initialPrompts: {},
|
||||
previousSessionId: null,
|
||||
});
|
||||
|
||||
const greetingName = getGreetingName(user);
|
||||
const quickActions = getQuickActions();
|
||||
|
||||
function setPageState(pageState: PageState) {
|
||||
dispatch({ type: "setPageState", pageState });
|
||||
}
|
||||
const hasSession = Boolean(urlSessionId);
|
||||
const initialPrompt = urlSessionId
|
||||
? getInitialPrompt(urlSessionId)
|
||||
: undefined;
|
||||
|
||||
function setInitialPrompt(sessionId: string, prompt: string) {
|
||||
dispatch({ type: "setInitialPrompt", sessionId, prompt });
|
||||
}
|
||||
|
||||
function setPreviousSessionId(sessionId: string | null) {
|
||||
dispatch({ type: "setPreviousSessionId", sessionId });
|
||||
}
|
||||
|
||||
const { setUrlSessionId } = useCopilotURLState({
|
||||
pageState: state.pageState,
|
||||
initialPrompts: state.initialPrompts,
|
||||
previousSessionId: state.previousSessionId,
|
||||
setPageState,
|
||||
setInitialPrompt,
|
||||
setPreviousSessionId,
|
||||
});
|
||||
|
||||
useEffect(
|
||||
function transitionNewChatToWelcome() {
|
||||
if (state.pageState.type === "newChat") {
|
||||
function setWelcomeState() {
|
||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
||||
}
|
||||
|
||||
const timer = setTimeout(setWelcomeState, 300);
|
||||
|
||||
return function cleanup() {
|
||||
clearTimeout(timer);
|
||||
};
|
||||
}
|
||||
},
|
||||
[state.pageState.type],
|
||||
);
|
||||
|
||||
useEffect(
|
||||
function ensureAccess() {
|
||||
if (!isFlagReady) return;
|
||||
if (isChatEnabled === false) {
|
||||
router.replace(homepageRoute);
|
||||
}
|
||||
},
|
||||
[homepageRoute, isChatEnabled, isFlagReady, router],
|
||||
);
|
||||
useEffect(() => {
|
||||
if (!isFlagReady) return;
|
||||
if (isChatEnabled === false) {
|
||||
router.replace(homepageRoute);
|
||||
}
|
||||
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
|
||||
|
||||
async function startChatWithPrompt(prompt: string) {
|
||||
if (!prompt?.trim()) return;
|
||||
if (state.pageState.type === "creating") return;
|
||||
if (isCreating) return;
|
||||
|
||||
const trimmedPrompt = prompt.trim();
|
||||
dispatch({
|
||||
type: "setPageState",
|
||||
pageState: { type: "creating", prompt: trimmedPrompt },
|
||||
});
|
||||
setIsCreating(true);
|
||||
|
||||
try {
|
||||
const sessionResponse = await postV2CreateSession({
|
||||
@@ -165,27 +81,19 @@ export function useCopilotPage() {
|
||||
}
|
||||
|
||||
const sessionId = sessionResponse.data.id;
|
||||
|
||||
dispatch({
|
||||
type: "setInitialPrompt",
|
||||
sessionId,
|
||||
prompt: trimmedPrompt,
|
||||
});
|
||||
setInitialPrompt(sessionId, trimmedPrompt);
|
||||
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
|
||||
await setUrlSessionId(sessionId, { shallow: false });
|
||||
dispatch({
|
||||
type: "setPageState",
|
||||
pageState: { type: "chat", sessionId, initialPrompt: trimmedPrompt },
|
||||
});
|
||||
await setUrlSessionId(sessionId, { shallow: true });
|
||||
} catch (error) {
|
||||
console.error("[CopilotPage] Failed to start chat:", error);
|
||||
toast({ title: "Failed to start chat", variant: "destructive" });
|
||||
Sentry.captureException(error);
|
||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
||||
} finally {
|
||||
setIsCreating(false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,21 +109,13 @@ export function useCopilotPage() {
|
||||
setIsStreaming(isStreamingValue);
|
||||
}
|
||||
|
||||
function handleCancelNewChat() {
|
||||
cancelNewChat();
|
||||
}
|
||||
|
||||
function handleNewChatModalOpen(isOpen: boolean) {
|
||||
if (!isOpen) cancelNewChat();
|
||||
}
|
||||
|
||||
return {
|
||||
state: {
|
||||
greetingName,
|
||||
quickActions,
|
||||
isLoading: isUserLoading,
|
||||
pageState: state.pageState,
|
||||
isNewChatModalOpen,
|
||||
hasSession,
|
||||
initialPrompt,
|
||||
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||
},
|
||||
handlers: {
|
||||
@@ -223,8 +123,32 @@ export function useCopilotPage() {
|
||||
startChatWithPrompt,
|
||||
handleSessionNotFound,
|
||||
handleStreamingChange,
|
||||
handleCancelNewChat,
|
||||
handleNewChatModalOpen,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function getInitialPrompt(sessionId: string): string | undefined {
|
||||
try {
|
||||
const prompts = JSON.parse(
|
||||
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
|
||||
);
|
||||
return prompts[sessionId];
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
function setInitialPrompt(sessionId: string, prompt: string): void {
|
||||
try {
|
||||
const prompts = JSON.parse(
|
||||
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
|
||||
);
|
||||
prompts[sessionId] = prompt;
|
||||
sessionStorage.set(
|
||||
SessionKey.CHAT_INITIAL_PROMPTS,
|
||||
JSON.stringify(prompts),
|
||||
);
|
||||
} catch {
|
||||
// Ignore storage errors
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
|
||||
export function useCopilotSessionId() {
|
||||
const [urlSessionId, setUrlSessionId] = useQueryState(
|
||||
"sessionId",
|
||||
parseAsString,
|
||||
);
|
||||
|
||||
return { urlSessionId, setUrlSessionId };
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useLayoutEffect } from "react";
|
||||
import {
|
||||
getInitialPromptFromState,
|
||||
type PageState,
|
||||
shouldResetToWelcome,
|
||||
} from "./helpers";
|
||||
|
||||
interface UseCopilotUrlStateArgs {
|
||||
pageState: PageState;
|
||||
initialPrompts: Record<string, string>;
|
||||
previousSessionId: string | null;
|
||||
setPageState: (pageState: PageState) => void;
|
||||
setInitialPrompt: (sessionId: string, prompt: string) => void;
|
||||
setPreviousSessionId: (sessionId: string | null) => void;
|
||||
}
|
||||
|
||||
export function useCopilotURLState({
|
||||
pageState,
|
||||
initialPrompts,
|
||||
previousSessionId,
|
||||
setPageState,
|
||||
setInitialPrompt,
|
||||
setPreviousSessionId,
|
||||
}: UseCopilotUrlStateArgs) {
|
||||
const [urlSessionId, setUrlSessionId] = useQueryState(
|
||||
"sessionId",
|
||||
parseAsString,
|
||||
);
|
||||
|
||||
function syncSessionFromUrl() {
|
||||
if (urlSessionId) {
|
||||
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
|
||||
setPreviousSessionId(urlSessionId);
|
||||
return;
|
||||
}
|
||||
|
||||
const storedInitialPrompt = initialPrompts[urlSessionId];
|
||||
const currentInitialPrompt = getInitialPromptFromState(
|
||||
pageState,
|
||||
storedInitialPrompt,
|
||||
);
|
||||
|
||||
if (currentInitialPrompt) {
|
||||
setInitialPrompt(urlSessionId, currentInitialPrompt);
|
||||
}
|
||||
|
||||
setPageState({
|
||||
type: "chat",
|
||||
sessionId: urlSessionId,
|
||||
initialPrompt: currentInitialPrompt,
|
||||
});
|
||||
setPreviousSessionId(urlSessionId);
|
||||
return;
|
||||
}
|
||||
|
||||
const wasInChat = previousSessionId !== null && pageState.type === "chat";
|
||||
setPreviousSessionId(null);
|
||||
if (wasInChat) {
|
||||
setPageState({ type: "newChat" });
|
||||
return;
|
||||
}
|
||||
|
||||
if (shouldResetToWelcome(pageState)) {
|
||||
setPageState({ type: "welcome" });
|
||||
}
|
||||
}
|
||||
|
||||
useLayoutEffect(syncSessionFromUrl, [
|
||||
urlSessionId,
|
||||
pageState.type,
|
||||
previousSessionId,
|
||||
initialPrompts,
|
||||
]);
|
||||
|
||||
return {
|
||||
urlSessionId,
|
||||
setUrlSessionId,
|
||||
};
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
"use server";
|
||||
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { loginFormSchema } from "@/types/auth";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { shouldShowOnboarding } from "../../api/helpers";
|
||||
import { getOnboardingStatus } from "../../api/helpers";
|
||||
|
||||
export async function login(email: string, password: string) {
|
||||
try {
|
||||
@@ -36,11 +37,15 @@ export async function login(email: string, password: string) {
|
||||
const api = new BackendAPI();
|
||||
await api.createUser();
|
||||
|
||||
const onboarding = await shouldShowOnboarding();
|
||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||
const next = shouldShowOnboarding
|
||||
? "/onboarding"
|
||||
: getHomepageRoute(isChatEnabled);
|
||||
|
||||
return {
|
||||
success: true,
|
||||
onboarding,
|
||||
next,
|
||||
};
|
||||
} catch (err) {
|
||||
Sentry.captureException(err);
|
||||
|
||||
@@ -97,13 +97,8 @@ export function useLoginPage() {
|
||||
throw new Error(result.error || "Login failed");
|
||||
}
|
||||
|
||||
if (nextUrl) {
|
||||
router.replace(nextUrl);
|
||||
} else if (result.onboarding) {
|
||||
router.replace("/onboarding");
|
||||
} else {
|
||||
router.replace(homepageRoute);
|
||||
}
|
||||
// Prefer URL's next parameter, then use backend-determined route
|
||||
router.replace(nextUrl || result.next || homepageRoute);
|
||||
} catch (error) {
|
||||
toast({
|
||||
title:
|
||||
|
||||
@@ -5,14 +5,13 @@ import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { signupFormSchema } from "@/types/auth";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { isWaitlistError, logWaitlistError } from "../../api/auth/utils";
|
||||
import { shouldShowOnboarding } from "../../api/helpers";
|
||||
import { getOnboardingStatus } from "../../api/helpers";
|
||||
|
||||
export async function signup(
|
||||
email: string,
|
||||
password: string,
|
||||
confirmPassword: string,
|
||||
agreeToTerms: boolean,
|
||||
isChatEnabled: boolean,
|
||||
) {
|
||||
try {
|
||||
const parsed = signupFormSchema.safeParse({
|
||||
@@ -59,8 +58,9 @@ export async function signup(
|
||||
await supabase.auth.setSession(data.session);
|
||||
}
|
||||
|
||||
const isOnboardingEnabled = await shouldShowOnboarding();
|
||||
const next = isOnboardingEnabled
|
||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||
const next = shouldShowOnboarding
|
||||
? "/onboarding"
|
||||
: getHomepageRoute(isChatEnabled);
|
||||
|
||||
|
||||
@@ -108,7 +108,6 @@ export function useSignupPage() {
|
||||
data.password,
|
||||
data.confirmPassword,
|
||||
data.agreeToTerms,
|
||||
isChatEnabled === true,
|
||||
);
|
||||
|
||||
setIsLoading(false);
|
||||
|
||||
@@ -175,9 +175,12 @@ export async function resolveResponse<
|
||||
return res.data;
|
||||
}
|
||||
|
||||
export async function shouldShowOnboarding() {
|
||||
const isEnabled = await resolveResponse(getV1IsOnboardingEnabled());
|
||||
export async function getOnboardingStatus() {
|
||||
const status = await resolveResponse(getV1IsOnboardingEnabled());
|
||||
const onboarding = await resolveResponse(getV1OnboardingState());
|
||||
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
||||
return isEnabled && !isCompleted;
|
||||
return {
|
||||
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
|
||||
isChatEnabled: status.is_chat_enabled,
|
||||
};
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user