mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
12 Commits
feat/llm-p
...
feat/llm-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be328c1ec5 | ||
|
|
8410448c16 | ||
|
|
e168597663 | ||
|
|
1d903ae287 | ||
|
|
1be7aebdea | ||
|
|
36045c7007 | ||
|
|
445eb173a5 | ||
|
|
393a138fee | ||
|
|
ccc1e35c5b | ||
|
|
c66f114e28 | ||
|
|
939edc73b8 | ||
|
|
d52409c853 |
@@ -403,6 +403,11 @@ app.include_router(
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.admin_router,
|
||||
tags=["v2", "llm", "admin"],
|
||||
prefix="/api",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""LLM registry public API."""
|
||||
"""LLM registry API (public + admin)."""
|
||||
|
||||
from .admin_routes import router as admin_router
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
__all__ = ["router", "admin_router"]
|
||||
|
||||
115
autogpt_platform/backend/backend/server/v2/llm/admin_model.py
Normal file
115
autogpt_platform/backend/backend/server/v2/llm/admin_model.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Request/response models for LLM registry admin API."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateLlmProviderRequest(BaseModel):
|
||||
"""Request model for creating an LLM provider."""
|
||||
|
||||
name: str = Field(
|
||||
..., description="Provider identifier (e.g., 'openai', 'anthropic')"
|
||||
)
|
||||
display_name: str = Field(..., description="Human-readable provider name")
|
||||
description: str | None = Field(None, description="Provider description")
|
||||
default_credential_provider: str | None = Field(
|
||||
None, description="Default credential system identifier"
|
||||
)
|
||||
default_credential_id: str | None = Field(None, description="Default credential ID")
|
||||
default_credential_type: str | None = Field(
|
||||
None, description="Default credential type"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
|
||||
class UpdateLlmProviderRequest(BaseModel):
|
||||
"""Request model for updating an LLM provider."""
|
||||
|
||||
display_name: str | None = Field(None, description="Human-readable provider name")
|
||||
description: str | None = Field(None, description="Provider description")
|
||||
default_credential_provider: str | None = Field(
|
||||
None, description="Default credential system identifier"
|
||||
)
|
||||
default_credential_id: str | None = Field(None, description="Default credential ID")
|
||||
default_credential_type: str | None = Field(
|
||||
None, description="Default credential type"
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")
|
||||
|
||||
|
||||
class CreateLlmModelRequest(BaseModel):
|
||||
"""Request model for creating an LLM model."""
|
||||
|
||||
slug: str = Field(..., description="Model slug (e.g., 'gpt-4', 'claude-3-opus')")
|
||||
display_name: str = Field(..., description="Human-readable model name")
|
||||
description: str | None = Field(None, description="Model description")
|
||||
provider_id: str = Field(..., description="Provider ID (UUID)")
|
||||
creator_id: str | None = Field(None, description="Creator ID (UUID)")
|
||||
context_window: int = Field(
|
||||
..., description="Maximum context window in tokens", gt=0
|
||||
)
|
||||
max_output_tokens: int | None = Field(
|
||||
None, description="Maximum output tokens (None if unlimited)", gt=0
|
||||
)
|
||||
price_tier: int = Field(
|
||||
..., description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3
|
||||
)
|
||||
is_enabled: bool = Field(default=True, description="Whether the model is enabled")
|
||||
is_recommended: bool = Field(
|
||||
default=False, description="Whether the model is recommended"
|
||||
)
|
||||
supports_tools: bool = Field(default=False, description="Supports function calling")
|
||||
supports_json_output: bool = Field(
|
||||
default=False, description="Supports JSON output mode"
|
||||
)
|
||||
supports_reasoning: bool = Field(
|
||||
default=False, description="Supports reasoning mode"
|
||||
)
|
||||
supports_parallel_tool_calls: bool = Field(
|
||||
default=False, description="Supports parallel tool calls"
|
||||
)
|
||||
capabilities: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional capabilities"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
costs: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Cost entries for the model"
|
||||
)
|
||||
|
||||
|
||||
class UpdateLlmModelRequest(BaseModel):
|
||||
"""Request model for updating an LLM model."""
|
||||
|
||||
display_name: str | None = Field(None, description="Human-readable model name")
|
||||
description: str | None = Field(None, description="Model description")
|
||||
creator_id: str | None = Field(None, description="Creator ID (UUID)")
|
||||
context_window: int | None = Field(
|
||||
None, description="Maximum context window in tokens", gt=0
|
||||
)
|
||||
max_output_tokens: int | None = Field(
|
||||
None, description="Maximum output tokens (None if unlimited)", gt=0
|
||||
)
|
||||
price_tier: int | None = Field(
|
||||
None, description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3
|
||||
)
|
||||
is_enabled: bool | None = Field(None, description="Whether the model is enabled")
|
||||
is_recommended: bool | None = Field(
|
||||
None, description="Whether the model is recommended"
|
||||
)
|
||||
supports_tools: bool | None = Field(None, description="Supports function calling")
|
||||
supports_json_output: bool | None = Field(
|
||||
None, description="Supports JSON output mode"
|
||||
)
|
||||
supports_reasoning: bool | None = Field(None, description="Supports reasoning mode")
|
||||
supports_parallel_tool_calls: bool | None = Field(
|
||||
None, description="Supports parallel tool calls"
|
||||
)
|
||||
capabilities: dict[str, Any] | None = Field(
|
||||
None, description="Additional capabilities"
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")
|
||||
689
autogpt_platform/backend/backend/server/v2/llm/admin_routes.py
Normal file
689
autogpt_platform/backend/backend/server/v2/llm/admin_routes.py
Normal file
@@ -0,0 +1,689 @@
|
||||
"""Admin API for LLM registry management.
|
||||
|
||||
Provides endpoints for:
|
||||
- Reading creators (GET)
|
||||
- Creating, updating, and deleting models
|
||||
- Creating, updating, and deleting providers
|
||||
|
||||
All endpoints require admin authentication. Mutations refresh the registry cache.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
import autogpt_libs.auth
|
||||
from fastapi import APIRouter, HTTPException, Security, status
|
||||
|
||||
from backend.server.v2.llm import db_write
|
||||
from backend.server.v2.llm.admin_model import (
|
||||
CreateLlmModelRequest,
|
||||
CreateLlmProviderRequest,
|
||||
UpdateLlmModelRequest,
|
||||
UpdateLlmProviderRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _map_provider_response(provider: Any) -> dict[str, Any]:
|
||||
"""Map Prisma provider model to response dict."""
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.displayName,
|
||||
"description": provider.description,
|
||||
"default_credential_provider": provider.defaultCredentialProvider,
|
||||
"default_credential_id": provider.defaultCredentialId,
|
||||
"default_credential_type": provider.defaultCredentialType,
|
||||
"metadata": dict(provider.metadata or {}),
|
||||
"created_at": provider.createdAt.isoformat() if provider.createdAt else None,
|
||||
"updated_at": provider.updatedAt.isoformat() if provider.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
def _map_model_response(model: Any) -> dict[str, Any]:
|
||||
"""Map Prisma model to response dict."""
|
||||
return {
|
||||
"id": model.id,
|
||||
"slug": model.slug,
|
||||
"display_name": model.displayName,
|
||||
"description": model.description,
|
||||
"provider_id": model.providerId,
|
||||
"creator_id": model.creatorId,
|
||||
"context_window": model.contextWindow,
|
||||
"max_output_tokens": model.maxOutputTokens,
|
||||
"price_tier": model.priceTier,
|
||||
"is_enabled": model.isEnabled,
|
||||
"is_recommended": model.isRecommended,
|
||||
"supports_tools": model.supportsTools,
|
||||
"supports_json_output": model.supportsJsonOutput,
|
||||
"supports_reasoning": model.supportsReasoning,
|
||||
"supports_parallel_tool_calls": model.supportsParallelToolCalls,
|
||||
"capabilities": dict(model.capabilities or {}),
|
||||
"metadata": dict(model.metadata or {}),
|
||||
"created_at": model.createdAt.isoformat() if model.createdAt else None,
|
||||
"updated_at": model.updatedAt.isoformat() if model.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
def _map_creator_response(creator: Any) -> dict[str, Any]:
|
||||
"""Map Prisma creator model to response dict."""
|
||||
return {
|
||||
"id": creator.id,
|
||||
"name": creator.name,
|
||||
"display_name": creator.displayName,
|
||||
"description": creator.description,
|
||||
"website_url": creator.websiteUrl,
|
||||
"logo_url": creator.logoUrl,
|
||||
"metadata": dict(creator.metadata or {}),
|
||||
"created_at": creator.createdAt.isoformat() if creator.createdAt else None,
|
||||
"updated_at": creator.updatedAt.isoformat() if creator.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_model(
|
||||
request: CreateLlmModelRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new LLM model.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models as pm
|
||||
|
||||
# Resolve provider name to ID
|
||||
provider = await pm.LlmProvider.prisma().find_unique(
|
||||
where={"name": request.provider_id}
|
||||
)
|
||||
if not provider:
|
||||
# Try as UUID fallback
|
||||
provider = await pm.LlmProvider.prisma().find_unique(
|
||||
where={"id": request.provider_id}
|
||||
)
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Provider '{request.provider_id}' not found",
|
||||
)
|
||||
|
||||
model = await db_write.create_model(
|
||||
slug=request.slug,
|
||||
display_name=request.display_name,
|
||||
provider_id=provider.id,
|
||||
context_window=request.context_window,
|
||||
price_tier=request.price_tier,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
# Create costs if provided in the raw request body
|
||||
if hasattr(request, 'costs') and request.costs:
|
||||
for cost_input in request.costs:
|
||||
await pm.LlmModelCost.prisma().create(
|
||||
data={
|
||||
"unit": cost_input.get("unit", "RUN"),
|
||||
"creditCost": int(cost_input.get("credit_cost", 1)),
|
||||
"credentialProvider": provider.name,
|
||||
"metadata": prisma.Json(cost_input.get("metadata", {})),
|
||||
"Model": {"connect": {"id": model.id}},
|
||||
}
|
||||
)
|
||||
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created model '{request.slug}' (id: {model.id})")
|
||||
|
||||
# Re-fetch with costs included
|
||||
model = await pm.LlmModel.prisma().find_unique(
|
||||
where={"id": model.id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return _map_model_response(model)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model creation validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create model")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/models/{slug:path}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_model(
|
||||
slug: str,
|
||||
request: UpdateLlmModelRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing LLM model.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
# Find model by slug first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
model = await db_write.update_model(
|
||||
model_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
context_window=request.context_window,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
price_tier=request.price_tier,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated model '{slug}' (id: {model.id})")
|
||||
return _map_model_response(model)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model update validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update model")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/models/{slug:path}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_model(
|
||||
slug: str,
|
||||
replacement_model_slug: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete an LLM model with optional migration.
|
||||
|
||||
If workflows are using this model and no replacement_model_slug is given,
|
||||
returns 400 with the node count. Provide replacement_model_slug to migrate
|
||||
affected nodes before deletion.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
result = await db_write.delete_model(
|
||||
model_id=existing.id,
|
||||
replacement_model_slug=replacement_model_slug,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Deleted model '{slug}' (migrated {result['nodes_migrated']} nodes)"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model deletion validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete model")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/models/{slug:path}/usage",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def get_model_usage(slug: str) -> dict[str, Any]:
|
||||
"""Get usage count for a model — how many workflow nodes reference it."""
|
||||
try:
|
||||
return await db_write.get_model_usage(slug)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get model usage: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get model usage")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models/{slug:path}/toggle",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def toggle_model(
|
||||
slug: str,
|
||||
request: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Toggle a model's enabled status with optional migration when disabling.
|
||||
|
||||
Body params:
|
||||
is_enabled: bool
|
||||
migrate_to_slug: optional str
|
||||
migration_reason: optional str
|
||||
custom_credit_cost: optional int
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id=existing.id,
|
||||
is_enabled=request.get("is_enabled", True),
|
||||
migrate_to_slug=request.get("migrate_to_slug"),
|
||||
migration_reason=request.get("migration_reason"),
|
||||
custom_credit_cost=request.get("custom_credit_cost"),
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Toggled model '{slug}' enabled={request.get('is_enabled')} "
|
||||
f"(migrated {result['nodes_migrated']} nodes)"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model toggle failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to toggle model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to toggle model")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/migrations",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""List model migrations."""
|
||||
try:
|
||||
migrations = await db_write.list_migrations(
|
||||
include_reverted=include_reverted
|
||||
)
|
||||
return {"migrations": migrations}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list migrations: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to list migrations"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/migrations/{migration_id}/revert",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Revert a model migration, restoring affected nodes."""
|
||||
try:
|
||||
result = await db_write.revert_migration(
|
||||
migration_id=migration_id,
|
||||
re_enable_source_model=re_enable_source_model,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Reverted migration {migration_id}: "
|
||||
f"{result['nodes_reverted']} nodes restored"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.warning(f"Migration revert failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to revert migration: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to revert migration"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/providers",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_provider(
|
||||
request: CreateLlmProviderRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new LLM provider.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
provider = await db_write.create_provider(
|
||||
name=request.name,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created provider '{request.name}' (id: {provider.id})")
|
||||
return _map_provider_response(provider)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider creation validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create provider")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/providers/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_provider(
|
||||
name: str,
|
||||
request: UpdateLlmProviderRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing LLM provider.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
# Find provider by name first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
provider = await db_write.update_provider(
|
||||
provider_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated provider '{name}' (id: {provider.id})")
|
||||
return _map_provider_response(provider)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider update validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update provider")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/providers/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_provider(
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
Requires admin authentication.
|
||||
A provider can only be deleted if it has no associated models.
|
||||
"""
|
||||
try:
|
||||
# Find provider by name first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
await db_write.delete_provider(provider_id=existing.id)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Deleted provider '{name}' (id: {existing.id})")
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider deletion validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete provider")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/admin/providers",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def admin_list_providers() -> dict[str, Any]:
|
||||
"""List all LLM providers from the database.
|
||||
|
||||
Unlike the public endpoint, this returns ALL providers including
|
||||
those with no models. Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
providers = await prisma.models.LlmProvider.prisma().find_many(
|
||||
order={"name": "asc"},
|
||||
include={"Models": True},
|
||||
)
|
||||
return {
|
||||
"providers": [
|
||||
{**_map_provider_response(p), "model_count": len(p.Models) if p.Models else 0}
|
||||
for p in providers
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list providers: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list providers")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/admin/models",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def admin_list_models(
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
enabled_only: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""List all LLM models from the database.
|
||||
|
||||
Unlike the public endpoint, this returns full model data including
|
||||
costs and creator info. Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
where = {"isEnabled": True} if enabled_only else {}
|
||||
models = await prisma.models.LlmModel.prisma().find_many(
|
||||
where=where,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
order={"displayName": "asc"},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
**_map_model_response(m),
|
||||
"creator": _map_creator_response(m.Creator) if m.Creator else None,
|
||||
"costs": [
|
||||
{
|
||||
"unit": c.unit,
|
||||
"credit_cost": float(c.creditCost),
|
||||
"credential_provider": c.credentialProvider,
|
||||
"credential_type": c.credentialType,
|
||||
"metadata": dict(c.metadata or {}),
|
||||
}
|
||||
for c in (m.Costs or [])
|
||||
],
|
||||
}
|
||||
for m in models
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list models: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list models")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/creators",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def list_creators() -> dict[str, Any]:
|
||||
"""List all LLM model creators.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
creators = await prisma.models.LlmModelCreator.prisma().find_many(
|
||||
order={"name": "asc"}
|
||||
)
|
||||
logger.info(f"Retrieved {len(creators)} creators")
|
||||
return {"creators": [_map_creator_response(c) for c in creators]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list creators: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list creators")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/creators",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_creator(
|
||||
request: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new LLM model creator."""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
creator = await prisma.models.LlmModelCreator.prisma().create(
|
||||
data={
|
||||
"name": request["name"],
|
||||
"displayName": request["display_name"],
|
||||
"description": request.get("description"),
|
||||
"websiteUrl": request.get("website_url"),
|
||||
"logoUrl": request.get("logo_url"),
|
||||
"metadata": prisma.Json(request.get("metadata", {})),
|
||||
}
|
||||
)
|
||||
logger.info(f"Created creator '{creator.name}' (id: {creator.id})")
|
||||
return _map_creator_response(creator)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/creators/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_creator(
|
||||
name: str,
|
||||
request: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing LLM model creator."""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Creator '{name}' not found"
|
||||
)
|
||||
|
||||
data: dict[str, Any] = {}
|
||||
if "display_name" in request:
|
||||
data["displayName"] = request["display_name"]
|
||||
if "description" in request:
|
||||
data["description"] = request["description"]
|
||||
if "website_url" in request:
|
||||
data["websiteUrl"] = request["website_url"]
|
||||
if "logo_url" in request:
|
||||
data["logoUrl"] = request["logo_url"]
|
||||
|
||||
creator = await prisma.models.LlmModelCreator.prisma().update(
|
||||
where={"id": existing.id},
|
||||
data=data,
|
||||
)
|
||||
logger.info(f"Updated creator '{name}' (id: {creator.id})")
|
||||
return _map_creator_response(creator)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/creators/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_creator(
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete an LLM model creator."""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"name": name},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Creator '{name}' not found"
|
||||
)
|
||||
|
||||
if existing.Models and len(existing.Models) > 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot delete creator '{name}' — it has {len(existing.Models)} associated models",
|
||||
)
|
||||
|
||||
await prisma.models.LlmModelCreator.prisma().delete(
|
||||
where={"id": existing.id}
|
||||
)
|
||||
logger.info(f"Deleted creator '{name}' (id: {existing.id})")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
588
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
588
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
@@ -0,0 +1,588 @@
|
||||
"""Database write operations for LLM registry admin API."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
import prisma.models
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.db import transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_provider_data(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build provider data dict for Prisma operations."""
|
||||
return {
|
||||
"name": name,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"defaultCredentialProvider": default_credential_provider,
|
||||
"defaultCredentialId": default_credential_id,
|
||||
"defaultCredentialType": default_credential_type,
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
|
||||
|
||||
def _build_model_data(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build model data dict for Prisma operations."""
|
||||
data: dict[str, Any] = {
|
||||
"slug": slug,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"Provider": {"connect": {"id": provider_id}},
|
||||
"contextWindow": context_window,
|
||||
"maxOutputTokens": max_output_tokens,
|
||||
"priceTier": price_tier,
|
||||
"isEnabled": is_enabled,
|
||||
"isRecommended": is_recommended,
|
||||
"supportsTools": supports_tools,
|
||||
"supportsJsonOutput": supports_json_output,
|
||||
"supportsReasoning": supports_reasoning,
|
||||
"supportsParallelToolCalls": supports_parallel_tool_calls,
|
||||
"capabilities": prisma.Json(capabilities or {}),
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
if creator_id:
|
||||
data["Creator"] = {"connect": {"id": creator_id}}
|
||||
return data
|
||||
|
||||
|
||||
async def create_provider(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Create a new LLM provider."""
|
||||
data = _build_provider_data(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
default_credential_provider=default_credential_provider,
|
||||
default_credential_id=default_credential_id,
|
||||
default_credential_type=default_credential_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
provider = await prisma.models.LlmProvider.prisma().create(
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError("Failed to create provider")
|
||||
return provider
|
||||
|
||||
|
||||
async def update_provider(
|
||||
provider_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Update an existing LLM provider."""
|
||||
# Fetch existing provider to get current name
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id}
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if default_credential_provider is not None:
|
||||
data["defaultCredentialProvider"] = default_credential_provider
|
||||
if default_credential_id is not None:
|
||||
data["defaultCredentialId"] = default_credential_id
|
||||
if default_credential_type is not None:
|
||||
data["defaultCredentialType"] = default_credential_type
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
|
||||
updated = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not updated:
|
||||
raise ValueError("Failed to update provider")
|
||||
return updated
|
||||
|
||||
|
||||
async def delete_provider(provider_id: str) -> bool:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
"""
|
||||
# Check if provider exists
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Check if provider has any models
|
||||
model_count = len(provider.Models) if provider.Models else 0
|
||||
if model_count > 0:
|
||||
raise ValueError(
|
||||
f"Cannot delete provider '{provider.displayName}' because it has "
|
||||
f"{model_count} model(s). Delete all models first."
|
||||
)
|
||||
|
||||
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
|
||||
return True
|
||||
|
||||
|
||||
async def create_model(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Create a new LLM model."""
|
||||
data = _build_model_data(
|
||||
slug=slug,
|
||||
display_name=display_name,
|
||||
provider_id=provider_id,
|
||||
context_window=context_window,
|
||||
price_tier=price_tier,
|
||||
description=description,
|
||||
creator_id=creator_id,
|
||||
max_output_tokens=max_output_tokens,
|
||||
is_enabled=is_enabled,
|
||||
is_recommended=is_recommended,
|
||||
supports_tools=supports_tools,
|
||||
supports_json_output=supports_json_output,
|
||||
supports_reasoning=supports_reasoning,
|
||||
supports_parallel_tool_calls=supports_parallel_tool_calls,
|
||||
capabilities=capabilities,
|
||||
metadata=metadata,
|
||||
)
|
||||
model = await prisma.models.LlmModel.prisma().create(
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
if not model:
|
||||
raise ValueError("Failed to create model")
|
||||
return model
|
||||
|
||||
|
||||
async def update_model(
|
||||
model_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
context_window: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
price_tier: int | None = None,
|
||||
is_enabled: bool | None = None,
|
||||
is_recommended: bool | None = None,
|
||||
supports_tools: bool | None = None,
|
||||
supports_json_output: bool | None = None,
|
||||
supports_reasoning: bool | None = None,
|
||||
supports_parallel_tool_calls: bool | None = None,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Update an existing LLM model.
|
||||
|
||||
When is_recommended=True, clears the flag on all other models first so
|
||||
only one model can be recommended at a time.
|
||||
"""
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if context_window is not None:
|
||||
data["contextWindow"] = context_window
|
||||
if max_output_tokens is not None:
|
||||
data["maxOutputTokens"] = max_output_tokens
|
||||
if price_tier is not None:
|
||||
data["priceTier"] = price_tier
|
||||
if is_enabled is not None:
|
||||
data["isEnabled"] = is_enabled
|
||||
if is_recommended is not None:
|
||||
data["isRecommended"] = is_recommended
|
||||
if supports_tools is not None:
|
||||
data["supportsTools"] = supports_tools
|
||||
if supports_json_output is not None:
|
||||
data["supportsJsonOutput"] = supports_json_output
|
||||
if supports_reasoning is not None:
|
||||
data["supportsReasoning"] = supports_reasoning
|
||||
if supports_parallel_tool_calls is not None:
|
||||
data["supportsParallelToolCalls"] = supports_parallel_tool_calls
|
||||
if capabilities is not None:
|
||||
data["capabilities"] = prisma.Json(capabilities)
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
if creator_id is not None:
|
||||
data["creatorId"] = creator_id if creator_id else None
|
||||
|
||||
async with transaction() as tx:
|
||||
# Enforce single recommended model: unset all others first.
|
||||
if is_recommended is True:
|
||||
await tx.llmmodel.update_many(
|
||||
where={"id": {"not": model_id}},
|
||||
data={"isRecommended": False},
|
||||
)
|
||||
|
||||
model = await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return model
|
||||
|
||||
|
||||
async def get_model_usage(slug: str) -> dict[str, Any]:
|
||||
"""Get usage count for a model — how many AgentNodes reference it."""
|
||||
import prisma as prisma_module
|
||||
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
slug,
|
||||
)
|
||||
node_count = int(count_result[0]["count"]) if count_result else 0
|
||||
return {"model_slug": slug, "node_count": node_count}
|
||||
|
||||
|
||||
async def toggle_model_with_migration(
|
||||
model_id: str,
|
||||
is_enabled: bool,
|
||||
migrate_to_slug: str | None = None,
|
||||
migration_reason: str | None = None,
|
||||
custom_credit_cost: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Toggle a model's enabled status, optionally migrating workflows when disabling."""
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
nodes_migrated = 0
|
||||
migration_id: str | None = None
|
||||
|
||||
if not is_enabled and migrate_to_slug:
|
||||
async with transaction() as tx:
|
||||
replacement = await tx.llmmodel.find_unique(
|
||||
where={"slug": migrate_to_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' not found"
|
||||
)
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
node_ids_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT id
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
FOR UPDATE
|
||||
""",
|
||||
model.slug,
|
||||
)
|
||||
migrated_node_ids = (
|
||||
[row["id"] for row in node_ids_result] if node_ids_result else []
|
||||
)
|
||||
nodes_migrated = len(migrated_node_ids)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
""",
|
||||
migrate_to_slug,
|
||||
node_ids_json,
|
||||
)
|
||||
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
migration_record = await tx.llmmodelmigration.create(
|
||||
data={
|
||||
"sourceModelSlug": model.slug,
|
||||
"targetModelSlug": migrate_to_slug,
|
||||
"reason": migration_reason,
|
||||
"migratedNodeIds": json.dumps(migrated_node_ids),
|
||||
"nodeCount": nodes_migrated,
|
||||
"customCreditCost": custom_credit_cost,
|
||||
}
|
||||
)
|
||||
migration_id = migration_record.id
|
||||
else:
|
||||
await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
)
|
||||
|
||||
return {
|
||||
"nodes_migrated": nodes_migrated,
|
||||
"migrated_to_slug": migrate_to_slug if nodes_migrated > 0 else None,
|
||||
"migration_id": migration_id,
|
||||
}
|
||||
|
||||
|
||||
async def delete_model(
|
||||
model_id: str, replacement_model_slug: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Delete an LLM model, optionally migrating affected AgentNodes first.
|
||||
|
||||
If workflows are using this model and no replacement is given, raises ValueError.
|
||||
If replacement is given, atomically migrates all affected nodes then deletes.
|
||||
"""
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
deleted_slug = model.slug
|
||||
deleted_display_name = model.displayName
|
||||
|
||||
async with transaction() as tx:
|
||||
count_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
deleted_slug,
|
||||
)
|
||||
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
if nodes_to_migrate > 0:
|
||||
if not replacement_model_slug:
|
||||
raise ValueError(
|
||||
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
|
||||
f"are using it. Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
replacement = await tx.llmmodel.find_unique(
|
||||
where={"slug": replacement_model_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' not found"
|
||||
)
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' is disabled."
|
||||
)
|
||||
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug,
|
||||
)
|
||||
|
||||
await tx.llmmodel.delete(where={"id": model_id})
|
||||
|
||||
return {
|
||||
"deleted_model_slug": deleted_slug,
|
||||
"deleted_model_display_name": deleted_display_name,
|
||||
"replacement_model_slug": replacement_model_slug,
|
||||
"nodes_migrated": nodes_to_migrate,
|
||||
}
|
||||
|
||||
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List model migrations."""
|
||||
where: Any = None if include_reverted else {"isReverted": False}
|
||||
records = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||
where=where,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"source_model_slug": r.sourceModelSlug,
|
||||
"target_model_slug": r.targetModelSlug,
|
||||
"reason": r.reason,
|
||||
"node_count": r.nodeCount,
|
||||
"custom_credit_cost": r.customCreditCost,
|
||||
"is_reverted": r.isReverted,
|
||||
"reverted_at": r.revertedAt.isoformat() if r.revertedAt else None,
|
||||
"created_at": r.createdAt.isoformat(),
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
|
||||
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Revert a model migration, restoring affected nodes to their original model."""
|
||||
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
if not migration:
|
||||
raise ValueError(f"Migration with id '{migration_id}' not found")
|
||||
|
||||
if migration.isReverted:
|
||||
raise ValueError(
|
||||
f"Migration '{migration_id}' has already been reverted"
|
||||
)
|
||||
|
||||
source_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migration.sourceModelSlug}
|
||||
)
|
||||
if not source_model:
|
||||
raise ValueError(
|
||||
f"Source model '{migration.sourceModelSlug}' no longer exists."
|
||||
)
|
||||
|
||||
migrated_node_ids: list[str] = (
|
||||
migration.migratedNodeIds
|
||||
if isinstance(migration.migratedNodeIds, list)
|
||||
else json.loads(migration.migratedNodeIds) # type: ignore
|
||||
)
|
||||
if not migrated_node_ids:
|
||||
raise ValueError("No nodes to revert in this migration")
|
||||
|
||||
source_model_re_enabled = False
|
||||
|
||||
async with transaction() as tx:
|
||||
if not source_model.isEnabled and re_enable_source_model:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": source_model.id},
|
||||
data={"isEnabled": True},
|
||||
)
|
||||
source_model_re_enabled = True
|
||||
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
result = await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
AND "constantInput"::jsonb->>'model' = $3
|
||||
""",
|
||||
migration.sourceModelSlug,
|
||||
node_ids_json,
|
||||
migration.targetModelSlug,
|
||||
)
|
||||
nodes_reverted = result if isinstance(result, int) else 0
|
||||
|
||||
await tx.llmmodelmigration.update(
|
||||
where={"id": migration_id},
|
||||
data={
|
||||
"isReverted": True,
|
||||
"revertedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"migration_id": migration_id,
|
||||
"source_model_slug": migration.sourceModelSlug,
|
||||
"target_model_slug": migration.targetModelSlug,
|
||||
"nodes_reverted": nodes_reverted,
|
||||
"nodes_already_changed": len(migrated_node_ids) - nodes_reverted,
|
||||
"source_model_re_enabled": source_model_re_enabled,
|
||||
}
|
||||
|
||||
|
||||
async def refresh_runtime_caches() -> None:
|
||||
"""Invalidate the shared Redis cache, refresh this process, notify other workers."""
|
||||
from backend.data.llm_registry.notifications import (
|
||||
publish_registry_refresh_notification,
|
||||
)
|
||||
|
||||
# Invalidate Redis so the next fetch hits the DB.
|
||||
llm_registry.clear_registry_cache()
|
||||
# Refresh this process (also repopulates Redis via @cached(shared_cache=True)).
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Tell other workers to reload their in-process cache from the fresh Redis data.
|
||||
await publish_registry_refresh_notification()
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user