mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
47 Commits
dx/add-age
...
feat/llm-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be328c1ec5 | ||
|
|
8410448c16 | ||
|
|
e168597663 | ||
|
|
1d903ae287 | ||
|
|
1be7aebdea | ||
|
|
36045c7007 | ||
|
|
445eb173a5 | ||
|
|
393a138fee | ||
|
|
ccc1e35c5b | ||
|
|
c66f114e28 | ||
|
|
939edc73b8 | ||
|
|
d52409c853 | ||
|
|
90a68084eb | ||
|
|
fb9a3224be | ||
|
|
eb76b95aa5 | ||
|
|
cc17884360 | ||
|
|
1ce3cc0231 | ||
|
|
bd1f4b5701 | ||
|
|
e89e56d90d | ||
|
|
2a923dcd92 | ||
|
|
1fffd21b16 | ||
|
|
2241a62b75 | ||
|
|
a5b71b9783 | ||
|
|
7632548408 | ||
|
|
05fa10925c | ||
|
|
c64246be87 | ||
|
|
253937e7b9 | ||
|
|
73e481b508 | ||
|
|
f0cc4ae573 | ||
|
|
e0282b00db | ||
|
|
9a9c36b806 | ||
|
|
d5381625cd | ||
|
|
f6ae3d6593 | ||
|
|
0fb1b854df | ||
|
|
64a011664a | ||
|
|
1db7c048d9 | ||
|
|
4c5627c966 | ||
|
|
d97d137a51 | ||
|
|
ded9e293ff | ||
|
|
83d504bed2 | ||
|
|
a5f1ffb35b | ||
|
|
97c6516a14 | ||
|
|
876dde8bc7 | ||
|
|
0bfdd74b25 | ||
|
|
a7d2f81b18 | ||
|
|
3699eaa556 | ||
|
|
21adf9e0fb |
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import platform
|
||||
@@ -37,8 +38,10 @@ import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.llm_registry
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.v2.llm
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.library.exceptions import (
|
||||
@@ -117,16 +120,56 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Load LLM registry before initializing blocks so blocks can use registry data.
|
||||
# Tries Redis first (fast path on warm restart), falls back to DB.
|
||||
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
|
||||
try:
|
||||
await backend.data.llm_registry.refresh_llm_registry()
|
||||
logger.info("LLM registry loaded successfully at startup")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load LLM registry at startup: {e}. "
|
||||
"Blocks will initialize with empty registry."
|
||||
)
|
||||
|
||||
# Start background task so this worker reloads its in-process cache whenever
|
||||
# another worker (e.g. the admin API) refreshes the registry.
|
||||
_registry_subscription_task = asyncio.create_task(
|
||||
backend.data.llm_registry.subscribe_to_registry_refresh(
|
||||
backend.data.llm_registry.refresh_llm_registry
|
||||
)
|
||||
)
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
try:
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
if "AgentNode" in err_str or "does not exist" in err_str:
|
||||
logger.warning(
|
||||
f"migrate_llm_models skipped: AgentNode table not found ({e}). "
|
||||
"This is expected in test environments."
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"migrate_llm_models failed unexpectedly: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
_registry_subscription_task.cancel()
|
||||
try:
|
||||
await _registry_subscription_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
@@ -355,6 +398,16 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.admin_router,
|
||||
tags=["v2", "llm", "admin"],
|
||||
prefix="/api",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from backend.util.request import parse_url
|
||||
from .block import BlockInput
|
||||
from .db import BaseDbModel
|
||||
from .db import prisma as db
|
||||
from .db import query_raw_with_schema, transaction
|
||||
from .db import execute_raw_with_schema, query_raw_with_schema, transaction
|
||||
from .dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, MAX_GRAPH_VERSIONS_FETCH
|
||||
from .model import CredentialsFieldInfo, CredentialsMetaInput, is_credentials_field_name
|
||||
@@ -1669,16 +1669,15 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
|
||||
# Update each block
|
||||
for id, path in llm_model_fields.items():
|
||||
query = f"""
|
||||
UPDATE platform."AgentNode"
|
||||
query = """
|
||||
UPDATE {schema_prefix}"AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? ($4)::text
|
||||
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
|
||||
"""
|
||||
AND "constantInput"->>($4)::text NOT IN """ + escaped_enum_values
|
||||
|
||||
await db.execute_raw(
|
||||
query, # type: ignore - is supposed to be LiteralString
|
||||
await execute_raw_with_schema(
|
||||
query,
|
||||
[path],
|
||||
migrate_to.value,
|
||||
id,
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""LLM Registry - Dynamic model management system."""
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from .notifications import (
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
from .registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
clear_registry_cache,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
get_enabled_models,
|
||||
get_model,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Cache management
|
||||
"clear_registry_cache",
|
||||
"publish_registry_refresh_notification",
|
||||
"subscribe_to_registry_refresh",
|
||||
# Read functions
|
||||
"refresh_llm_registry",
|
||||
"get_model",
|
||||
"get_all_models",
|
||||
"get_enabled_models",
|
||||
"get_schema_options",
|
||||
"get_default_model_slug",
|
||||
"get_all_model_slugs_for_validation",
|
||||
]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Pub/sub notifications for LLM registry cross-process synchronisation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
async def publish_registry_refresh_notification() -> None:
|
||||
"""Publish a refresh signal so all other workers reload their in-process cache."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.debug("Published LLM registry refresh notification")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish registry refresh notification: %s", e)
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Listen for registry refresh signals and call on_refresh each time one arrives.
|
||||
|
||||
Designed to run as a long-lived background asyncio.Task. Automatically
|
||||
reconnects if the Redis connection drops.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable invoked on each refresh signal.
|
||||
Typically ``llm_registry.refresh_llm_registry``.
|
||||
"""
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Dedicated connection — pub/sub must not share a connection used
|
||||
# for regular commands.
|
||||
redis_sub = AsyncRedis(
|
||||
host=HOST, port=PORT, password=PASSWORD, decode_responses=True
|
||||
)
|
||||
pubsub = redis_sub.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info("Subscribed to LLM registry refresh channel")
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if (
|
||||
message
|
||||
and message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.debug("LLM registry refresh signal received")
|
||||
try:
|
||||
await on_refresh()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in registry on_refresh callback: %s", e
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error processing registry refresh message: %s", e
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("LLM registry subscription task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"LLM registry subscription error: %s. Retrying in 5s...", e
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
254
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
254
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import prisma.models
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegistryModelCost(BaseModel):
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None = None
|
||||
credential_type: str | None = None
|
||||
currency: str | None = None
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
|
||||
class RegistryModelCreator(BaseModel):
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
|
||||
|
||||
class RegistryModel(BaseModel):
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any] = {}
|
||||
extra_metadata: dict[str, Any] = {}
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = ()
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
# Typed capability fields from DB schema
|
||||
supports_tools: bool = False
|
||||
supports_json_output: bool = False
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool_calls: bool = False
|
||||
|
||||
|
||||
# L1 in-process cache — Redis is the shared L2 via @cached(shared_cache=True)
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: # type: ignore[name-defined]
|
||||
"""Transform a raw Prisma LlmModel record into a RegistryModel instance."""
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
unit=str(cost.unit),
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=dict(cost.metadata or {}),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
capabilities = dict(record.capabilities or {})
|
||||
|
||||
if not record.Provider:
|
||||
logger.warning(
|
||||
"LlmModel %s has no Provider despite NOT NULL FK - "
|
||||
"falling back to providerId %s",
|
||||
record.slug,
|
||||
record.providerId,
|
||||
)
|
||||
provider_name = record.Provider.name if record.Provider else record.providerId
|
||||
provider_display = (
|
||||
record.Provider.displayName if record.Provider else record.providerId
|
||||
)
|
||||
creator_name = record.Creator.displayName if record.Creator else "Unknown"
|
||||
|
||||
if record.priceTier not in (1, 2, 3):
|
||||
logger.warning(
|
||||
"LlmModel %s has out-of-range priceTier=%s, defaulting to 1",
|
||||
record.slug,
|
||||
record.priceTier,
|
||||
)
|
||||
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=(
|
||||
record.maxOutputTokens
|
||||
if record.maxOutputTokens is not None
|
||||
else record.contextWindow
|
||||
),
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier,
|
||||
)
|
||||
|
||||
return RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=capabilities,
|
||||
extra_metadata=dict(record.metadata or {}),
|
||||
provider_display_name=provider_display,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
supports_tools=record.supportsTools,
|
||||
supports_json_output=record.supportsJsonOutput,
|
||||
supports_reasoning=record.supportsReasoning,
|
||||
supports_parallel_tool_calls=record.supportsParallelToolCalls,
|
||||
)
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl_seconds=300, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def _fetch_registry_from_db() -> list[RegistryModel]:
|
||||
"""Fetch all LLM models from the database.
|
||||
|
||||
Results are cached in Redis (shared_cache=True) so subsequent calls within
|
||||
the TTL window skip the DB entirely — both within this process and across
|
||||
all other workers that share the same Redis instance.
|
||||
"""
|
||||
records = await prisma.models.LlmModel.prisma().find_many( # type: ignore[attr-defined]
|
||||
include={"Provider": True, "Costs": True, "Creator": True}
|
||||
)
|
||||
logger.info("Fetched %d LLM models from database", len(records))
|
||||
return [_record_to_registry_model(r) for r in records]
|
||||
|
||||
|
||||
def clear_registry_cache() -> None:
|
||||
"""Invalidate the shared Redis cache for the registry DB fetch.
|
||||
|
||||
Call this before refresh_llm_registry() after any admin DB mutation so the
|
||||
next fetch hits the database rather than serving the now-stale cached data.
|
||||
"""
|
||||
_fetch_registry_from_db.cache_clear()
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the in-process L1 cache from Redis/DB.
|
||||
|
||||
On the first call (or after clear_registry_cache()), fetches fresh data
|
||||
from the database and stores it in Redis. Subsequent calls by other
|
||||
workers hit the Redis cache instead of the DB.
|
||||
"""
|
||||
async with _lock:
|
||||
try:
|
||||
models = await _fetch_registry_from_db()
|
||||
new_models = {m.slug: m for m in models}
|
||||
|
||||
global _dynamic_models, _schema_options
|
||||
_dynamic_models = new_models
|
||||
_schema_options = _build_schema_options()
|
||||
|
||||
logger.info(
|
||||
"LLM registry refreshed: %d models, %d schema options",
|
||||
len(_dynamic_models),
|
||||
len(_schema_options),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to refresh LLM registry: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
return [
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
for model in sorted(
|
||||
_dynamic_models.values(), key=lambda m: m.display_name.lower()
|
||||
)
|
||||
if model.is_enabled
|
||||
]
|
||||
|
||||
|
||||
def get_model(slug: str) -> RegistryModel | None:
|
||||
"""Get a model by slug from the registry."""
|
||||
return _dynamic_models.get(slug)
|
||||
|
||||
|
||||
def get_all_models() -> list[RegistryModel]:
|
||||
"""Get all models from the registry (including disabled)."""
|
||||
return list(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_enabled_models() -> list[RegistryModel]:
|
||||
"""Get only enabled models from the registry."""
|
||||
return [model for model in _dynamic_models.values() if model.is_enabled]
|
||||
|
||||
|
||||
def get_schema_options() -> list[dict[str, str]]:
|
||||
"""Get schema options for model selection dropdown (enabled models only)."""
|
||||
return list(_schema_options)
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""Get the default model slug (first recommended, or first enabled)."""
|
||||
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
|
||||
recommended = next(
|
||||
(m.slug for m in models if m.is_recommended and m.is_enabled), None
|
||||
)
|
||||
return recommended or next((m.slug for m in models if m.is_enabled), None)
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> list[str]:
|
||||
"""Get all model slugs for validation (enabled models only)."""
|
||||
return [model.slug for model in _dynamic_models.values() if model.is_enabled]
|
||||
@@ -0,0 +1,358 @@
|
||||
"""Unit tests for the LLM registry module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import pydantic
|
||||
|
||||
from backend.data.llm_registry.registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
_build_schema_options,
|
||||
_record_to_registry_model,
|
||||
get_default_model_slug,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_record(**overrides):
|
||||
"""Build a realistic mock Prisma LlmModel record."""
|
||||
provider = Mock()
|
||||
provider.name = "openai"
|
||||
provider.displayName = "OpenAI"
|
||||
|
||||
record = Mock()
|
||||
record.slug = "openai/gpt-4o"
|
||||
record.displayName = "GPT-4o"
|
||||
record.description = "Latest GPT model"
|
||||
record.providerId = "provider-uuid"
|
||||
record.Provider = provider
|
||||
record.creatorId = "creator-uuid"
|
||||
record.Creator = None
|
||||
record.contextWindow = 128000
|
||||
record.maxOutputTokens = 16384
|
||||
record.priceTier = 2
|
||||
record.isEnabled = True
|
||||
record.isRecommended = False
|
||||
record.supportsTools = True
|
||||
record.supportsJsonOutput = True
|
||||
record.supportsReasoning = False
|
||||
record.supportsParallelToolCalls = True
|
||||
record.capabilities = {}
|
||||
record.metadata = {}
|
||||
record.Costs = []
|
||||
|
||||
for key, value in overrides.items():
|
||||
setattr(record, key, value)
|
||||
return record
|
||||
|
||||
|
||||
def _make_registry_model(**kwargs) -> RegistryModel:
|
||||
"""Build a minimal RegistryModel for testing registry-level functions."""
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
|
||||
defaults = dict(
|
||||
slug="openai/gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description=None,
|
||||
metadata=ModelMetadata(
|
||||
provider="openai",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
display_name="GPT-4o",
|
||||
provider_name="OpenAI",
|
||||
creator_name="Unknown",
|
||||
price_tier=2,
|
||||
),
|
||||
capabilities={},
|
||||
extra_metadata={},
|
||||
provider_display_name="OpenAI",
|
||||
is_enabled=True,
|
||||
is_recommended=False,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return RegistryModel(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _record_to_registry_model tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_record_to_registry_model():
|
||||
"""Happy-path: well-formed record produces a correct RegistryModel."""
|
||||
record = _make_mock_record()
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.slug == "openai/gpt-4o"
|
||||
assert model.display_name == "GPT-4o"
|
||||
assert model.description == "Latest GPT model"
|
||||
assert model.provider_display_name == "OpenAI"
|
||||
assert model.is_enabled is True
|
||||
assert model.is_recommended is False
|
||||
assert model.supports_tools is True
|
||||
assert model.supports_json_output is True
|
||||
assert model.supports_reasoning is False
|
||||
assert model.supports_parallel_tool_calls is True
|
||||
assert model.metadata.provider == "openai"
|
||||
assert model.metadata.context_window == 128000
|
||||
assert model.metadata.max_output_tokens == 16384
|
||||
assert model.metadata.price_tier == 2
|
||||
assert model.creator is None
|
||||
assert model.costs == ()
|
||||
|
||||
|
||||
def test_record_to_registry_model_missing_provider(caplog):
|
||||
"""Record with no Provider relation falls back to providerId and logs a warning."""
|
||||
record = _make_mock_record(Provider=None, providerId="provider-uuid")
|
||||
with caplog.at_level("WARNING"):
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert "no Provider" in caplog.text
|
||||
assert model.metadata.provider == "provider-uuid"
|
||||
assert model.provider_display_name == "provider-uuid"
|
||||
|
||||
|
||||
def test_record_to_registry_model_missing_creator():
|
||||
"""When Creator is None, creator_name defaults to 'Unknown' and creator field is None."""
|
||||
record = _make_mock_record(Creator=None)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.creator is None
|
||||
assert model.metadata.creator_name == "Unknown"
|
||||
|
||||
|
||||
def test_record_to_registry_model_with_creator():
|
||||
"""When Creator is present, it is parsed into RegistryModelCreator."""
|
||||
creator_mock = Mock()
|
||||
creator_mock.id = "creator-uuid"
|
||||
creator_mock.name = "openai"
|
||||
creator_mock.displayName = "OpenAI"
|
||||
creator_mock.description = "AI company"
|
||||
creator_mock.websiteUrl = "https://openai.com"
|
||||
creator_mock.logoUrl = "https://openai.com/logo.png"
|
||||
|
||||
record = _make_mock_record(Creator=creator_mock)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.creator is not None
|
||||
assert isinstance(model.creator, RegistryModelCreator)
|
||||
assert model.creator.id == "creator-uuid"
|
||||
assert model.creator.display_name == "OpenAI"
|
||||
assert model.metadata.creator_name == "OpenAI"
|
||||
|
||||
|
||||
def test_record_to_registry_model_null_max_output_tokens():
|
||||
"""maxOutputTokens=None falls back to contextWindow."""
|
||||
record = _make_mock_record(maxOutputTokens=None, contextWindow=64000)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.metadata.max_output_tokens == 64000
|
||||
|
||||
|
||||
def test_record_to_registry_model_invalid_price_tier(caplog):
|
||||
"""Out-of-range priceTier is coerced to 1 and a warning is logged."""
|
||||
record = _make_mock_record(priceTier=99)
|
||||
with caplog.at_level("WARNING"):
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert "out-of-range priceTier" in caplog.text
|
||||
assert model.metadata.price_tier == 1
|
||||
|
||||
|
||||
def test_record_to_registry_model_with_costs():
|
||||
"""Costs are parsed into RegistryModelCost tuples."""
|
||||
cost_mock = Mock()
|
||||
cost_mock.unit = "TOKENS"
|
||||
cost_mock.creditCost = 10
|
||||
cost_mock.credentialProvider = "openai"
|
||||
cost_mock.credentialId = None
|
||||
cost_mock.credentialType = None
|
||||
cost_mock.currency = "USD"
|
||||
cost_mock.metadata = {}
|
||||
|
||||
record = _make_mock_record(Costs=[cost_mock])
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert len(model.costs) == 1
|
||||
cost = model.costs[0]
|
||||
assert isinstance(cost, RegistryModelCost)
|
||||
assert cost.unit == "TOKENS"
|
||||
assert cost.credit_cost == 10
|
||||
assert cost.credential_provider == "openai"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_default_model_slug tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_default_model_slug_recommended():
|
||||
"""Recommended model is preferred over non-recommended enabled models."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_recommended=False
|
||||
),
|
||||
"openai/gpt-4o-recommended": _make_registry_model(
|
||||
slug="openai/gpt-4o-recommended",
|
||||
display_name="GPT-4o Recommended",
|
||||
is_recommended=True,
|
||||
),
|
||||
}
|
||||
|
||||
result = get_default_model_slug()
|
||||
assert result == "openai/gpt-4o-recommended"
|
||||
|
||||
|
||||
def test_get_default_model_slug_fallback():
|
||||
"""With no recommended model, falls back to first enabled (alphabetical)."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_recommended=False
|
||||
),
|
||||
"openai/gpt-3.5": _make_registry_model(
|
||||
slug="openai/gpt-3.5", display_name="GPT-3.5", is_recommended=False
|
||||
),
|
||||
}
|
||||
|
||||
result = get_default_model_slug()
|
||||
# Sorted alphabetically: GPT-3.5 < GPT-4o
|
||||
assert result == "openai/gpt-3.5"
|
||||
|
||||
|
||||
def test_get_default_model_slug_empty():
|
||||
"""Empty registry returns None."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {}
|
||||
|
||||
result = get_default_model_slug()
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_schema_options / get_schema_options tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_schema_options():
|
||||
"""Only enabled models appear, sorted case-insensitively."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_enabled=True
|
||||
),
|
||||
"openai/disabled": _make_registry_model(
|
||||
slug="openai/disabled", display_name="Disabled Model", is_enabled=False
|
||||
),
|
||||
"openai/gpt-3.5": _make_registry_model(
|
||||
slug="openai/gpt-3.5", display_name="gpt-3.5", is_enabled=True
|
||||
),
|
||||
}
|
||||
|
||||
options = _build_schema_options()
|
||||
slugs = [o["value"] for o in options]
|
||||
|
||||
# disabled model should be excluded
|
||||
assert "openai/disabled" not in slugs
|
||||
# only enabled models
|
||||
assert "openai/gpt-4o" in slugs
|
||||
assert "openai/gpt-3.5" in slugs
|
||||
# case-insensitive sort: "gpt-3.5" < "GPT-4o" (both lowercase: "gpt-3.5" < "gpt-4o")
|
||||
assert slugs.index("openai/gpt-3.5") < slugs.index("openai/gpt-4o")
|
||||
|
||||
# Verify structure
|
||||
for option in options:
|
||||
assert "label" in option
|
||||
assert "value" in option
|
||||
assert "group" in option
|
||||
assert "description" in option
|
||||
|
||||
|
||||
def test_get_schema_options_returns_copy():
|
||||
"""Mutating the returned list does not affect the internal cache."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(slug="openai/gpt-4o", display_name="GPT-4o"),
|
||||
}
|
||||
reg._schema_options = _build_schema_options()
|
||||
|
||||
options = get_schema_options()
|
||||
original_length = len(options)
|
||||
options.append({"label": "Injected", "value": "evil/model", "group": "evil", "description": ""})
|
||||
|
||||
# Internal state should be unchanged
|
||||
assert len(get_schema_options()) == original_length
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic frozen model tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registry_model_frozen():
|
||||
"""Pydantic frozen=True should reject attribute assignment."""
|
||||
model = _make_registry_model()
|
||||
|
||||
with pytest.raises((pydantic.ValidationError, TypeError)):
|
||||
model.slug = "changed/slug" # type: ignore[misc]
|
||||
|
||||
|
||||
def test_registry_model_cost_frozen():
|
||||
"""RegistryModelCost is also frozen."""
|
||||
cost = RegistryModelCost(
|
||||
unit="TOKENS",
|
||||
credit_cost=5,
|
||||
credential_provider="openai",
|
||||
)
|
||||
with pytest.raises((pydantic.ValidationError, TypeError)):
|
||||
cost.unit = "RUN" # type: ignore[misc]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# refresh_llm_registry tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_llm_registry():
|
||||
"""Mock prisma find_many, verify cache is populated after refresh."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
record = _make_mock_record()
|
||||
mock_find_many = AsyncMock(return_value=[record])
|
||||
|
||||
with patch("prisma.models.LlmModel.prisma") as mock_prisma_cls:
|
||||
mock_prisma_instance = Mock()
|
||||
mock_prisma_instance.find_many = mock_find_many
|
||||
mock_prisma_cls.return_value = mock_prisma_instance
|
||||
|
||||
# Clear state first
|
||||
reg._dynamic_models = {}
|
||||
reg._schema_options = []
|
||||
|
||||
await refresh_llm_registry()
|
||||
|
||||
assert "openai/gpt-4o" in reg._dynamic_models
|
||||
model = reg._dynamic_models["openai/gpt-4o"]
|
||||
assert isinstance(model, RegistryModel)
|
||||
assert model.slug == "openai/gpt-4o"
|
||||
# Schema options should be populated too
|
||||
assert len(reg._schema_options) == 1
|
||||
assert reg._schema_options[0]["value"] == "openai/gpt-4o"
|
||||
@@ -0,0 +1,6 @@
|
||||
"""LLM registry API (public + admin)."""
|
||||
|
||||
from .admin_routes import router as admin_router
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router", "admin_router"]
|
||||
115
autogpt_platform/backend/backend/server/v2/llm/admin_model.py
Normal file
115
autogpt_platform/backend/backend/server/v2/llm/admin_model.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Request/response models for LLM registry admin API."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateLlmProviderRequest(BaseModel):
|
||||
"""Request model for creating an LLM provider."""
|
||||
|
||||
name: str = Field(
|
||||
..., description="Provider identifier (e.g., 'openai', 'anthropic')"
|
||||
)
|
||||
display_name: str = Field(..., description="Human-readable provider name")
|
||||
description: str | None = Field(None, description="Provider description")
|
||||
default_credential_provider: str | None = Field(
|
||||
None, description="Default credential system identifier"
|
||||
)
|
||||
default_credential_id: str | None = Field(None, description="Default credential ID")
|
||||
default_credential_type: str | None = Field(
|
||||
None, description="Default credential type"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
|
||||
class UpdateLlmProviderRequest(BaseModel):
|
||||
"""Request model for updating an LLM provider."""
|
||||
|
||||
display_name: str | None = Field(None, description="Human-readable provider name")
|
||||
description: str | None = Field(None, description="Provider description")
|
||||
default_credential_provider: str | None = Field(
|
||||
None, description="Default credential system identifier"
|
||||
)
|
||||
default_credential_id: str | None = Field(None, description="Default credential ID")
|
||||
default_credential_type: str | None = Field(
|
||||
None, description="Default credential type"
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")
|
||||
|
||||
|
||||
class CreateLlmModelRequest(BaseModel):
|
||||
"""Request model for creating an LLM model."""
|
||||
|
||||
slug: str = Field(..., description="Model slug (e.g., 'gpt-4', 'claude-3-opus')")
|
||||
display_name: str = Field(..., description="Human-readable model name")
|
||||
description: str | None = Field(None, description="Model description")
|
||||
provider_id: str = Field(..., description="Provider ID (UUID)")
|
||||
creator_id: str | None = Field(None, description="Creator ID (UUID)")
|
||||
context_window: int = Field(
|
||||
..., description="Maximum context window in tokens", gt=0
|
||||
)
|
||||
max_output_tokens: int | None = Field(
|
||||
None, description="Maximum output tokens (None if unlimited)", gt=0
|
||||
)
|
||||
price_tier: int = Field(
|
||||
..., description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3
|
||||
)
|
||||
is_enabled: bool = Field(default=True, description="Whether the model is enabled")
|
||||
is_recommended: bool = Field(
|
||||
default=False, description="Whether the model is recommended"
|
||||
)
|
||||
supports_tools: bool = Field(default=False, description="Supports function calling")
|
||||
supports_json_output: bool = Field(
|
||||
default=False, description="Supports JSON output mode"
|
||||
)
|
||||
supports_reasoning: bool = Field(
|
||||
default=False, description="Supports reasoning mode"
|
||||
)
|
||||
supports_parallel_tool_calls: bool = Field(
|
||||
default=False, description="Supports parallel tool calls"
|
||||
)
|
||||
capabilities: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional capabilities"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
costs: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Cost entries for the model"
|
||||
)
|
||||
|
||||
|
||||
class UpdateLlmModelRequest(BaseModel):
|
||||
"""Request model for updating an LLM model."""
|
||||
|
||||
display_name: str | None = Field(None, description="Human-readable model name")
|
||||
description: str | None = Field(None, description="Model description")
|
||||
creator_id: str | None = Field(None, description="Creator ID (UUID)")
|
||||
context_window: int | None = Field(
|
||||
None, description="Maximum context window in tokens", gt=0
|
||||
)
|
||||
max_output_tokens: int | None = Field(
|
||||
None, description="Maximum output tokens (None if unlimited)", gt=0
|
||||
)
|
||||
price_tier: int | None = Field(
|
||||
None, description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3
|
||||
)
|
||||
is_enabled: bool | None = Field(None, description="Whether the model is enabled")
|
||||
is_recommended: bool | None = Field(
|
||||
None, description="Whether the model is recommended"
|
||||
)
|
||||
supports_tools: bool | None = Field(None, description="Supports function calling")
|
||||
supports_json_output: bool | None = Field(
|
||||
None, description="Supports JSON output mode"
|
||||
)
|
||||
supports_reasoning: bool | None = Field(None, description="Supports reasoning mode")
|
||||
supports_parallel_tool_calls: bool | None = Field(
|
||||
None, description="Supports parallel tool calls"
|
||||
)
|
||||
capabilities: dict[str, Any] | None = Field(
|
||||
None, description="Additional capabilities"
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")
|
||||
689
autogpt_platform/backend/backend/server/v2/llm/admin_routes.py
Normal file
689
autogpt_platform/backend/backend/server/v2/llm/admin_routes.py
Normal file
@@ -0,0 +1,689 @@
|
||||
"""Admin API for LLM registry management.
|
||||
|
||||
Provides endpoints for:
|
||||
- Reading creators (GET)
|
||||
- Creating, updating, and deleting models
|
||||
- Creating, updating, and deleting providers
|
||||
|
||||
All endpoints require admin authentication. Mutations refresh the registry cache.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
import autogpt_libs.auth
|
||||
from fastapi import APIRouter, HTTPException, Security, status
|
||||
|
||||
from backend.server.v2.llm import db_write
|
||||
from backend.server.v2.llm.admin_model import (
|
||||
CreateLlmModelRequest,
|
||||
CreateLlmProviderRequest,
|
||||
UpdateLlmModelRequest,
|
||||
UpdateLlmProviderRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _map_provider_response(provider: Any) -> dict[str, Any]:
|
||||
"""Map Prisma provider model to response dict."""
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.displayName,
|
||||
"description": provider.description,
|
||||
"default_credential_provider": provider.defaultCredentialProvider,
|
||||
"default_credential_id": provider.defaultCredentialId,
|
||||
"default_credential_type": provider.defaultCredentialType,
|
||||
"metadata": dict(provider.metadata or {}),
|
||||
"created_at": provider.createdAt.isoformat() if provider.createdAt else None,
|
||||
"updated_at": provider.updatedAt.isoformat() if provider.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
def _map_model_response(model: Any) -> dict[str, Any]:
|
||||
"""Map Prisma model to response dict."""
|
||||
return {
|
||||
"id": model.id,
|
||||
"slug": model.slug,
|
||||
"display_name": model.displayName,
|
||||
"description": model.description,
|
||||
"provider_id": model.providerId,
|
||||
"creator_id": model.creatorId,
|
||||
"context_window": model.contextWindow,
|
||||
"max_output_tokens": model.maxOutputTokens,
|
||||
"price_tier": model.priceTier,
|
||||
"is_enabled": model.isEnabled,
|
||||
"is_recommended": model.isRecommended,
|
||||
"supports_tools": model.supportsTools,
|
||||
"supports_json_output": model.supportsJsonOutput,
|
||||
"supports_reasoning": model.supportsReasoning,
|
||||
"supports_parallel_tool_calls": model.supportsParallelToolCalls,
|
||||
"capabilities": dict(model.capabilities or {}),
|
||||
"metadata": dict(model.metadata or {}),
|
||||
"created_at": model.createdAt.isoformat() if model.createdAt else None,
|
||||
"updated_at": model.updatedAt.isoformat() if model.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
def _map_creator_response(creator: Any) -> dict[str, Any]:
|
||||
"""Map Prisma creator model to response dict."""
|
||||
return {
|
||||
"id": creator.id,
|
||||
"name": creator.name,
|
||||
"display_name": creator.displayName,
|
||||
"description": creator.description,
|
||||
"website_url": creator.websiteUrl,
|
||||
"logo_url": creator.logoUrl,
|
||||
"metadata": dict(creator.metadata or {}),
|
||||
"created_at": creator.createdAt.isoformat() if creator.createdAt else None,
|
||||
"updated_at": creator.updatedAt.isoformat() if creator.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_model(
|
||||
request: CreateLlmModelRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new LLM model.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models as pm
|
||||
|
||||
# Resolve provider name to ID
|
||||
provider = await pm.LlmProvider.prisma().find_unique(
|
||||
where={"name": request.provider_id}
|
||||
)
|
||||
if not provider:
|
||||
# Try as UUID fallback
|
||||
provider = await pm.LlmProvider.prisma().find_unique(
|
||||
where={"id": request.provider_id}
|
||||
)
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Provider '{request.provider_id}' not found",
|
||||
)
|
||||
|
||||
model = await db_write.create_model(
|
||||
slug=request.slug,
|
||||
display_name=request.display_name,
|
||||
provider_id=provider.id,
|
||||
context_window=request.context_window,
|
||||
price_tier=request.price_tier,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
# Create costs if provided in the raw request body
|
||||
if hasattr(request, 'costs') and request.costs:
|
||||
for cost_input in request.costs:
|
||||
await pm.LlmModelCost.prisma().create(
|
||||
data={
|
||||
"unit": cost_input.get("unit", "RUN"),
|
||||
"creditCost": int(cost_input.get("credit_cost", 1)),
|
||||
"credentialProvider": provider.name,
|
||||
"metadata": prisma.Json(cost_input.get("metadata", {})),
|
||||
"Model": {"connect": {"id": model.id}},
|
||||
}
|
||||
)
|
||||
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created model '{request.slug}' (id: {model.id})")
|
||||
|
||||
# Re-fetch with costs included
|
||||
model = await pm.LlmModel.prisma().find_unique(
|
||||
where={"id": model.id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return _map_model_response(model)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model creation validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create model")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/models/{slug:path}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_model(
|
||||
slug: str,
|
||||
request: UpdateLlmModelRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing LLM model.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
# Find model by slug first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
model = await db_write.update_model(
|
||||
model_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
context_window=request.context_window,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
price_tier=request.price_tier,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated model '{slug}' (id: {model.id})")
|
||||
return _map_model_response(model)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model update validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update model")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/models/{slug:path}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_model(
|
||||
slug: str,
|
||||
replacement_model_slug: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete an LLM model with optional migration.
|
||||
|
||||
If workflows are using this model and no replacement_model_slug is given,
|
||||
returns 400 with the node count. Provide replacement_model_slug to migrate
|
||||
affected nodes before deletion.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
result = await db_write.delete_model(
|
||||
model_id=existing.id,
|
||||
replacement_model_slug=replacement_model_slug,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Deleted model '{slug}' (migrated {result['nodes_migrated']} nodes)"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model deletion validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete model")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/models/{slug:path}/usage",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def get_model_usage(slug: str) -> dict[str, Any]:
|
||||
"""Get usage count for a model — how many workflow nodes reference it."""
|
||||
try:
|
||||
return await db_write.get_model_usage(slug)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get model usage: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get model usage")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models/{slug:path}/toggle",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def toggle_model(
|
||||
slug: str,
|
||||
request: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Toggle a model's enabled status with optional migration when disabling.
|
||||
|
||||
Body params:
|
||||
is_enabled: bool
|
||||
migrate_to_slug: optional str
|
||||
migration_reason: optional str
|
||||
custom_credit_cost: optional int
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id=existing.id,
|
||||
is_enabled=request.get("is_enabled", True),
|
||||
migrate_to_slug=request.get("migrate_to_slug"),
|
||||
migration_reason=request.get("migration_reason"),
|
||||
custom_credit_cost=request.get("custom_credit_cost"),
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Toggled model '{slug}' enabled={request.get('is_enabled')} "
|
||||
f"(migrated {result['nodes_migrated']} nodes)"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model toggle failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to toggle model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to toggle model")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/migrations",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""List model migrations."""
|
||||
try:
|
||||
migrations = await db_write.list_migrations(
|
||||
include_reverted=include_reverted
|
||||
)
|
||||
return {"migrations": migrations}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list migrations: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to list migrations"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/migrations/{migration_id}/revert",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Revert a model migration, restoring affected nodes."""
|
||||
try:
|
||||
result = await db_write.revert_migration(
|
||||
migration_id=migration_id,
|
||||
re_enable_source_model=re_enable_source_model,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Reverted migration {migration_id}: "
|
||||
f"{result['nodes_reverted']} nodes restored"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.warning(f"Migration revert failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to revert migration: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to revert migration"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/providers",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_provider(
|
||||
request: CreateLlmProviderRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new LLM provider.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
provider = await db_write.create_provider(
|
||||
name=request.name,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created provider '{request.name}' (id: {provider.id})")
|
||||
return _map_provider_response(provider)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider creation validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create provider")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/providers/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_provider(
|
||||
name: str,
|
||||
request: UpdateLlmProviderRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing LLM provider.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
# Find provider by name first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
provider = await db_write.update_provider(
|
||||
provider_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated provider '{name}' (id: {provider.id})")
|
||||
return _map_provider_response(provider)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider update validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update provider")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/providers/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_provider(
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
Requires admin authentication.
|
||||
A provider can only be deleted if it has no associated models.
|
||||
"""
|
||||
try:
|
||||
# Find provider by name first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
await db_write.delete_provider(provider_id=existing.id)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Deleted provider '{name}' (id: {existing.id})")
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider deletion validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete provider")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/admin/providers",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def admin_list_providers() -> dict[str, Any]:
|
||||
"""List all LLM providers from the database.
|
||||
|
||||
Unlike the public endpoint, this returns ALL providers including
|
||||
those with no models. Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
providers = await prisma.models.LlmProvider.prisma().find_many(
|
||||
order={"name": "asc"},
|
||||
include={"Models": True},
|
||||
)
|
||||
return {
|
||||
"providers": [
|
||||
{**_map_provider_response(p), "model_count": len(p.Models) if p.Models else 0}
|
||||
for p in providers
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list providers: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list providers")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/admin/models",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def admin_list_models(
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
enabled_only: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""List all LLM models from the database.
|
||||
|
||||
Unlike the public endpoint, this returns full model data including
|
||||
costs and creator info. Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
where = {"isEnabled": True} if enabled_only else {}
|
||||
models = await prisma.models.LlmModel.prisma().find_many(
|
||||
where=where,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
order={"displayName": "asc"},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
**_map_model_response(m),
|
||||
"creator": _map_creator_response(m.Creator) if m.Creator else None,
|
||||
"costs": [
|
||||
{
|
||||
"unit": c.unit,
|
||||
"credit_cost": float(c.creditCost),
|
||||
"credential_provider": c.credentialProvider,
|
||||
"credential_type": c.credentialType,
|
||||
"metadata": dict(c.metadata or {}),
|
||||
}
|
||||
for c in (m.Costs or [])
|
||||
],
|
||||
}
|
||||
for m in models
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list models: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list models")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/creators",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def list_creators() -> dict[str, Any]:
|
||||
"""List all LLM model creators.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
creators = await prisma.models.LlmModelCreator.prisma().find_many(
|
||||
order={"name": "asc"}
|
||||
)
|
||||
logger.info(f"Retrieved {len(creators)} creators")
|
||||
return {"creators": [_map_creator_response(c) for c in creators]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list creators: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list creators")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/creators",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_creator(
|
||||
request: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new LLM model creator."""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
creator = await prisma.models.LlmModelCreator.prisma().create(
|
||||
data={
|
||||
"name": request["name"],
|
||||
"displayName": request["display_name"],
|
||||
"description": request.get("description"),
|
||||
"websiteUrl": request.get("website_url"),
|
||||
"logoUrl": request.get("logo_url"),
|
||||
"metadata": prisma.Json(request.get("metadata", {})),
|
||||
}
|
||||
)
|
||||
logger.info(f"Created creator '{creator.name}' (id: {creator.id})")
|
||||
return _map_creator_response(creator)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/creators/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_creator(
|
||||
name: str,
|
||||
request: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing LLM model creator."""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Creator '{name}' not found"
|
||||
)
|
||||
|
||||
data: dict[str, Any] = {}
|
||||
if "display_name" in request:
|
||||
data["displayName"] = request["display_name"]
|
||||
if "description" in request:
|
||||
data["description"] = request["description"]
|
||||
if "website_url" in request:
|
||||
data["websiteUrl"] = request["website_url"]
|
||||
if "logo_url" in request:
|
||||
data["logoUrl"] = request["logo_url"]
|
||||
|
||||
creator = await prisma.models.LlmModelCreator.prisma().update(
|
||||
where={"id": existing.id},
|
||||
data=data,
|
||||
)
|
||||
logger.info(f"Updated creator '{name}' (id: {creator.id})")
|
||||
return _map_creator_response(creator)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/creators/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_creator(
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete an LLM model creator."""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"name": name},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Creator '{name}' not found"
|
||||
)
|
||||
|
||||
if existing.Models and len(existing.Models) > 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot delete creator '{name}' — it has {len(existing.Models)} associated models",
|
||||
)
|
||||
|
||||
await prisma.models.LlmModelCreator.prisma().delete(
|
||||
where={"id": existing.id}
|
||||
)
|
||||
logger.info(f"Deleted creator '{name}' (id: {existing.id})")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
588
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
588
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
@@ -0,0 +1,588 @@
|
||||
"""Database write operations for LLM registry admin API."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
import prisma.models
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.db import transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_provider_data(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build provider data dict for Prisma operations."""
|
||||
return {
|
||||
"name": name,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"defaultCredentialProvider": default_credential_provider,
|
||||
"defaultCredentialId": default_credential_id,
|
||||
"defaultCredentialType": default_credential_type,
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
|
||||
|
||||
def _build_model_data(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build model data dict for Prisma operations."""
|
||||
data: dict[str, Any] = {
|
||||
"slug": slug,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"Provider": {"connect": {"id": provider_id}},
|
||||
"contextWindow": context_window,
|
||||
"maxOutputTokens": max_output_tokens,
|
||||
"priceTier": price_tier,
|
||||
"isEnabled": is_enabled,
|
||||
"isRecommended": is_recommended,
|
||||
"supportsTools": supports_tools,
|
||||
"supportsJsonOutput": supports_json_output,
|
||||
"supportsReasoning": supports_reasoning,
|
||||
"supportsParallelToolCalls": supports_parallel_tool_calls,
|
||||
"capabilities": prisma.Json(capabilities or {}),
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
if creator_id:
|
||||
data["Creator"] = {"connect": {"id": creator_id}}
|
||||
return data
|
||||
|
||||
|
||||
async def create_provider(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Create a new LLM provider."""
|
||||
data = _build_provider_data(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
default_credential_provider=default_credential_provider,
|
||||
default_credential_id=default_credential_id,
|
||||
default_credential_type=default_credential_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
provider = await prisma.models.LlmProvider.prisma().create(
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError("Failed to create provider")
|
||||
return provider
|
||||
|
||||
|
||||
async def update_provider(
|
||||
provider_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Update an existing LLM provider."""
|
||||
# Fetch existing provider to get current name
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id}
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if default_credential_provider is not None:
|
||||
data["defaultCredentialProvider"] = default_credential_provider
|
||||
if default_credential_id is not None:
|
||||
data["defaultCredentialId"] = default_credential_id
|
||||
if default_credential_type is not None:
|
||||
data["defaultCredentialType"] = default_credential_type
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
|
||||
updated = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not updated:
|
||||
raise ValueError("Failed to update provider")
|
||||
return updated
|
||||
|
||||
|
||||
async def delete_provider(provider_id: str) -> bool:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
"""
|
||||
# Check if provider exists
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Check if provider has any models
|
||||
model_count = len(provider.Models) if provider.Models else 0
|
||||
if model_count > 0:
|
||||
raise ValueError(
|
||||
f"Cannot delete provider '{provider.displayName}' because it has "
|
||||
f"{model_count} model(s). Delete all models first."
|
||||
)
|
||||
|
||||
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
|
||||
return True
|
||||
|
||||
|
||||
async def create_model(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Create a new LLM model."""
|
||||
data = _build_model_data(
|
||||
slug=slug,
|
||||
display_name=display_name,
|
||||
provider_id=provider_id,
|
||||
context_window=context_window,
|
||||
price_tier=price_tier,
|
||||
description=description,
|
||||
creator_id=creator_id,
|
||||
max_output_tokens=max_output_tokens,
|
||||
is_enabled=is_enabled,
|
||||
is_recommended=is_recommended,
|
||||
supports_tools=supports_tools,
|
||||
supports_json_output=supports_json_output,
|
||||
supports_reasoning=supports_reasoning,
|
||||
supports_parallel_tool_calls=supports_parallel_tool_calls,
|
||||
capabilities=capabilities,
|
||||
metadata=metadata,
|
||||
)
|
||||
model = await prisma.models.LlmModel.prisma().create(
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
if not model:
|
||||
raise ValueError("Failed to create model")
|
||||
return model
|
||||
|
||||
|
||||
async def update_model(
|
||||
model_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
context_window: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
price_tier: int | None = None,
|
||||
is_enabled: bool | None = None,
|
||||
is_recommended: bool | None = None,
|
||||
supports_tools: bool | None = None,
|
||||
supports_json_output: bool | None = None,
|
||||
supports_reasoning: bool | None = None,
|
||||
supports_parallel_tool_calls: bool | None = None,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Update an existing LLM model.
|
||||
|
||||
When is_recommended=True, clears the flag on all other models first so
|
||||
only one model can be recommended at a time.
|
||||
"""
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if context_window is not None:
|
||||
data["contextWindow"] = context_window
|
||||
if max_output_tokens is not None:
|
||||
data["maxOutputTokens"] = max_output_tokens
|
||||
if price_tier is not None:
|
||||
data["priceTier"] = price_tier
|
||||
if is_enabled is not None:
|
||||
data["isEnabled"] = is_enabled
|
||||
if is_recommended is not None:
|
||||
data["isRecommended"] = is_recommended
|
||||
if supports_tools is not None:
|
||||
data["supportsTools"] = supports_tools
|
||||
if supports_json_output is not None:
|
||||
data["supportsJsonOutput"] = supports_json_output
|
||||
if supports_reasoning is not None:
|
||||
data["supportsReasoning"] = supports_reasoning
|
||||
if supports_parallel_tool_calls is not None:
|
||||
data["supportsParallelToolCalls"] = supports_parallel_tool_calls
|
||||
if capabilities is not None:
|
||||
data["capabilities"] = prisma.Json(capabilities)
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
if creator_id is not None:
|
||||
data["creatorId"] = creator_id if creator_id else None
|
||||
|
||||
async with transaction() as tx:
|
||||
# Enforce single recommended model: unset all others first.
|
||||
if is_recommended is True:
|
||||
await tx.llmmodel.update_many(
|
||||
where={"id": {"not": model_id}},
|
||||
data={"isRecommended": False},
|
||||
)
|
||||
|
||||
model = await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return model
|
||||
|
||||
|
||||
async def get_model_usage(slug: str) -> dict[str, Any]:
|
||||
"""Get usage count for a model — how many AgentNodes reference it."""
|
||||
import prisma as prisma_module
|
||||
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
slug,
|
||||
)
|
||||
node_count = int(count_result[0]["count"]) if count_result else 0
|
||||
return {"model_slug": slug, "node_count": node_count}
|
||||
|
||||
|
||||
async def toggle_model_with_migration(
|
||||
model_id: str,
|
||||
is_enabled: bool,
|
||||
migrate_to_slug: str | None = None,
|
||||
migration_reason: str | None = None,
|
||||
custom_credit_cost: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Toggle a model's enabled status, optionally migrating workflows when disabling."""
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
nodes_migrated = 0
|
||||
migration_id: str | None = None
|
||||
|
||||
if not is_enabled and migrate_to_slug:
|
||||
async with transaction() as tx:
|
||||
replacement = await tx.llmmodel.find_unique(
|
||||
where={"slug": migrate_to_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' not found"
|
||||
)
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
node_ids_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT id
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
FOR UPDATE
|
||||
""",
|
||||
model.slug,
|
||||
)
|
||||
migrated_node_ids = (
|
||||
[row["id"] for row in node_ids_result] if node_ids_result else []
|
||||
)
|
||||
nodes_migrated = len(migrated_node_ids)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
""",
|
||||
migrate_to_slug,
|
||||
node_ids_json,
|
||||
)
|
||||
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
migration_record = await tx.llmmodelmigration.create(
|
||||
data={
|
||||
"sourceModelSlug": model.slug,
|
||||
"targetModelSlug": migrate_to_slug,
|
||||
"reason": migration_reason,
|
||||
"migratedNodeIds": json.dumps(migrated_node_ids),
|
||||
"nodeCount": nodes_migrated,
|
||||
"customCreditCost": custom_credit_cost,
|
||||
}
|
||||
)
|
||||
migration_id = migration_record.id
|
||||
else:
|
||||
await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
)
|
||||
|
||||
return {
|
||||
"nodes_migrated": nodes_migrated,
|
||||
"migrated_to_slug": migrate_to_slug if nodes_migrated > 0 else None,
|
||||
"migration_id": migration_id,
|
||||
}
|
||||
|
||||
|
||||
async def delete_model(
|
||||
model_id: str, replacement_model_slug: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Delete an LLM model, optionally migrating affected AgentNodes first.
|
||||
|
||||
If workflows are using this model and no replacement is given, raises ValueError.
|
||||
If replacement is given, atomically migrates all affected nodes then deletes.
|
||||
"""
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
deleted_slug = model.slug
|
||||
deleted_display_name = model.displayName
|
||||
|
||||
async with transaction() as tx:
|
||||
count_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
deleted_slug,
|
||||
)
|
||||
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
if nodes_to_migrate > 0:
|
||||
if not replacement_model_slug:
|
||||
raise ValueError(
|
||||
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
|
||||
f"are using it. Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
replacement = await tx.llmmodel.find_unique(
|
||||
where={"slug": replacement_model_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' not found"
|
||||
)
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' is disabled."
|
||||
)
|
||||
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug,
|
||||
)
|
||||
|
||||
await tx.llmmodel.delete(where={"id": model_id})
|
||||
|
||||
return {
|
||||
"deleted_model_slug": deleted_slug,
|
||||
"deleted_model_display_name": deleted_display_name,
|
||||
"replacement_model_slug": replacement_model_slug,
|
||||
"nodes_migrated": nodes_to_migrate,
|
||||
}
|
||||
|
||||
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List model migrations."""
|
||||
where: Any = None if include_reverted else {"isReverted": False}
|
||||
records = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||
where=where,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"source_model_slug": r.sourceModelSlug,
|
||||
"target_model_slug": r.targetModelSlug,
|
||||
"reason": r.reason,
|
||||
"node_count": r.nodeCount,
|
||||
"custom_credit_cost": r.customCreditCost,
|
||||
"is_reverted": r.isReverted,
|
||||
"reverted_at": r.revertedAt.isoformat() if r.revertedAt else None,
|
||||
"created_at": r.createdAt.isoformat(),
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
|
||||
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Revert a model migration, restoring affected nodes to their original model."""
|
||||
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
if not migration:
|
||||
raise ValueError(f"Migration with id '{migration_id}' not found")
|
||||
|
||||
if migration.isReverted:
|
||||
raise ValueError(
|
||||
f"Migration '{migration_id}' has already been reverted"
|
||||
)
|
||||
|
||||
source_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migration.sourceModelSlug}
|
||||
)
|
||||
if not source_model:
|
||||
raise ValueError(
|
||||
f"Source model '{migration.sourceModelSlug}' no longer exists."
|
||||
)
|
||||
|
||||
migrated_node_ids: list[str] = (
|
||||
migration.migratedNodeIds
|
||||
if isinstance(migration.migratedNodeIds, list)
|
||||
else json.loads(migration.migratedNodeIds) # type: ignore
|
||||
)
|
||||
if not migrated_node_ids:
|
||||
raise ValueError("No nodes to revert in this migration")
|
||||
|
||||
source_model_re_enabled = False
|
||||
|
||||
async with transaction() as tx:
|
||||
if not source_model.isEnabled and re_enable_source_model:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": source_model.id},
|
||||
data={"isEnabled": True},
|
||||
)
|
||||
source_model_re_enabled = True
|
||||
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
result = await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
AND "constantInput"::jsonb->>'model' = $3
|
||||
""",
|
||||
migration.sourceModelSlug,
|
||||
node_ids_json,
|
||||
migration.targetModelSlug,
|
||||
)
|
||||
nodes_reverted = result if isinstance(result, int) else 0
|
||||
|
||||
await tx.llmmodelmigration.update(
|
||||
where={"id": migration_id},
|
||||
data={
|
||||
"isReverted": True,
|
||||
"revertedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"migration_id": migration_id,
|
||||
"source_model_slug": migration.sourceModelSlug,
|
||||
"target_model_slug": migration.targetModelSlug,
|
||||
"nodes_reverted": nodes_reverted,
|
||||
"nodes_already_changed": len(migrated_node_ids) - nodes_reverted,
|
||||
"source_model_re_enabled": source_model_re_enabled,
|
||||
}
|
||||
|
||||
|
||||
async def refresh_runtime_caches() -> None:
|
||||
"""Invalidate the shared Redis cache, refresh this process, notify other workers."""
|
||||
from backend.data.llm_registry.notifications import (
|
||||
publish_registry_refresh_notification,
|
||||
)
|
||||
|
||||
# Invalidate Redis so the next fetch hits the DB.
|
||||
llm_registry.clear_registry_cache()
|
||||
# Refresh this process (also repopulates Redis via @cached(shared_cache=True)).
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Tell other workers to reload their in-process cache from the fresh Redis data.
|
||||
await publish_registry_refresh_notification()
|
||||
68
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
68
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Pydantic models for LLM registry public API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class LlmModelCost(pydantic.BaseModel):
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int = pydantic.Field(ge=0)
|
||||
credential_provider: str
|
||||
credential_id: str | None = None
|
||||
credential_type: str | None = None
|
||||
currency: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCreator(pydantic.BaseModel):
|
||||
"""Represents the organization that created/trained the model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
|
||||
|
||||
class LlmModel(pydantic.BaseModel):
|
||||
"""Public-facing LLM model information."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
provider_name: str
|
||||
creator: LlmModelCreator | None = None
|
||||
context_window: int
|
||||
max_output_tokens: int | None = None
|
||||
price_tier: int # 1=cheapest, 2=medium, 3=expensive
|
||||
is_enabled: bool = True
|
||||
is_recommended: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvider(pydantic.BaseModel):
|
||||
"""Provider with its enabled models."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
models: list[LlmModel] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmModelsResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/models."""
|
||||
|
||||
models: list[LlmModel]
|
||||
total: int
|
||||
|
||||
|
||||
class LlmProvidersResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/providers."""
|
||||
|
||||
providers: list[LlmProvider]
|
||||
143
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
143
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Public read-only API for LLM registry."""
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data.llm_registry import (
|
||||
RegistryModelCreator,
|
||||
get_all_models,
|
||||
get_enabled_models,
|
||||
)
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/llm",
|
||||
tags=["llm"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(
|
||||
creator: RegistryModelCreator | None,
|
||||
) -> llm_model.LlmModelCreator | None:
|
||||
"""Convert registry creator to API model."""
|
||||
if not creator:
|
||||
return None
|
||||
return llm_model.LlmModelCreator(
|
||||
id=creator.id,
|
||||
name=creator.name,
|
||||
display_name=creator.display_name,
|
||||
description=creator.description,
|
||||
website_url=creator.website_url,
|
||||
logo_url=creator.logo_url,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models(
|
||||
enabled_only: bool = fastapi.Query(
|
||||
default=True, description="Only return enabled models"
|
||||
),
|
||||
):
|
||||
"""
|
||||
List all LLM models available to users.
|
||||
|
||||
Returns models from the in-memory registry cache.
|
||||
Use enabled_only=true to filter to only enabled models (default).
|
||||
"""
|
||||
# Get models from in-memory registry
|
||||
registry_models = get_enabled_models() if enabled_only else get_all_models()
|
||||
|
||||
# Map to API response models
|
||||
models = [
|
||||
llm_model.LlmModel(
|
||||
slug=model.slug,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
provider_name=model.provider_display_name,
|
||||
creator=_map_creator(model.creator),
|
||||
context_window=model.metadata.context_window,
|
||||
max_output_tokens=model.metadata.max_output_tokens,
|
||||
price_tier=model.metadata.price_tier,
|
||||
is_enabled=model.is_enabled,
|
||||
is_recommended=model.is_recommended,
|
||||
capabilities=model.capabilities,
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
unit=cost.unit,
|
||||
credit_cost=cost.credit_cost,
|
||||
credential_provider=cost.credential_provider,
|
||||
credential_id=cost.credential_id,
|
||||
credential_type=cost.credential_type,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata,
|
||||
)
|
||||
for cost in model.costs
|
||||
],
|
||||
)
|
||||
for model in registry_models
|
||||
]
|
||||
|
||||
return llm_model.LlmModelsResponse(models=models, total=len(models))
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
"""
|
||||
List all LLM providers with their enabled models.
|
||||
|
||||
Groups enabled models by provider from the in-memory registry.
|
||||
"""
|
||||
# Get all enabled models and group by provider
|
||||
registry_models = get_enabled_models()
|
||||
|
||||
# Group models by provider
|
||||
provider_map: dict[str, list] = {}
|
||||
for model in registry_models:
|
||||
provider_key = model.metadata.provider
|
||||
if provider_key not in provider_map:
|
||||
provider_map[provider_key] = []
|
||||
provider_map[provider_key].append(model)
|
||||
|
||||
# Build provider responses
|
||||
providers = []
|
||||
for provider_key, models in sorted(provider_map.items()):
|
||||
# Use the first model's provider display name
|
||||
display_name = models[0].provider_display_name if models else provider_key
|
||||
|
||||
providers.append(
|
||||
llm_model.LlmProvider(
|
||||
name=provider_key,
|
||||
display_name=display_name,
|
||||
models=[
|
||||
llm_model.LlmModel(
|
||||
slug=model.slug,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
provider_name=model.provider_display_name,
|
||||
creator=_map_creator(model.creator),
|
||||
context_window=model.metadata.context_window,
|
||||
max_output_tokens=model.metadata.max_output_tokens,
|
||||
price_tier=model.metadata.price_tier,
|
||||
is_enabled=model.is_enabled,
|
||||
is_recommended=model.is_recommended,
|
||||
capabilities=model.capabilities,
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
unit=cost.unit,
|
||||
credit_cost=cost.credit_cost,
|
||||
credential_provider=cost.credential_provider,
|
||||
credential_id=cost.credential_id,
|
||||
credential_type=cost.credential_type,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata,
|
||||
)
|
||||
for cost in model.costs
|
||||
],
|
||||
)
|
||||
for model in sorted(models, key=lambda m: m.display_name)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
@@ -0,0 +1,148 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmProvider" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"defaultCredentialProvider" TEXT,
|
||||
"defaultCredentialId" TEXT,
|
||||
"defaultCredentialType" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCreator" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"websiteUrl" TEXT,
|
||||
"logoUrl" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModel" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"slug" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"providerId" TEXT NOT NULL,
|
||||
"creatorId" TEXT,
|
||||
"contextWindow" INTEGER NOT NULL,
|
||||
"maxOutputTokens" INTEGER,
|
||||
"priceTier" INTEGER NOT NULL DEFAULT 1,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"isRecommended" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsTools" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsReasoning" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsParallelToolCalls" BOOLEAN NOT NULL DEFAULT false,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}',
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCost" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
|
||||
"creditCost" INTEGER NOT NULL,
|
||||
"credentialProvider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"credentialType" TEXT,
|
||||
"currency" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
"llmModelId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelMigration" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"sourceModelSlug" TEXT NOT NULL,
|
||||
"targetModelSlug" TEXT NOT NULL,
|
||||
"reason" TEXT,
|
||||
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
|
||||
"nodeCount" INTEGER NOT NULL,
|
||||
"customCreditCost" INTEGER,
|
||||
"isReverted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"revertedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmProvider_name_key" ON "LlmProvider"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModel_slug_key" ON "LlmModel"("slug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- CreateIndex (partial unique for default costs - no specific credential)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_default_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL;
|
||||
|
||||
-- CreateIndex (partial unique for credential-specific costs)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_credential_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "credentialId", "unit") WHERE "credentialId" IS NOT NULL;
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");
|
||||
|
||||
-- CreateIndex (partial unique to prevent multiple active migrations per source)
|
||||
CREATE UNIQUE INDEX "LlmModelMigration_active_source_key" ON "LlmModelMigration"("sourceModelSlug") WHERE "isReverted" = false;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey" FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_sourceModelSlug_fkey" FOREIGN KEY ("sourceModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_targetModelSlug_fkey" FOREIGN KEY ("targetModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddCheckConstraints (enforce data integrity)
|
||||
ALTER TABLE "LlmModel"
|
||||
ADD CONSTRAINT "LlmModel_priceTier_check" CHECK ("priceTier" BETWEEN 1 AND 3);
|
||||
|
||||
ALTER TABLE "LlmModelCost"
|
||||
ADD CONSTRAINT "LlmModelCost_creditCost_check" CHECK ("creditCost" >= 0);
|
||||
|
||||
ALTER TABLE "LlmModelMigration"
|
||||
ADD CONSTRAINT "LlmModelMigration_nodeCount_check" CHECK ("nodeCount" >= 0),
|
||||
ADD CONSTRAINT "LlmModelMigration_customCreditCost_check" CHECK ("customCreditCost" IS NULL OR "customCreditCost" >= 0);
|
||||
@@ -0,0 +1,287 @@
|
||||
-- Seed LLM Registry from existing hard-coded data
|
||||
-- This migration populates the LlmProvider, LlmModelCreator, LlmModel, and LlmModelCost tables
|
||||
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
|
||||
|
||||
-- Insert Providers
|
||||
INSERT INTO "LlmProvider" ("id", "createdAt", "updatedAt", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Model Creators
|
||||
INSERT INTO "LlmModelCreator" ("id", "createdAt", "updatedAt", "name", "displayName", "description", "websiteUrl", "logoUrl", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT, O1, O3, and DALL-E models', 'https://openai.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude AI models', 'https://anthropic.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama foundation models', 'https://llama.meta.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini and PaLM models', 'https://deepmind.google', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'mistralai', 'Mistral AI', 'Creator of Mistral and Codestral models', 'https://mistral.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command language models', 'https://cohere.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek reasoning models', 'https://deepseek.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'alibaba', 'Alibaba', 'Creator of Qwen language models', 'https://qwenlm.github.io', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 AI models', 'https://v0.dev', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of Phi models', 'https://microsoft.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar search models', 'https://perplexity.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'nousresearch', 'Nous Research', 'Creator of Hermes language models', 'https://nousresearch.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova language models', 'https://aws.amazon.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'moonshotai', 'Moonshot AI', 'Creator of Kimi language models', 'https://moonshot.ai', NULL, '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Models (using CTEs to reference provider and creator IDs)
|
||||
WITH provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
),
|
||||
creator_ids AS (
|
||||
SELECT "id", "name" FROM "LlmModelCreator"
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "createdAt", "updatedAt", "slug", "displayName", "description", "providerId", "creatorId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP,
|
||||
model_slug,
|
||||
model_display_name,
|
||||
NULL,
|
||||
p."id",
|
||||
c."id",
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM (VALUES
|
||||
-- OpenAI models (creator: openai)
|
||||
('o3-2025-04-16', 'O3', 'openai', 'openai', 200000, 100000),
|
||||
('o3-mini', 'O3 Mini', 'openai', 'openai', 200000, 100000),
|
||||
('o1', 'O1', 'openai', 'openai', 200000, 100000),
|
||||
('o1-mini', 'O1 Mini', 'openai', 'openai', 128000, 65536),
|
||||
('gpt-5.2-2025-12-11', 'GPT-5.2', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-2025-08-07', 'GPT 5', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 'openai', 400000, 16384),
|
||||
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 'openai', 1000000, 32768),
|
||||
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 'openai', 1047576, 32768),
|
||||
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 'openai', 128000, 16384),
|
||||
('gpt-4o', 'GPT 4o', 'openai', 'openai', 128000, 16384),
|
||||
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 'openai', 128000, 4096),
|
||||
-- Anthropic models (creator: anthropic)
|
||||
('claude-opus-4-6', 'Claude Opus 4.6', 'anthropic', 'anthropic', 200000, 128000),
|
||||
('claude-sonnet-4-6', 'Claude Sonnet 4.6', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 'anthropic', 200000, 32000),
|
||||
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 'anthropic', 200000, 32000),
|
||||
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 'anthropic', 200000, 4096),
|
||||
-- AI/ML API models (creators: alibaba, nvidia, meta)
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 'alibaba', 32000, 8000),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 'nvidia', 128000, 40000),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 'meta', 128000, NULL),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 'meta', 131000, 2000),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 'meta', 128000, NULL),
|
||||
-- Groq models (creator: meta for Llama)
|
||||
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 'meta', 128000, 32768),
|
||||
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 'meta', 128000, 8192),
|
||||
-- Ollama models (creators: meta for Llama, mistralai for Mistral)
|
||||
('llama3.3', 'Llama 3.3', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3.2', 'Llama 3.2', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3', 'Llama 3', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 'meta', 8192, NULL),
|
||||
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 'mistralai', 32768, NULL),
|
||||
-- OpenRouter models (creators: google, mistralai, cohere, deepseek, perplexity, nousresearch, openai, amazon, microsoft, gryphe, meta, xai, moonshotai, alibaba)
|
||||
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 'google', 1050000, 8192),
|
||||
('google/gemini-2.5-pro', 'Gemini 2.5 Pro', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-3.1-pro-preview', 'Gemini 3.1 Pro Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-3-flash-preview', 'Gemini 3 Flash Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 'google', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 'google', 1048576, 8192),
|
||||
('google/gemini-3.1-flash-lite-preview', 'Gemini 3.1 Flash Lite Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 'google', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 'google', 1048576, 8192),
|
||||
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 'mistralai', 128000, 4096),
|
||||
('mistralai/mistral-large-2512', 'Mistral Large 3 2512', 'open_router', 'mistralai', 262144, NULL),
|
||||
('mistralai/mistral-medium-3.1', 'Mistral Medium 3.1', 'open_router', 'mistralai', 131072, NULL),
|
||||
('mistralai/mistral-small-3.2-24b-instruct', 'Mistral Small 3.2 24B', 'open_router', 'mistralai', 131072, 131072),
|
||||
('mistralai/codestral-2508', 'Codestral 2508', 'open_router', 'mistralai', 256000, NULL),
|
||||
('cohere/command-r-08-2024', 'Command R', 'open_router', 'cohere', 128000, 4096),
|
||||
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 'cohere', 128000, 4096),
|
||||
('cohere/command-a-03-2025', 'Command A 03.2025', 'open_router', 'cohere', 256000, 8192),
|
||||
('cohere/command-a-reasoning-08-2025', 'Command A Reasoning 08.2025', 'open_router', 'cohere', 256000, 32768),
|
||||
('cohere/command-a-translate-08-2025', 'Command A Translate 08.2025', 'open_router', 'cohere', 128000, 8192),
|
||||
('cohere/command-a-vision-07-2025', 'Command A Vision 07.2025', 'open_router', 'cohere', 128000, 8192),
|
||||
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 'deepseek', 64000, 2048),
|
||||
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 'deepseek', 163840, 163840),
|
||||
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 'perplexity', 127000, 8000),
|
||||
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 'perplexity', 200000, 8000),
|
||||
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 'perplexity', 128000, 16000),
|
||||
('perplexity/sonar-reasoning-pro', 'Sonar Reasoning Pro', 'open_router', 'perplexity', 128000, 8000),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 'nousresearch', 131000, 4096),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 'nousresearch', 12288, 12288),
|
||||
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 'openai', 131072, 131072),
|
||||
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 'openai', 131072, 32768),
|
||||
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 'amazon', 300000, 5120),
|
||||
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 'amazon', 128000, 5120),
|
||||
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 'amazon', 300000, 5120),
|
||||
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 'microsoft', 65536, 4096),
|
||||
('microsoft/phi-4', 'Phi-4', 'open_router', 'microsoft', 16384, 16384),
|
||||
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 'gryphe', 4096, 4096),
|
||||
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 'meta', 131072, 131072),
|
||||
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 'meta', 1048576, 1000000),
|
||||
('x-ai/grok-3', 'Grok 3', 'open_router', 'xai', 131072, 131072),
|
||||
('x-ai/grok-4', 'Grok 4', 'open_router', 'xai', 256000, 256000),
|
||||
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 'xai', 2000000, 30000),
|
||||
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 'xai', 2000000, 30000),
|
||||
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 'xai', 256000, 10000),
|
||||
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 'moonshotai', 131000, 131000),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 'alibaba', 262144, 262144),
|
||||
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 'alibaba', 262144, 262144),
|
||||
-- Llama API models (creator: meta)
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 'meta', 128000, 4028),
|
||||
-- v0 models (creator: vercel)
|
||||
('v0-1.5-md', 'v0 1.5 MD', 'v0', 'vercel', 128000, 64000),
|
||||
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 'vercel', 512000, 64000),
|
||||
('v0-1.0-md', 'v0 1.0 MD', 'v0', 'vercel', 128000, 64000)
|
||||
) AS models(model_slug, model_display_name, provider_name, creator_name, context_window, max_output_tokens)
|
||||
JOIN provider_ids p ON p."name" = models.provider_name
|
||||
JOIN creator_ids c ON c."name" = models.creator_name
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Insert Costs (using CTEs to reference model IDs)
|
||||
WITH model_ids AS (
|
||||
SELECT "id", "slug", "providerId" FROM "LlmModel"
|
||||
),
|
||||
provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "createdAt", "updatedAt", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP,
|
||||
'RUN'::"LlmCostUnit",
|
||||
cost,
|
||||
p."name",
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM (VALUES
|
||||
-- OpenAI costs
|
||||
('o3-2025-04-16', 4),
|
||||
('o3-mini', 2),
|
||||
('o1', 16),
|
||||
('o1-mini', 4),
|
||||
('gpt-5.2-2025-12-11', 5),
|
||||
('gpt-5-2025-08-07', 2),
|
||||
('gpt-5.1-2025-11-13', 5),
|
||||
('gpt-5-mini-2025-08-07', 1),
|
||||
('gpt-5-nano-2025-08-07', 1),
|
||||
('gpt-5-chat-latest', 5),
|
||||
('gpt-4.1-2025-04-14', 2),
|
||||
('gpt-4.1-mini-2025-04-14', 1),
|
||||
('gpt-4o-mini', 1),
|
||||
('gpt-4o', 3),
|
||||
('gpt-4-turbo', 10),
|
||||
-- Anthropic costs
|
||||
('claude-opus-4-6', 21),
|
||||
('claude-sonnet-4-6', 5),
|
||||
('claude-opus-4-1-20250805', 21),
|
||||
('claude-opus-4-20250514', 21),
|
||||
('claude-sonnet-4-20250514', 5),
|
||||
('claude-haiku-4-5-20251001', 4),
|
||||
('claude-opus-4-5-20251101', 14),
|
||||
('claude-sonnet-4-5-20250929', 9),
|
||||
('claude-3-haiku-20240307', 1),
|
||||
-- AI/ML API costs
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
|
||||
-- Groq costs
|
||||
('llama-3.3-70b-versatile', 1),
|
||||
('llama-3.1-8b-instant', 1),
|
||||
-- Ollama costs
|
||||
('llama3.3', 1),
|
||||
('llama3.2', 1),
|
||||
('llama3', 1),
|
||||
('llama3.1:405b', 1),
|
||||
('dolphin-mistral:latest', 1),
|
||||
-- OpenRouter costs
|
||||
('google/gemini-2.5-pro-preview-03-25', 4),
|
||||
('google/gemini-2.5-pro', 4),
|
||||
('google/gemini-3.1-pro-preview', 5),
|
||||
('google/gemini-3-flash-preview', 3),
|
||||
('google/gemini-3.1-flash-lite-preview', 1),
|
||||
('mistralai/mistral-nemo', 1),
|
||||
('mistralai/mistral-large-2512', 3),
|
||||
('mistralai/mistral-medium-3.1', 2),
|
||||
('mistralai/mistral-small-3.2-24b-instruct', 1),
|
||||
('mistralai/codestral-2508', 2),
|
||||
('cohere/command-r-08-2024', 1),
|
||||
('cohere/command-r-plus-08-2024', 3),
|
||||
('cohere/command-a-03-2025', 2),
|
||||
('cohere/command-a-reasoning-08-2025', 3),
|
||||
('cohere/command-a-translate-08-2025', 1),
|
||||
('cohere/command-a-vision-07-2025', 2),
|
||||
('deepseek/deepseek-chat', 2),
|
||||
('perplexity/sonar', 1),
|
||||
('perplexity/sonar-pro', 5),
|
||||
('perplexity/sonar-deep-research', 10),
|
||||
('perplexity/sonar-reasoning-pro', 5),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 1),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 1),
|
||||
('amazon/nova-lite-v1', 1),
|
||||
('amazon/nova-micro-v1', 1),
|
||||
('amazon/nova-pro-v1', 1),
|
||||
('microsoft/wizardlm-2-8x22b', 1),
|
||||
('microsoft/phi-4', 1),
|
||||
('gryphe/mythomax-l2-13b', 1),
|
||||
('meta-llama/llama-4-scout', 1),
|
||||
('meta-llama/llama-4-maverick', 1),
|
||||
('x-ai/grok-3', 5),
|
||||
('x-ai/grok-4', 9),
|
||||
('x-ai/grok-4-fast', 1),
|
||||
('x-ai/grok-4.1-fast', 1),
|
||||
('x-ai/grok-code-fast-1', 1),
|
||||
('moonshotai/kimi-k2', 1),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 1),
|
||||
('qwen/qwen3-coder', 9),
|
||||
('google/gemini-2.5-flash', 1),
|
||||
('google/gemini-2.0-flash-001', 1),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 1),
|
||||
('google/gemini-2.0-flash-lite-001', 1),
|
||||
('deepseek/deepseek-r1-0528', 1),
|
||||
('openai/gpt-oss-120b', 1),
|
||||
('openai/gpt-oss-20b', 1),
|
||||
-- Llama API costs
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
|
||||
('Llama-3.3-8B-Instruct', 1),
|
||||
('Llama-3.3-70B-Instruct', 1),
|
||||
-- v0 costs
|
||||
('v0-1.5-md', 1),
|
||||
('v0-1.5-lg', 2),
|
||||
('v0-1.0-md', 1)
|
||||
) AS costs(model_slug, cost)
|
||||
JOIN model_ids m ON m."slug" = costs.model_slug
|
||||
JOIN provider_ids p ON p."id" = m."providerId"
|
||||
ON CONFLICT ("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL DO NOTHING;
|
||||
|
||||
@@ -1301,3 +1301,164 @@ model OAuthRefreshToken {
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LLM Registry Models
|
||||
// ============================================================================
|
||||
|
||||
enum LlmCostUnit {
|
||||
RUN
|
||||
TOKENS
|
||||
}
|
||||
|
||||
model LlmProvider {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
defaultCredentialProvider String?
|
||||
defaultCredentialId String?
|
||||
defaultCredentialType String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModel {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
slug String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
providerId String
|
||||
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
|
||||
|
||||
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
|
||||
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
|
||||
creatorId String?
|
||||
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
|
||||
|
||||
contextWindow Int
|
||||
maxOutputTokens Int?
|
||||
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive (DB constraint: 1-3)
|
||||
isEnabled Boolean @default(true)
|
||||
isRecommended Boolean @default(false)
|
||||
|
||||
// Model-specific capabilities
|
||||
// These vary per model even within the same provider (e.g., Hugging Face)
|
||||
// Default to false for safety - partially-seeded rows should not be assumed capable
|
||||
supportsTools Boolean @default(false)
|
||||
supportsJsonOutput Boolean @default(false)
|
||||
supportsReasoning Boolean @default(false)
|
||||
supportsParallelToolCalls Boolean @default(false)
|
||||
|
||||
capabilities Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
|
||||
Costs LlmModelCost[]
|
||||
SourceMigrations LlmModelMigration[] @relation("SourceMigrations")
|
||||
TargetMigrations LlmModelMigration[] @relation("TargetMigrations")
|
||||
|
||||
@@index([providerId, isEnabled])
|
||||
@@index([creatorId])
|
||||
// Note: slug already has @unique which creates an implicit index
|
||||
}
|
||||
|
||||
model LlmModelCost {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
unit LlmCostUnit @default(RUN)
|
||||
|
||||
creditCost Int // DB constraint: >= 0
|
||||
|
||||
// Provider identifier (e.g., "openai", "anthropic", "openrouter")
|
||||
// Used to determine which credential system provides the API key.
|
||||
// Allows different pricing for:
|
||||
// - Default provider costs (WHERE credentialId IS NULL)
|
||||
// - User's own API key costs (WHERE credentialId IS NOT NULL)
|
||||
credentialProvider String
|
||||
credentialId String?
|
||||
credentialType String?
|
||||
currency String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
llmModelId String
|
||||
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Note: Unique constraints are implemented as partial indexes in migration SQL:
|
||||
// - One for default costs (WHERE credentialId IS NULL)
|
||||
// - One for credential-specific costs (WHERE credentialId IS NOT NULL)
|
||||
// This allows both provider-level defaults and credential-specific overrides
|
||||
}
|
||||
|
||||
model LlmModelCreator {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique // e.g., "openai", "anthropic", "meta"
|
||||
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
|
||||
description String?
|
||||
websiteUrl String? // Link to creator's website
|
||||
logoUrl String? // URL to creator's logo
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModelMigration {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
sourceModelSlug String // The original model that was disabled
|
||||
targetModelSlug String // The model workflows were migrated to
|
||||
reason String? // Why the migration happened (e.g., "Provider outage")
|
||||
|
||||
// FK constraints ensure slugs reference valid models
|
||||
SourceModel LlmModel @relation("SourceMigrations", fields: [sourceModelSlug], references: [slug], onDelete: Restrict)
|
||||
TargetModel LlmModel @relation("TargetMigrations", fields: [targetModelSlug], references: [slug], onDelete: Restrict)
|
||||
|
||||
// Track affected nodes as JSON array of node IDs
|
||||
// Format: ["node-uuid-1", "node-uuid-2", ...]
|
||||
migratedNodeIds Json @default("[]")
|
||||
nodeCount Int // Number of nodes migrated (DB constraint: >= 0)
|
||||
|
||||
// Custom pricing override for migrated workflows during the migration period.
|
||||
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
|
||||
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
|
||||
// to avoid billing surprises, or offer a discount during the transition.
|
||||
//
|
||||
// IMPORTANT: This field is intended for integration with the billing system.
|
||||
// When billing calculates costs for nodes affected by this migration, it should
|
||||
// check if customCreditCost is set and use it instead of the target model's cost.
|
||||
// If null, the target model's normal cost applies.
|
||||
//
|
||||
// TODO: Integrate with billing system to apply this override during cost calculation.
|
||||
// LIMITATION: This is a simple Int and doesn't distinguish RUN vs TOKENS pricing.
|
||||
// For token-priced models, this may be ambiguous. Consider migrating to a relation
|
||||
// with LlmModelCost or a dedicated override model in a follow-up PR.
|
||||
customCreditCost Int? // DB constraint: >= 0 when not null
|
||||
|
||||
// Revert tracking
|
||||
isReverted Boolean @default(false)
|
||||
revertedAt DateTime?
|
||||
|
||||
// Note: Partial unique index in migration SQL prevents multiple active migrations per source:
|
||||
// UNIQUE (sourceModelSlug) WHERE isReverted = false
|
||||
@@index([targetModelSlug])
|
||||
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user