mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
feat(platform): Implement LLM registry admin API functionality
Implement full CRUD operations for admin API: Database layer (db_write.py): - create_provider, update_provider, delete_provider - create_model, update_model, delete_model - refresh_runtime_caches - invalidates in-memory registry after mutations - Proper validation and error handling Admin routes (admin_routes.py): - All endpoints now functional (no more 501) - Proper error responses (400 for validation, 404 for not found, 500 for server errors) - Lookup by slug/name before operations - Cache refresh after all mutations Features: - Provider deletion blocked if models exist (FK constraint) - All mutations refresh registry cache automatically - Proper logging for audit trail - Admin auth enforced on all endpoints Based on original implementation from PR #11699 (upstream-llm branch). Builds on: - PR #12357: Schema foundation - PR #12359: Registry core - PR #12371: Public read API
This commit is contained in:
@@ -8,15 +8,21 @@ from pydantic import BaseModel, Field
|
||||
class CreateLlmProviderRequest(BaseModel):
|
||||
"""Request model for creating an LLM provider."""
|
||||
|
||||
name: str = Field(..., description="Provider identifier (e.g., 'openai', 'anthropic')")
|
||||
name: str = Field(
|
||||
..., description="Provider identifier (e.g., 'openai', 'anthropic')"
|
||||
)
|
||||
display_name: str = Field(..., description="Human-readable provider name")
|
||||
description: str | None = Field(None, description="Provider description")
|
||||
default_credential_provider: str | None = Field(
|
||||
None, description="Default credential system identifier"
|
||||
)
|
||||
default_credential_id: str | None = Field(None, description="Default credential ID")
|
||||
default_credential_type: str | None = Field(None, description="Default credential type")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
default_credential_type: str | None = Field(
|
||||
None, description="Default credential type"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
|
||||
class UpdateLlmProviderRequest(BaseModel):
|
||||
@@ -28,7 +34,9 @@ class UpdateLlmProviderRequest(BaseModel):
|
||||
None, description="Default credential system identifier"
|
||||
)
|
||||
default_credential_id: str | None = Field(None, description="Default credential ID")
|
||||
default_credential_type: str | None = Field(None, description="Default credential type")
|
||||
default_credential_type: str | None = Field(
|
||||
None, description="Default credential type"
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")
|
||||
|
||||
|
||||
@@ -40,21 +48,35 @@ class CreateLlmModelRequest(BaseModel):
|
||||
description: str | None = Field(None, description="Model description")
|
||||
provider_id: str = Field(..., description="Provider ID (UUID)")
|
||||
creator_id: str | None = Field(None, description="Creator ID (UUID)")
|
||||
context_window: int = Field(..., description="Maximum context window in tokens", gt=0)
|
||||
context_window: int = Field(
|
||||
..., description="Maximum context window in tokens", gt=0
|
||||
)
|
||||
max_output_tokens: int | None = Field(
|
||||
None, description="Maximum output tokens (None if unlimited)", gt=0
|
||||
)
|
||||
price_tier: int = Field(..., description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3)
|
||||
price_tier: int = Field(
|
||||
..., description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3
|
||||
)
|
||||
is_enabled: bool = Field(default=True, description="Whether the model is enabled")
|
||||
is_recommended: bool = Field(default=False, description="Whether the model is recommended")
|
||||
is_recommended: bool = Field(
|
||||
default=False, description="Whether the model is recommended"
|
||||
)
|
||||
supports_tools: bool = Field(default=False, description="Supports function calling")
|
||||
supports_json_output: bool = Field(default=False, description="Supports JSON output mode")
|
||||
supports_reasoning: bool = Field(default=False, description="Supports reasoning mode")
|
||||
supports_parallel_tool_calls: bool = Field(default=False, description="Supports parallel tool calls")
|
||||
supports_json_output: bool = Field(
|
||||
default=False, description="Supports JSON output mode"
|
||||
)
|
||||
supports_reasoning: bool = Field(
|
||||
default=False, description="Supports reasoning mode"
|
||||
)
|
||||
supports_parallel_tool_calls: bool = Field(
|
||||
default=False, description="Supports parallel tool calls"
|
||||
)
|
||||
capabilities: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional capabilities"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
|
||||
class UpdateLlmModelRequest(BaseModel):
|
||||
@@ -73,12 +95,18 @@ class UpdateLlmModelRequest(BaseModel):
|
||||
None, description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3
|
||||
)
|
||||
is_enabled: bool | None = Field(None, description="Whether the model is enabled")
|
||||
is_recommended: bool | None = Field(None, description="Whether the model is recommended")
|
||||
is_recommended: bool | None = Field(
|
||||
None, description="Whether the model is recommended"
|
||||
)
|
||||
supports_tools: bool | None = Field(None, description="Supports function calling")
|
||||
supports_json_output: bool | None = Field(None, description="Supports JSON output mode")
|
||||
supports_json_output: bool | None = Field(
|
||||
None, description="Supports JSON output mode"
|
||||
)
|
||||
supports_reasoning: bool | None = Field(None, description="Supports reasoning mode")
|
||||
supports_parallel_tool_calls: bool | None = Field(
|
||||
None, description="Supports parallel tool calls"
|
||||
)
|
||||
capabilities: dict[str, Any] | None = Field(None, description="Additional capabilities")
|
||||
capabilities: dict[str, Any] | None = Field(
|
||||
None, description="Additional capabilities"
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")
|
||||
|
||||
@@ -3,18 +3,17 @@
|
||||
Provides endpoints for creating, updating, and deleting:
|
||||
- Models
|
||||
- Providers
|
||||
- Costs
|
||||
- Creators
|
||||
- Migrations
|
||||
|
||||
All endpoints require admin authentication.
|
||||
All endpoints require admin authentication and refresh the registry cache after mutations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import autogpt_libs.auth
|
||||
from fastapi import APIRouter, HTTPException, Security, status
|
||||
|
||||
from backend.server.v2.llm import db_write
|
||||
from backend.server.v2.llm.admin_model import (
|
||||
CreateLlmModelRequest,
|
||||
CreateLlmProviderRequest,
|
||||
@@ -22,9 +21,52 @@ from backend.server.v2.llm.admin_model import (
|
||||
UpdateLlmProviderRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _map_provider_response(provider: Any) -> dict[str, Any]:
|
||||
"""Map Prisma provider model to response dict."""
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.displayName,
|
||||
"description": provider.description,
|
||||
"default_credential_provider": provider.defaultCredentialProvider,
|
||||
"default_credential_id": provider.defaultCredentialId,
|
||||
"default_credential_type": provider.defaultCredentialType,
|
||||
"metadata": dict(provider.metadata or {}),
|
||||
"created_at": provider.createdAt.isoformat() if provider.createdAt else None,
|
||||
"updated_at": provider.updatedAt.isoformat() if provider.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
def _map_model_response(model: Any) -> dict[str, Any]:
|
||||
"""Map Prisma model to response dict."""
|
||||
return {
|
||||
"id": model.id,
|
||||
"slug": model.slug,
|
||||
"display_name": model.displayName,
|
||||
"description": model.description,
|
||||
"provider_id": model.providerId,
|
||||
"creator_id": model.creatorId,
|
||||
"context_window": model.contextWindow,
|
||||
"max_output_tokens": model.maxOutputTokens,
|
||||
"price_tier": model.priceTier,
|
||||
"is_enabled": model.isEnabled,
|
||||
"is_recommended": model.isRecommended,
|
||||
"supports_tools": model.supportsTools,
|
||||
"supports_json_output": model.supportsJsonOutput,
|
||||
"supports_reasoning": model.supportsReasoning,
|
||||
"supports_parallel_tool_calls": model.supportsParallelToolCalls,
|
||||
"capabilities": dict(model.capabilities or {}),
|
||||
"metadata": dict(model.metadata or {}),
|
||||
"created_at": model.createdAt.isoformat() if model.createdAt else None,
|
||||
"updated_at": model.updatedAt.isoformat() if model.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
@@ -37,14 +79,40 @@ async def create_model(
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
# TODO: Implement model creation
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Model creation not yet implemented",
|
||||
)
|
||||
try:
|
||||
model = await db_write.create_model(
|
||||
slug=request.slug,
|
||||
display_name=request.display_name,
|
||||
provider_id=request.provider_id,
|
||||
context_window=request.context_window,
|
||||
price_tier=request.price_tier,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created model '{request.slug}' (id: {model.id})")
|
||||
return _map_model_response(model)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model creation validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create model")
|
||||
|
||||
|
||||
@router.patch("/llm/models/{slug}", dependencies=[Security(autogpt_libs.auth.requires_admin_user)])
|
||||
@router.patch(
|
||||
"/llm/models/{slug}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_model(
|
||||
slug: str,
|
||||
request: UpdateLlmModelRequest,
|
||||
@@ -53,14 +121,51 @@ async def update_model(
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
# TODO: Implement model update
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Model update not yet implemented",
|
||||
)
|
||||
try:
|
||||
# Find model by slug first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
model = await db_write.update_model(
|
||||
model_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
context_window=request.context_window,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
price_tier=request.price_tier,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated model '{slug}' (id: {model.id})")
|
||||
return _map_model_response(model)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model update validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update model")
|
||||
|
||||
|
||||
@router.delete("/llm/models/{slug}", dependencies=[Security(autogpt_libs.auth.requires_admin_user)], status_code=status.HTTP_204_NO_CONTENT)
|
||||
@router.delete(
|
||||
"/llm/models/{slug}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_model(
|
||||
slug: str,
|
||||
) -> None:
|
||||
@@ -68,14 +173,34 @@ async def delete_model(
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
# TODO: Implement model deletion
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Model deletion not yet implemented",
|
||||
)
|
||||
try:
|
||||
# Find model by slug first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
await db_write.delete_model(model_id=existing.id)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Deleted model '{slug}' (id: {existing.id})")
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model deletion validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete model")
|
||||
|
||||
|
||||
@router.post("/llm/providers", status_code=status.HTTP_201_CREATED)
|
||||
@router.post(
|
||||
"/llm/providers",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_provider(
|
||||
request: CreateLlmProviderRequest,
|
||||
) -> dict[str, Any]:
|
||||
@@ -83,14 +208,31 @@ async def create_provider(
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
# TODO: Implement provider creation
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Provider creation not yet implemented",
|
||||
)
|
||||
try:
|
||||
provider = await db_write.create_provider(
|
||||
name=request.name,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created provider '{request.name}' (id: {provider.id})")
|
||||
return _map_provider_response(provider)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider creation validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create provider")
|
||||
|
||||
|
||||
@router.patch("/llm/providers/{name}", dependencies=[Security(autogpt_libs.auth.requires_admin_user)])
|
||||
@router.patch(
|
||||
"/llm/providers/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_provider(
|
||||
name: str,
|
||||
request: UpdateLlmProviderRequest,
|
||||
@@ -99,23 +241,69 @@ async def update_provider(
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
# TODO: Implement provider update
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Provider update not yet implemented",
|
||||
)
|
||||
try:
|
||||
# Find provider by name first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
provider = await db_write.update_provider(
|
||||
provider_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated provider '{name}' (id: {provider.id})")
|
||||
return _map_provider_response(provider)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider update validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update provider")
|
||||
|
||||
|
||||
@router.delete("/llm/providers/{name}", dependencies=[Security(autogpt_libs.auth.requires_admin_user)], status_code=status.HTTP_204_NO_CONTENT)
|
||||
@router.delete(
|
||||
"/llm/providers/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_provider(
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
Requires admin authentication.
|
||||
A provider can only be deleted if it has no associated models.
|
||||
"""
|
||||
# TODO: Implement provider deletion
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Provider deletion not yet implemented",
|
||||
)
|
||||
try:
|
||||
# Find provider by name first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
await db_write.delete_provider(provider_id=existing.id)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Deleted provider '{name}' (id: {existing.id})")
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider deletion validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete provider")
|
||||
|
||||
294
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
294
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Database write operations for LLM registry admin API."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
import prisma.models
|
||||
|
||||
from backend.data import llm_registry
|
||||
|
||||
|
||||
def _build_provider_data(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build provider data dict for Prisma operations."""
|
||||
return {
|
||||
"name": name,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"defaultCredentialProvider": default_credential_provider,
|
||||
"defaultCredentialId": default_credential_id,
|
||||
"defaultCredentialType": default_credential_type,
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
|
||||
|
||||
def _build_model_data(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build model data dict for Prisma operations."""
|
||||
data: dict[str, Any] = {
|
||||
"slug": slug,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"Provider": {"connect": {"id": provider_id}},
|
||||
"contextWindow": context_window,
|
||||
"maxOutputTokens": max_output_tokens,
|
||||
"priceTier": price_tier,
|
||||
"isEnabled": is_enabled,
|
||||
"isRecommended": is_recommended,
|
||||
"supportsTools": supports_tools,
|
||||
"supportsJsonOutput": supports_json_output,
|
||||
"supportsReasoning": supports_reasoning,
|
||||
"supportsParallelToolCalls": supports_parallel_tool_calls,
|
||||
"capabilities": prisma.Json(capabilities or {}),
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
if creator_id:
|
||||
data["Creator"] = {"connect": {"id": creator_id}}
|
||||
return data
|
||||
|
||||
|
||||
async def create_provider(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Create a new LLM provider."""
|
||||
data = _build_provider_data(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
default_credential_provider=default_credential_provider,
|
||||
default_credential_id=default_credential_id,
|
||||
default_credential_type=default_credential_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
provider = await prisma.models.LlmProvider.prisma().create(
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError("Failed to create provider")
|
||||
return provider
|
||||
|
||||
|
||||
async def update_provider(
|
||||
provider_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Update an existing LLM provider."""
|
||||
# Fetch existing provider to get current name
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id}
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if default_credential_provider is not None:
|
||||
data["defaultCredentialProvider"] = default_credential_provider
|
||||
if default_credential_id is not None:
|
||||
data["defaultCredentialId"] = default_credential_id
|
||||
if default_credential_type is not None:
|
||||
data["defaultCredentialType"] = default_credential_type
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
|
||||
updated = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not updated:
|
||||
raise ValueError("Failed to update provider")
|
||||
return updated
|
||||
|
||||
|
||||
async def delete_provider(provider_id: str) -> bool:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
"""
|
||||
# Check if provider exists
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Check if provider has any models
|
||||
model_count = len(provider.Models) if provider.Models else 0
|
||||
if model_count > 0:
|
||||
raise ValueError(
|
||||
f"Cannot delete provider '{provider.displayName}' because it has "
|
||||
f"{model_count} model(s). Delete all models first."
|
||||
)
|
||||
|
||||
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
|
||||
return True
|
||||
|
||||
|
||||
async def create_model(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Create a new LLM model."""
|
||||
data = _build_model_data(
|
||||
slug=slug,
|
||||
display_name=display_name,
|
||||
provider_id=provider_id,
|
||||
context_window=context_window,
|
||||
price_tier=price_tier,
|
||||
description=description,
|
||||
creator_id=creator_id,
|
||||
max_output_tokens=max_output_tokens,
|
||||
is_enabled=is_enabled,
|
||||
is_recommended=is_recommended,
|
||||
supports_tools=supports_tools,
|
||||
supports_json_output=supports_json_output,
|
||||
supports_reasoning=supports_reasoning,
|
||||
supports_parallel_tool_calls=supports_parallel_tool_calls,
|
||||
capabilities=capabilities,
|
||||
metadata=metadata,
|
||||
)
|
||||
model = await prisma.models.LlmModel.prisma().create(
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
if not model:
|
||||
raise ValueError("Failed to create model")
|
||||
return model
|
||||
|
||||
|
||||
async def update_model(
|
||||
model_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
context_window: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
price_tier: int | None = None,
|
||||
is_enabled: bool | None = None,
|
||||
is_recommended: bool | None = None,
|
||||
supports_tools: bool | None = None,
|
||||
supports_json_output: bool | None = None,
|
||||
supports_reasoning: bool | None = None,
|
||||
supports_parallel_tool_calls: bool | None = None,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Update an existing LLM model."""
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if context_window is not None:
|
||||
data["contextWindow"] = context_window
|
||||
if max_output_tokens is not None:
|
||||
data["maxOutputTokens"] = max_output_tokens
|
||||
if price_tier is not None:
|
||||
data["priceTier"] = price_tier
|
||||
if is_enabled is not None:
|
||||
data["isEnabled"] = is_enabled
|
||||
if is_recommended is not None:
|
||||
data["isRecommended"] = is_recommended
|
||||
if supports_tools is not None:
|
||||
data["supportsTools"] = supports_tools
|
||||
if supports_json_output is not None:
|
||||
data["supportsJsonOutput"] = supports_json_output
|
||||
if supports_reasoning is not None:
|
||||
data["supportsReasoning"] = supports_reasoning
|
||||
if supports_parallel_tool_calls is not None:
|
||||
data["supportsParallelToolCalls"] = supports_parallel_tool_calls
|
||||
if capabilities is not None:
|
||||
data["capabilities"] = prisma.Json(capabilities)
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
if creator_id is not None:
|
||||
data["creatorId"] = creator_id if creator_id else None
|
||||
|
||||
model = await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return model
|
||||
|
||||
|
||||
async def delete_model(model_id: str) -> bool:
|
||||
"""Delete an LLM model.
|
||||
|
||||
Note: This should check if any workflows are using this model first.
|
||||
For now, we'll allow deletion and rely on FK constraints.
|
||||
"""
|
||||
# Check if model exists
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
await prisma.models.LlmModel.prisma().delete(where={"id": model_id})
|
||||
return True
|
||||
|
||||
|
||||
async def refresh_runtime_caches() -> None:
|
||||
"""Refresh the LLM registry and clear all related caches."""
|
||||
# Refresh the in-memory registry
|
||||
await llm_registry.refresh_llm_registry()
|
||||
|
||||
# TODO: Clear block schema caches when block integration is implemented
|
||||
# TODO: Publish registry refresh notification to executors
|
||||
Reference in New Issue
Block a user