mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
55 Commits
test-scree
...
feat/llm-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
20657e7784 | ||
|
|
dcf207eb72 | ||
|
|
da533089d3 | ||
|
|
5352dc9778 | ||
|
|
fc932aa415 | ||
|
|
405bdb2808 | ||
|
|
6fcd05ef61 | ||
|
|
33c30c6990 | ||
|
|
52d074d31f | ||
|
|
93be8e5095 | ||
|
|
b0d9ef13e6 | ||
|
|
d5a0ce2815 | ||
|
|
3f6b1120f3 | ||
|
|
77757a25a5 | ||
|
|
e192695884 | ||
|
|
5b2d4595d1 | ||
|
|
4e1774c939 | ||
|
|
845ce6ae8d | ||
|
|
84f30775fd | ||
|
|
62065292ec | ||
|
|
dff9b0f3b2 | ||
|
|
fa47d898d1 | ||
|
|
67455f6a35 | ||
|
|
cb8cf81be7 | ||
|
|
ef30c1ed76 | ||
|
|
b5f63c13a4 | ||
|
|
c5dfe3333d | ||
|
|
696b273afc | ||
|
|
732365cd8f | ||
|
|
7e85371ce5 | ||
|
|
3f964c8aba | ||
|
|
ad1f489c5c | ||
|
|
9c77a2207f | ||
|
|
081fa9f2db | ||
|
|
8b6dea2496 | ||
|
|
5e15213846 | ||
|
|
f0cc4ae573 | ||
|
|
e0282b00db | ||
|
|
9a9c36b806 | ||
|
|
d5381625cd | ||
|
|
f6ae3d6593 | ||
|
|
0fb1b854df | ||
|
|
64a011664a | ||
|
|
1db7c048d9 | ||
|
|
4c5627c966 | ||
|
|
d97d137a51 | ||
|
|
ded9e293ff | ||
|
|
83d504bed2 | ||
|
|
a5f1ffb35b | ||
|
|
97c6516a14 | ||
|
|
876dde8bc7 | ||
|
|
0bfdd74b25 | ||
|
|
a7d2f81b18 | ||
|
|
3699eaa556 | ||
|
|
21adf9e0fb |
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import platform
|
||||
@@ -37,8 +38,10 @@ import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.llm_registry
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.v2.llm
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.library.exceptions import (
|
||||
@@ -117,16 +120,47 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
_registry_subscription_task: asyncio.Task | None = None
|
||||
|
||||
try:
|
||||
await backend.data.llm_registry.refresh_llm_registry()
|
||||
logger.info("LLM registry loaded successfully at startup")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load LLM registry at startup: {e}. "
|
||||
"Blocks will initialize with empty registry."
|
||||
)
|
||||
|
||||
_registry_subscription_task = asyncio.create_task(
|
||||
backend.data.llm_registry.subscribe_to_registry_refresh(
|
||||
backend.data.llm_registry.refresh_llm_registry
|
||||
)
|
||||
)
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
try:
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
except Exception as e:
|
||||
if "AgentNode" in str(e):
|
||||
logger.warning("migrate_llm_models skipped: AgentNode table not found (%s)", e)
|
||||
else:
|
||||
logger.error("migrate_llm_models failed unexpectedly: %s", e, exc_info=True)
|
||||
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
if _registry_subscription_task:
|
||||
_registry_subscription_task.cancel()
|
||||
try:
|
||||
await _registry_subscription_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
@@ -355,6 +389,16 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.admin_router,
|
||||
tags=["v2", "llm", "admin"],
|
||||
prefix="/api",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -36,9 +36,9 @@ from backend.util.models import Pagination
|
||||
from backend.util.request import parse_url
|
||||
|
||||
from .block import BlockInput
|
||||
from .db import BaseDbModel
|
||||
from .db import BaseDbModel, execute_raw_with_schema
|
||||
from .db import prisma as db
|
||||
from .db import query_raw_with_schema, transaction
|
||||
from .db import execute_raw_with_schema, query_raw_with_schema, transaction
|
||||
from .dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, MAX_GRAPH_VERSIONS_FETCH
|
||||
from .model import CredentialsFieldInfo, CredentialsMetaInput, is_credentials_field_name
|
||||
@@ -1663,22 +1663,19 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
llm_model_fields[block.id] = field_name
|
||||
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel]
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
enum_values = repr(tuple(v.value for v in LlmModel))
|
||||
|
||||
# Update each block
|
||||
for id, path in llm_model_fields.items():
|
||||
query = f"""
|
||||
UPDATE platform."AgentNode"
|
||||
query = """
|
||||
UPDATE {schema_prefix}"AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? ($4)::text
|
||||
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
|
||||
"""
|
||||
AND "constantInput"->>($4)::text NOT IN """ + enum_values
|
||||
|
||||
await db.execute_raw(
|
||||
query, # type: ignore - is supposed to be LiteralString
|
||||
await execute_raw_with_schema(
|
||||
query,
|
||||
[path],
|
||||
migrate_to.value,
|
||||
id,
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""LLM Registry - Dynamic model management system."""
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from .notifications import (
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
from .registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
clear_registry_cache,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
get_enabled_models,
|
||||
get_model,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Cache management
|
||||
"clear_registry_cache",
|
||||
"publish_registry_refresh_notification",
|
||||
"subscribe_to_registry_refresh",
|
||||
# Read functions
|
||||
"refresh_llm_registry",
|
||||
"get_model",
|
||||
"get_all_models",
|
||||
"get_enabled_models",
|
||||
"get_schema_options",
|
||||
"get_default_model_slug",
|
||||
"get_all_model_slugs_for_validation",
|
||||
]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Pub/sub notifications for LLM registry cross-process synchronisation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
async def publish_registry_refresh_notification() -> None:
|
||||
"""Publish a refresh signal so all other workers reload their in-process cache."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.debug("Published LLM registry refresh notification")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish registry refresh notification: %s", e)
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Listen for registry refresh signals and call on_refresh each time one arrives.
|
||||
|
||||
Designed to run as a long-lived background asyncio.Task. Automatically
|
||||
reconnects if the Redis connection drops.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable invoked on each refresh signal.
|
||||
Typically ``llm_registry.refresh_llm_registry``.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Dedicated connection — pub/sub must not share a connection used
|
||||
# for regular commands.
|
||||
redis_sub = AsyncRedis(
|
||||
host=HOST, port=PORT, password=PASSWORD, decode_responses=True
|
||||
)
|
||||
pubsub = redis_sub.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info("Subscribed to LLM registry refresh channel")
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if (
|
||||
message
|
||||
and message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.debug("LLM registry refresh signal received")
|
||||
try:
|
||||
await on_refresh()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in registry on_refresh callback: %s", e
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error processing registry refresh message: %s", e
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("LLM registry subscription task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"LLM registry subscription error: %s. Retrying in 5s...", e
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
@@ -0,0 +1,195 @@
|
||||
"""Tests for LLM registry pub/sub notifications (notifications.py).
|
||||
|
||||
Covers:
|
||||
- publish_registry_refresh_notification: happy path and Redis error swallowed
|
||||
- subscribe_to_registry_refresh: message triggers on_refresh, non-message
|
||||
types ignored, wrong channel ignored, CancelledError stops the loop,
|
||||
connection errors trigger reconnect
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.llm_registry.notifications import (
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# publish_registry_refresh_notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_sends_to_correct_channel(mocker):
|
||||
"""publish_registry_refresh_notification publishes on the registry channel."""
|
||||
mock_redis = AsyncMock()
|
||||
mocker.patch(
|
||||
"backend.data.redis_client.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
await publish_registry_refresh_notification()
|
||||
|
||||
mock_redis.publish.assert_called_once_with(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_swallows_redis_error(mocker):
|
||||
"""Redis errors during publish are caught and logged, not raised."""
|
||||
mocker.patch(
|
||||
"backend.data.redis_client.get_redis_async",
|
||||
side_effect=ConnectionError("Redis unavailable"),
|
||||
)
|
||||
|
||||
# Should not raise — errors are swallowed to avoid crashing the admin op
|
||||
await publish_registry_refresh_notification()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# subscribe_to_registry_refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_pubsub(messages: list) -> MagicMock:
|
||||
"""Build a mock pubsub that returns messages in sequence then raises CancelledError."""
|
||||
pubsub = AsyncMock()
|
||||
pubsub.subscribe = AsyncMock()
|
||||
# Once messages are exhausted the next get_message raises CancelledError to
|
||||
# break the infinite loop cleanly in tests.
|
||||
pubsub.get_message = AsyncMock(
|
||||
side_effect=messages + [asyncio.CancelledError()]
|
||||
)
|
||||
return pubsub
|
||||
|
||||
|
||||
def _make_message(channel: str = REGISTRY_REFRESH_CHANNEL, msg_type: str = "message"):
|
||||
return {"type": msg_type, "channel": channel, "data": "refresh"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_calls_on_refresh_for_valid_message(mocker):
|
||||
"""A message on the registry channel triggers the on_refresh callback."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([_make_message()])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_ignores_non_message_types(mocker):
|
||||
"""Subscribe messages of type 'subscribe' (handshake) do not trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([
|
||||
_make_message(msg_type="subscribe"), # handshake — should be ignored
|
||||
_make_message(msg_type="psubscribe"), # also ignored
|
||||
])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_ignores_wrong_channel(mocker):
|
||||
"""Messages on a different channel do not trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([_make_message(channel="some:other:channel")])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_handles_none_message(mocker):
|
||||
"""None returned by get_message (timeout) does not crash or trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([None, None]) # two timeouts, then CancelledError
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_processes_multiple_messages(mocker):
|
||||
"""Multiple valid messages each trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([_make_message(), _make_message(), _make_message()])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
assert on_refresh.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_cancelled_error_stops_loop(mocker):
|
||||
"""CancelledError at the outer loop level causes the function to return cleanly."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = AsyncMock()
|
||||
pubsub.subscribe = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
# Should return normally — CancelledError is caught and the loop breaks
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_reconnects_after_connection_error(mocker):
|
||||
"""A connection error on the first attempt triggers a reconnect attempt."""
|
||||
on_refresh = AsyncMock()
|
||||
|
||||
# First call raises ConnectionError; second call succeeds then CancelledError
|
||||
good_pubsub = _make_pubsub([_make_message()])
|
||||
bad_redis = MagicMock()
|
||||
bad_redis.pubsub.side_effect = ConnectionError("Redis down")
|
||||
good_redis = MagicMock()
|
||||
good_redis.pubsub.return_value = good_pubsub
|
||||
|
||||
mock_redis_cls = mocker.patch(
|
||||
"redis.asyncio.Redis", side_effect=[bad_redis, good_redis]
|
||||
)
|
||||
mock_sleep = mocker.patch("asyncio.sleep", new=AsyncMock())
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
# Should have slept before retrying
|
||||
mock_sleep.assert_called_once_with(5)
|
||||
# Should have tried to create two Redis connections
|
||||
assert mock_redis_cls.call_count == 2
|
||||
# After reconnect, the valid message triggered on_refresh
|
||||
on_refresh.assert_called_once()
|
||||
254
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
254
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import prisma.models
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegistryModelCost(BaseModel):
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None = None
|
||||
credential_type: str | None = None
|
||||
currency: str | None = None
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
|
||||
class RegistryModelCreator(BaseModel):
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
|
||||
|
||||
class RegistryModel(BaseModel):
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any] = {}
|
||||
extra_metadata: dict[str, Any] = {}
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = ()
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
# Typed capability fields from DB schema
|
||||
supports_tools: bool = False
|
||||
supports_json_output: bool = False
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool_calls: bool = False
|
||||
|
||||
|
||||
# L1 in-process cache — Redis is the shared L2 via @cached(shared_cache=True)
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: # type: ignore[name-defined]
|
||||
"""Transform a raw Prisma LlmModel record into a RegistryModel instance."""
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
unit=str(cost.unit),
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=dict(cost.metadata or {}),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
capabilities = dict(record.capabilities or {})
|
||||
|
||||
if not record.Provider:
|
||||
logger.warning(
|
||||
"LlmModel %s has no Provider despite NOT NULL FK - "
|
||||
"falling back to providerId %s",
|
||||
record.slug,
|
||||
record.providerId,
|
||||
)
|
||||
provider_name = record.Provider.name if record.Provider else record.providerId
|
||||
provider_display = (
|
||||
record.Provider.displayName if record.Provider else record.providerId
|
||||
)
|
||||
creator_name = record.Creator.displayName if record.Creator else "Unknown"
|
||||
|
||||
if record.priceTier not in (1, 2, 3):
|
||||
logger.warning(
|
||||
"LlmModel %s has out-of-range priceTier=%s, defaulting to 1",
|
||||
record.slug,
|
||||
record.priceTier,
|
||||
)
|
||||
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=(
|
||||
record.maxOutputTokens
|
||||
if record.maxOutputTokens is not None
|
||||
else record.contextWindow
|
||||
),
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier,
|
||||
)
|
||||
|
||||
return RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=capabilities,
|
||||
extra_metadata=dict(record.metadata or {}),
|
||||
provider_display_name=provider_display,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
supports_tools=record.supportsTools,
|
||||
supports_json_output=record.supportsJsonOutput,
|
||||
supports_reasoning=record.supportsReasoning,
|
||||
supports_parallel_tool_calls=record.supportsParallelToolCalls,
|
||||
)
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl_seconds=300, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def _fetch_registry_from_db() -> list[RegistryModel]:
|
||||
"""Fetch all LLM models from the database.
|
||||
|
||||
Results are cached in Redis (shared_cache=True) so subsequent calls within
|
||||
the TTL window skip the DB entirely — both within this process and across
|
||||
all other workers that share the same Redis instance.
|
||||
"""
|
||||
records = await prisma.models.LlmModel.prisma().find_many( # type: ignore[attr-defined]
|
||||
include={"Provider": True, "Costs": True, "Creator": True}
|
||||
)
|
||||
logger.info("Fetched %d LLM models from database", len(records))
|
||||
return [_record_to_registry_model(r) for r in records]
|
||||
|
||||
|
||||
def clear_registry_cache() -> None:
|
||||
"""Invalidate the shared Redis cache for the registry DB fetch.
|
||||
|
||||
Call this before refresh_llm_registry() after any admin DB mutation so the
|
||||
next fetch hits the database rather than serving the now-stale cached data.
|
||||
"""
|
||||
_fetch_registry_from_db.cache_clear()
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the in-process L1 cache from Redis/DB.
|
||||
|
||||
On the first call (or after clear_registry_cache()), fetches fresh data
|
||||
from the database and stores it in Redis. Subsequent calls by other
|
||||
workers hit the Redis cache instead of the DB.
|
||||
"""
|
||||
async with _lock:
|
||||
try:
|
||||
models = await _fetch_registry_from_db()
|
||||
new_models = {m.slug: m for m in models}
|
||||
|
||||
global _dynamic_models, _schema_options
|
||||
_dynamic_models = new_models
|
||||
_schema_options = _build_schema_options()
|
||||
|
||||
logger.info(
|
||||
"LLM registry refreshed: %d models, %d schema options",
|
||||
len(_dynamic_models),
|
||||
len(_schema_options),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to refresh LLM registry: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
return [
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
for model in sorted(
|
||||
_dynamic_models.values(), key=lambda m: m.display_name.lower()
|
||||
)
|
||||
if model.is_enabled
|
||||
]
|
||||
|
||||
|
||||
def get_model(slug: str) -> RegistryModel | None:
|
||||
"""Get a model by slug from the registry."""
|
||||
return _dynamic_models.get(slug)
|
||||
|
||||
|
||||
def get_all_models() -> list[RegistryModel]:
|
||||
"""Get all models from the registry (including disabled)."""
|
||||
return list(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_enabled_models() -> list[RegistryModel]:
|
||||
"""Get only enabled models from the registry."""
|
||||
return [model for model in _dynamic_models.values() if model.is_enabled]
|
||||
|
||||
|
||||
def get_schema_options() -> list[dict[str, str]]:
|
||||
"""Get schema options for model selection dropdown (enabled models only)."""
|
||||
return list(_schema_options)
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""Get the default model slug (first recommended, or first enabled)."""
|
||||
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
|
||||
recommended = next(
|
||||
(m.slug for m in models if m.is_recommended and m.is_enabled), None
|
||||
)
|
||||
return recommended or next((m.slug for m in models if m.is_enabled), None)
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> list[str]:
|
||||
"""Get all model slugs for validation (enabled models only)."""
|
||||
return [model.slug for model in _dynamic_models.values() if model.is_enabled]
|
||||
@@ -0,0 +1,466 @@
|
||||
"""Unit tests for the LLM registry module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import pydantic
|
||||
|
||||
from backend.data.llm_registry.registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
_build_schema_options,
|
||||
_record_to_registry_model,
|
||||
clear_registry_cache,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
get_enabled_models,
|
||||
get_model,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_record(**overrides):
|
||||
"""Build a realistic mock Prisma LlmModel record."""
|
||||
provider = Mock()
|
||||
provider.name = "openai"
|
||||
provider.displayName = "OpenAI"
|
||||
|
||||
record = Mock()
|
||||
record.slug = "openai/gpt-4o"
|
||||
record.displayName = "GPT-4o"
|
||||
record.description = "Latest GPT model"
|
||||
record.providerId = "provider-uuid"
|
||||
record.Provider = provider
|
||||
record.creatorId = "creator-uuid"
|
||||
record.Creator = None
|
||||
record.contextWindow = 128000
|
||||
record.maxOutputTokens = 16384
|
||||
record.priceTier = 2
|
||||
record.isEnabled = True
|
||||
record.isRecommended = False
|
||||
record.supportsTools = True
|
||||
record.supportsJsonOutput = True
|
||||
record.supportsReasoning = False
|
||||
record.supportsParallelToolCalls = True
|
||||
record.capabilities = {}
|
||||
record.metadata = {}
|
||||
record.Costs = []
|
||||
|
||||
for key, value in overrides.items():
|
||||
setattr(record, key, value)
|
||||
return record
|
||||
|
||||
|
||||
def _make_registry_model(**kwargs) -> RegistryModel:
|
||||
"""Build a minimal RegistryModel for testing registry-level functions."""
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
|
||||
defaults = dict(
|
||||
slug="openai/gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description=None,
|
||||
metadata=ModelMetadata(
|
||||
provider="openai",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
display_name="GPT-4o",
|
||||
provider_name="OpenAI",
|
||||
creator_name="Unknown",
|
||||
price_tier=2,
|
||||
),
|
||||
capabilities={},
|
||||
extra_metadata={},
|
||||
provider_display_name="OpenAI",
|
||||
is_enabled=True,
|
||||
is_recommended=False,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return RegistryModel(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _record_to_registry_model tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_record_to_registry_model():
|
||||
"""Happy-path: well-formed record produces a correct RegistryModel."""
|
||||
record = _make_mock_record()
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.slug == "openai/gpt-4o"
|
||||
assert model.display_name == "GPT-4o"
|
||||
assert model.description == "Latest GPT model"
|
||||
assert model.provider_display_name == "OpenAI"
|
||||
assert model.is_enabled is True
|
||||
assert model.is_recommended is False
|
||||
assert model.supports_tools is True
|
||||
assert model.supports_json_output is True
|
||||
assert model.supports_reasoning is False
|
||||
assert model.supports_parallel_tool_calls is True
|
||||
assert model.metadata.provider == "openai"
|
||||
assert model.metadata.context_window == 128000
|
||||
assert model.metadata.max_output_tokens == 16384
|
||||
assert model.metadata.price_tier == 2
|
||||
assert model.creator is None
|
||||
assert model.costs == ()
|
||||
|
||||
|
||||
def test_record_to_registry_model_missing_provider(caplog):
|
||||
"""Record with no Provider relation falls back to providerId and logs a warning."""
|
||||
record = _make_mock_record(Provider=None, providerId="provider-uuid")
|
||||
with caplog.at_level("WARNING"):
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert "no Provider" in caplog.text
|
||||
assert model.metadata.provider == "provider-uuid"
|
||||
assert model.provider_display_name == "provider-uuid"
|
||||
|
||||
|
||||
def test_record_to_registry_model_missing_creator():
|
||||
"""When Creator is None, creator_name defaults to 'Unknown' and creator field is None."""
|
||||
record = _make_mock_record(Creator=None)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.creator is None
|
||||
assert model.metadata.creator_name == "Unknown"
|
||||
|
||||
|
||||
def test_record_to_registry_model_with_creator():
|
||||
"""When Creator is present, it is parsed into RegistryModelCreator."""
|
||||
creator_mock = Mock()
|
||||
creator_mock.id = "creator-uuid"
|
||||
creator_mock.name = "openai"
|
||||
creator_mock.displayName = "OpenAI"
|
||||
creator_mock.description = "AI company"
|
||||
creator_mock.websiteUrl = "https://openai.com"
|
||||
creator_mock.logoUrl = "https://openai.com/logo.png"
|
||||
|
||||
record = _make_mock_record(Creator=creator_mock)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.creator is not None
|
||||
assert isinstance(model.creator, RegistryModelCreator)
|
||||
assert model.creator.id == "creator-uuid"
|
||||
assert model.creator.display_name == "OpenAI"
|
||||
assert model.metadata.creator_name == "OpenAI"
|
||||
|
||||
|
||||
def test_record_to_registry_model_null_max_output_tokens():
|
||||
"""maxOutputTokens=None falls back to contextWindow."""
|
||||
record = _make_mock_record(maxOutputTokens=None, contextWindow=64000)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.metadata.max_output_tokens == 64000
|
||||
|
||||
|
||||
def test_record_to_registry_model_invalid_price_tier(caplog):
|
||||
"""Out-of-range priceTier is coerced to 1 and a warning is logged."""
|
||||
record = _make_mock_record(priceTier=99)
|
||||
with caplog.at_level("WARNING"):
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert "out-of-range priceTier" in caplog.text
|
||||
assert model.metadata.price_tier == 1
|
||||
|
||||
|
||||
def test_record_to_registry_model_with_costs():
|
||||
"""Costs are parsed into RegistryModelCost tuples."""
|
||||
cost_mock = Mock()
|
||||
cost_mock.unit = "TOKENS"
|
||||
cost_mock.creditCost = 10
|
||||
cost_mock.credentialProvider = "openai"
|
||||
cost_mock.credentialId = None
|
||||
cost_mock.credentialType = None
|
||||
cost_mock.currency = "USD"
|
||||
cost_mock.metadata = {}
|
||||
|
||||
record = _make_mock_record(Costs=[cost_mock])
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert len(model.costs) == 1
|
||||
cost = model.costs[0]
|
||||
assert isinstance(cost, RegistryModelCost)
|
||||
assert cost.unit == "TOKENS"
|
||||
assert cost.credit_cost == 10
|
||||
assert cost.credential_provider == "openai"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_default_model_slug tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_default_model_slug_recommended():
|
||||
"""Recommended model is preferred over non-recommended enabled models."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_recommended=False
|
||||
),
|
||||
"openai/gpt-4o-recommended": _make_registry_model(
|
||||
slug="openai/gpt-4o-recommended",
|
||||
display_name="GPT-4o Recommended",
|
||||
is_recommended=True,
|
||||
),
|
||||
}
|
||||
|
||||
result = get_default_model_slug()
|
||||
assert result == "openai/gpt-4o-recommended"
|
||||
|
||||
|
||||
def test_get_default_model_slug_fallback():
|
||||
"""With no recommended model, falls back to first enabled (alphabetical)."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_recommended=False
|
||||
),
|
||||
"openai/gpt-3.5": _make_registry_model(
|
||||
slug="openai/gpt-3.5", display_name="GPT-3.5", is_recommended=False
|
||||
),
|
||||
}
|
||||
|
||||
result = get_default_model_slug()
|
||||
# Sorted alphabetically: GPT-3.5 < GPT-4o
|
||||
assert result == "openai/gpt-3.5"
|
||||
|
||||
|
||||
def test_get_default_model_slug_empty():
|
||||
"""Empty registry returns None."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {}
|
||||
|
||||
result = get_default_model_slug()
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_schema_options / get_schema_options tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_schema_options():
|
||||
"""Only enabled models appear, sorted case-insensitively."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_enabled=True
|
||||
),
|
||||
"openai/disabled": _make_registry_model(
|
||||
slug="openai/disabled", display_name="Disabled Model", is_enabled=False
|
||||
),
|
||||
"openai/gpt-3.5": _make_registry_model(
|
||||
slug="openai/gpt-3.5", display_name="gpt-3.5", is_enabled=True
|
||||
),
|
||||
}
|
||||
|
||||
options = _build_schema_options()
|
||||
slugs = [o["value"] for o in options]
|
||||
|
||||
# disabled model should be excluded
|
||||
assert "openai/disabled" not in slugs
|
||||
# only enabled models
|
||||
assert "openai/gpt-4o" in slugs
|
||||
assert "openai/gpt-3.5" in slugs
|
||||
# case-insensitive sort: "gpt-3.5" < "GPT-4o" (both lowercase: "gpt-3.5" < "gpt-4o")
|
||||
assert slugs.index("openai/gpt-3.5") < slugs.index("openai/gpt-4o")
|
||||
|
||||
# Verify structure
|
||||
for option in options:
|
||||
assert "label" in option
|
||||
assert "value" in option
|
||||
assert "group" in option
|
||||
assert "description" in option
|
||||
|
||||
|
||||
def test_get_schema_options_returns_copy():
|
||||
"""Mutating the returned list does not affect the internal cache."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(slug="openai/gpt-4o", display_name="GPT-4o"),
|
||||
}
|
||||
reg._schema_options = _build_schema_options()
|
||||
|
||||
options = get_schema_options()
|
||||
original_length = len(options)
|
||||
options.append({"label": "Injected", "value": "evil/model", "group": "evil", "description": ""})
|
||||
|
||||
# Internal state should be unchanged
|
||||
assert len(get_schema_options()) == original_length
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic frozen model tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registry_model_frozen():
|
||||
"""Pydantic frozen=True should reject attribute assignment."""
|
||||
model = _make_registry_model()
|
||||
|
||||
with pytest.raises((pydantic.ValidationError, TypeError)):
|
||||
model.slug = "changed/slug" # type: ignore[misc]
|
||||
|
||||
|
||||
def test_registry_model_cost_frozen():
|
||||
"""RegistryModelCost is also frozen."""
|
||||
cost = RegistryModelCost(
|
||||
unit="TOKENS",
|
||||
credit_cost=5,
|
||||
credential_provider="openai",
|
||||
)
|
||||
with pytest.raises((pydantic.ValidationError, TypeError)):
|
||||
cost.unit = "RUN" # type: ignore[misc]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# refresh_llm_registry tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_llm_registry():
|
||||
"""Mock prisma find_many, verify cache is populated after refresh."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
record = _make_mock_record()
|
||||
mock_find_many = AsyncMock(return_value=[record])
|
||||
|
||||
with patch("prisma.models.LlmModel.prisma") as mock_prisma_cls:
|
||||
mock_prisma_instance = Mock()
|
||||
mock_prisma_instance.find_many = mock_find_many
|
||||
mock_prisma_cls.return_value = mock_prisma_instance
|
||||
|
||||
# Clear state first
|
||||
reg._dynamic_models = {}
|
||||
reg._schema_options = []
|
||||
|
||||
await refresh_llm_registry()
|
||||
|
||||
assert "openai/gpt-4o" in reg._dynamic_models
|
||||
model = reg._dynamic_models["openai/gpt-4o"]
|
||||
assert isinstance(model, RegistryModel)
|
||||
assert model.slug == "openai/gpt-4o"
|
||||
# Schema options should be populated too
|
||||
assert len(reg._schema_options) == 1
|
||||
assert reg._schema_options[0]["value"] == "openai/gpt-4o"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# clear_registry_cache tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_clear_registry_cache():
|
||||
"""clear_registry_cache calls cache_clear on the cached fetch function."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch.object(reg._fetch_registry_from_db, "cache_clear") as mock_clear:
|
||||
clear_registry_cache()
|
||||
mock_clear.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_model / get_all_models / get_enabled_models / get_all_model_slugs tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_model_found():
|
||||
"""get_model returns the model when the slug exists in the registry."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
m = _make_registry_model(slug="openai/gpt-4o")
|
||||
reg._dynamic_models = {"openai/gpt-4o": m}
|
||||
|
||||
result = get_model("openai/gpt-4o")
|
||||
|
||||
assert result is m
|
||||
|
||||
|
||||
def test_get_model_not_found():
|
||||
"""get_model returns None for an unknown slug."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {}
|
||||
|
||||
assert get_model("nonexistent/model") is None
|
||||
|
||||
|
||||
def test_get_all_models_includes_disabled():
|
||||
"""get_all_models returns all models, including disabled ones."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
enabled = _make_registry_model(slug="openai/gpt-4", is_enabled=True)
|
||||
disabled = _make_registry_model(slug="openai/old-model", is_enabled=False)
|
||||
reg._dynamic_models = {"openai/gpt-4": enabled, "openai/old-model": disabled}
|
||||
|
||||
result = get_all_models()
|
||||
|
||||
slugs = [m.slug for m in result]
|
||||
assert "openai/gpt-4" in slugs
|
||||
assert "openai/old-model" in slugs
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_get_enabled_models_excludes_disabled():
|
||||
"""get_enabled_models returns only models where is_enabled=True."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
enabled = _make_registry_model(slug="openai/gpt-4", is_enabled=True)
|
||||
disabled = _make_registry_model(slug="openai/old-model", is_enabled=False)
|
||||
reg._dynamic_models = {"openai/gpt-4": enabled, "openai/old-model": disabled}
|
||||
|
||||
result = get_enabled_models()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].slug == "openai/gpt-4"
|
||||
|
||||
|
||||
def test_get_all_model_slugs_for_validation():
|
||||
"""get_all_model_slugs_for_validation returns only enabled model slugs."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4": _make_registry_model(slug="openai/gpt-4", is_enabled=True),
|
||||
"openai/old": _make_registry_model(slug="openai/old", is_enabled=False),
|
||||
"anthropic/claude": _make_registry_model(
|
||||
slug="anthropic/claude", is_enabled=True
|
||||
),
|
||||
}
|
||||
|
||||
result = get_all_model_slugs_for_validation()
|
||||
|
||||
assert "openai/gpt-4" in result
|
||||
assert "anthropic/claude" in result
|
||||
assert "openai/old" not in result
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_llm_registry_error_is_reraised(mocker):
|
||||
"""refresh_llm_registry re-raises exceptions after logging them."""
|
||||
mocker.patch(
|
||||
"backend.data.llm_registry.registry._fetch_registry_from_db",
|
||||
new=AsyncMock(side_effect=RuntimeError("DB unavailable")),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="DB unavailable"):
|
||||
await refresh_llm_registry()
|
||||
@@ -0,0 +1,6 @@
|
||||
"""LLM registry API (public + admin)."""
|
||||
|
||||
from .admin_routes import router as admin_router
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router", "admin_router"]
|
||||
145
autogpt_platform/backend/backend/server/v2/llm/admin_model.py
Normal file
145
autogpt_platform/backend/backend/server/v2/llm/admin_model.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Request/response models for LLM registry admin API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class CreateLlmProviderRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
default_credential_provider: str | None = None
|
||||
default_credential_id: str | None = None
|
||||
default_credential_type: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class UpdateLlmProviderRequest(pydantic.BaseModel):
|
||||
display_name: str | None = None
|
||||
description: str | None = None
|
||||
default_credential_provider: str | None = None
|
||||
default_credential_id: str | None = None
|
||||
default_credential_type: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CreateLlmModelRequest(pydantic.BaseModel):
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
provider_name: str
|
||||
creator_id: str | None = None
|
||||
context_window: int = pydantic.Field(gt=0)
|
||||
max_output_tokens: int | None = pydantic.Field(default=None, gt=0)
|
||||
price_tier: int = pydantic.Field(ge=1, le=3)
|
||||
is_enabled: bool = True
|
||||
is_recommended: bool = False
|
||||
supports_tools: bool = False
|
||||
supports_json_output: bool = False
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool_calls: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[dict[str, Any]] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateLlmModelRequest(pydantic.BaseModel):
|
||||
display_name: str | None = None
|
||||
description: str | None = None
|
||||
creator_id: str | None = None
|
||||
context_window: int | None = pydantic.Field(default=None, gt=0)
|
||||
max_output_tokens: int | None = pydantic.Field(default=None, gt=0)
|
||||
price_tier: int | None = pydantic.Field(default=None, ge=1, le=3)
|
||||
is_enabled: bool | None = None
|
||||
is_recommended: bool | None = None
|
||||
supports_tools: bool | None = None
|
||||
supports_json_output: bool | None = None
|
||||
supports_reasoning: bool | None = None
|
||||
supports_parallel_tool_calls: bool | None = None
|
||||
capabilities: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ToggleLlmModelRequest(pydantic.BaseModel):
|
||||
is_enabled: bool
|
||||
migrate_to_slug: str | None = None
|
||||
migration_reason: str | None = None
|
||||
custom_credit_cost: int | None = None
|
||||
|
||||
|
||||
class CreateLlmCreatorRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class UpdateLlmCreatorRequest(pydantic.BaseModel):
|
||||
display_name: str | None = None
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LlmCreatorAdminResponse(pydantic.BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
|
||||
|
||||
class LlmModelCostAdminResponse(pydantic.BaseModel):
|
||||
unit: str
|
||||
credit_cost: float
|
||||
credential_provider: str
|
||||
credential_type: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmProviderAdminResponse(pydantic.BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
default_credential_provider: str | None = None
|
||||
default_credential_id: str | None = None
|
||||
default_credential_type: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
model_count: int | None = None
|
||||
|
||||
|
||||
class LlmModelAdminResponse(pydantic.BaseModel):
|
||||
id: str
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
provider_id: str
|
||||
creator_id: str | None = None
|
||||
context_window: int
|
||||
max_output_tokens: int | None = None
|
||||
price_tier: int
|
||||
is_enabled: bool
|
||||
is_recommended: bool
|
||||
supports_tools: bool
|
||||
supports_json_output: bool
|
||||
supports_reasoning: bool
|
||||
supports_parallel_tool_calls: bool
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
creator: LlmCreatorAdminResponse | None = None
|
||||
costs: list[LlmModelCostAdminResponse] = pydantic.Field(default_factory=list)
|
||||
565
autogpt_platform/backend/backend/server/v2/llm/admin_routes.py
Normal file
565
autogpt_platform/backend/backend/server/v2/llm/admin_routes.py
Normal file
@@ -0,0 +1,565 @@
|
||||
"""Admin API for LLM registry management."""
|
||||
|
||||
import logging
|
||||
|
||||
import autogpt_libs.auth
|
||||
import prisma
|
||||
import prisma.models
|
||||
from fastapi import APIRouter, HTTPException, Security, status
|
||||
|
||||
from backend.server.v2.llm import db_write
|
||||
from backend.server.v2.llm.admin_model import (
|
||||
CreateLlmCreatorRequest,
|
||||
CreateLlmModelRequest,
|
||||
CreateLlmProviderRequest,
|
||||
LlmCreatorAdminResponse,
|
||||
LlmModelAdminResponse,
|
||||
LlmModelCostAdminResponse,
|
||||
LlmProviderAdminResponse,
|
||||
ToggleLlmModelRequest,
|
||||
UpdateLlmCreatorRequest,
|
||||
UpdateLlmModelRequest,
|
||||
UpdateLlmProviderRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _map_provider(provider: prisma.models.LlmProvider, model_count: int | None = None) -> LlmProviderAdminResponse:
|
||||
return LlmProviderAdminResponse(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
display_name=provider.displayName,
|
||||
description=provider.description,
|
||||
default_credential_provider=provider.defaultCredentialProvider,
|
||||
default_credential_id=provider.defaultCredentialId,
|
||||
default_credential_type=provider.defaultCredentialType,
|
||||
metadata=dict(provider.metadata or {}),
|
||||
created_at=provider.createdAt.isoformat() if provider.createdAt else None,
|
||||
updated_at=provider.updatedAt.isoformat() if provider.updatedAt else None,
|
||||
model_count=model_count,
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(creator: prisma.models.LlmModelCreator) -> LlmCreatorAdminResponse:
|
||||
return LlmCreatorAdminResponse(
|
||||
id=creator.id,
|
||||
name=creator.name,
|
||||
display_name=creator.displayName,
|
||||
description=creator.description,
|
||||
website_url=creator.websiteUrl,
|
||||
logo_url=creator.logoUrl,
|
||||
metadata=dict(creator.metadata or {}),
|
||||
created_at=creator.createdAt.isoformat() if creator.createdAt else None,
|
||||
updated_at=creator.updatedAt.isoformat() if creator.updatedAt else None,
|
||||
)
|
||||
|
||||
|
||||
def _map_model(model: prisma.models.LlmModel) -> LlmModelAdminResponse:
|
||||
return LlmModelAdminResponse(
|
||||
id=model.id,
|
||||
slug=model.slug,
|
||||
display_name=model.displayName,
|
||||
description=model.description,
|
||||
provider_id=model.providerId,
|
||||
creator_id=model.creatorId,
|
||||
context_window=model.contextWindow,
|
||||
max_output_tokens=model.maxOutputTokens,
|
||||
price_tier=model.priceTier,
|
||||
is_enabled=model.isEnabled,
|
||||
is_recommended=model.isRecommended,
|
||||
supports_tools=model.supportsTools,
|
||||
supports_json_output=model.supportsJsonOutput,
|
||||
supports_reasoning=model.supportsReasoning,
|
||||
supports_parallel_tool_calls=model.supportsParallelToolCalls,
|
||||
capabilities=dict(model.capabilities or {}),
|
||||
metadata=dict(model.metadata or {}),
|
||||
created_at=model.createdAt.isoformat() if model.createdAt else None,
|
||||
updated_at=model.updatedAt.isoformat() if model.updatedAt else None,
|
||||
creator=_map_creator(model.Creator) if model.Creator else None,
|
||||
costs=[
|
||||
LlmModelCostAdminResponse(
|
||||
unit=c.unit,
|
||||
credit_cost=float(c.creditCost),
|
||||
credential_provider=c.credentialProvider,
|
||||
credential_type=c.credentialType,
|
||||
metadata=dict(c.metadata or {}),
|
||||
)
|
||||
for c in (model.Costs or [])
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_model(request: CreateLlmModelRequest) -> LlmModelAdminResponse:
|
||||
try:
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": request.provider_name}
|
||||
)
|
||||
if not provider:
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": request.provider_name}
|
||||
)
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Provider '{request.provider_name}' not found",
|
||||
)
|
||||
|
||||
model = await db_write.create_model(
|
||||
slug=request.slug,
|
||||
display_name=request.display_name,
|
||||
provider_id=provider.id,
|
||||
context_window=request.context_window,
|
||||
price_tier=request.price_tier,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
if request.costs:
|
||||
for cost_input in request.costs:
|
||||
await prisma.models.LlmModelCost.prisma().create(
|
||||
data={
|
||||
"unit": cost_input.get("unit", "RUN"),
|
||||
"creditCost": int(cost_input.get("credit_cost", 1)),
|
||||
"credentialProvider": provider.name,
|
||||
"metadata": prisma.Json(cost_input.get("metadata", {})),
|
||||
"Model": {"connect": {"id": model.id}},
|
||||
}
|
||||
)
|
||||
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created model '{request.slug}' (id: {model.id})")
|
||||
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model.id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return _map_model(model)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create model")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/models/{slug:path}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_model(slug: str, request: UpdateLlmModelRequest) -> LlmModelAdminResponse:
|
||||
try:
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
model = await db_write.update_model(
|
||||
model_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
context_window=request.context_window,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
price_tier=request.price_tier,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated model '{slug}' (id: {model.id})")
|
||||
return _map_model(model)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update model")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/models/{slug:path}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_model(slug: str, replacement_model_slug: str | None = None) -> dict:
|
||||
try:
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
result = await db_write.delete_model(
|
||||
model_id=existing.id,
|
||||
replacement_model_slug=replacement_model_slug,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Deleted model '{slug}' (migrated {result['nodes_migrated']} nodes)"
|
||||
)
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete model")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/models/{slug:path}/usage",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def get_model_usage(slug: str) -> dict:
|
||||
try:
|
||||
return await db_write.get_model_usage(slug)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get model usage: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get model usage")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models/{slug:path}/toggle",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def toggle_model(slug: str, request: ToggleLlmModelRequest) -> dict:
|
||||
try:
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
if not request.is_enabled and existing.isRecommended:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"Cannot disable the recommended model. "
|
||||
"Change the recommended model before disabling this one."
|
||||
),
|
||||
)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id=existing.id,
|
||||
is_enabled=request.is_enabled,
|
||||
migrate_to_slug=request.migrate_to_slug,
|
||||
migration_reason=request.migration_reason,
|
||||
custom_credit_cost=request.custom_credit_cost,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Toggled model '{slug}' enabled={request.is_enabled} "
|
||||
f"(migrated {result['nodes_migrated']} nodes)"
|
||||
)
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to toggle model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to toggle model")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/migrations",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def list_migrations(include_reverted: bool = False) -> dict:
|
||||
try:
|
||||
migrations = await db_write.list_migrations(include_reverted=include_reverted)
|
||||
return {"migrations": migrations}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list migrations: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list migrations")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/migrations/{migration_id}/revert",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def revert_migration(migration_id: str, re_enable_source_model: bool = True) -> dict:
|
||||
try:
|
||||
result = await db_write.revert_migration(
|
||||
migration_id=migration_id,
|
||||
re_enable_source_model=re_enable_source_model,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Reverted migration {migration_id}: "
|
||||
f"{result['nodes_reverted']} nodes restored"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to revert migration: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to revert migration")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/providers",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_provider(request: CreateLlmProviderRequest) -> LlmProviderAdminResponse:
|
||||
try:
|
||||
provider = await db_write.create_provider(
|
||||
name=request.name,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created provider '{request.name}' (id: {provider.id})")
|
||||
return _map_provider(provider)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create provider")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/providers/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_provider(name: str, request: UpdateLlmProviderRequest) -> LlmProviderAdminResponse:
|
||||
try:
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
provider = await db_write.update_provider(
|
||||
provider_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated provider '{name}' (id: {provider.id})")
|
||||
return _map_provider(provider)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update provider")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/providers/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_provider(name: str) -> None:
|
||||
try:
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
await db_write.delete_provider(provider_id=existing.id)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Deleted provider '{name}' (id: {existing.id})")
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete provider")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/admin/providers",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def admin_list_providers() -> dict:
|
||||
try:
|
||||
providers = await prisma.models.LlmProvider.prisma().find_many(
|
||||
order={"name": "asc"},
|
||||
include={"Models": True},
|
||||
)
|
||||
return {
|
||||
"providers": [
|
||||
_map_provider(p, model_count=len(p.Models) if p.Models else 0).model_dump()
|
||||
for p in providers
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list providers: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list providers")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/admin/models",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def admin_list_models(
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
enabled_only: bool = False,
|
||||
) -> dict:
|
||||
try:
|
||||
where = {"isEnabled": True} if enabled_only else {}
|
||||
models = await prisma.models.LlmModel.prisma().find_many(
|
||||
where=where,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
order={"displayName": "asc"},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return {"models": [_map_model(m).model_dump() for m in models]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list models: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list models")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/creators",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def list_creators() -> dict:
|
||||
try:
|
||||
creators = await prisma.models.LlmModelCreator.prisma().find_many(
|
||||
order={"name": "asc"}
|
||||
)
|
||||
return {"creators": [_map_creator(c).model_dump() for c in creators]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list creators: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list creators")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/creators",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_creator(request: CreateLlmCreatorRequest) -> LlmCreatorAdminResponse:
|
||||
try:
|
||||
creator = await prisma.models.LlmModelCreator.prisma().create(
|
||||
data={
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"websiteUrl": request.website_url,
|
||||
"logoUrl": request.logo_url,
|
||||
"metadata": prisma.Json(request.metadata),
|
||||
}
|
||||
)
|
||||
logger.info(f"Created creator '{creator.name}' (id: {creator.id})")
|
||||
return _map_creator(creator)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/creators/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_creator(name: str, request: UpdateLlmCreatorRequest) -> LlmCreatorAdminResponse:
|
||||
try:
|
||||
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail=f"Creator '{name}' not found")
|
||||
|
||||
data: dict = {}
|
||||
if request.display_name is not None:
|
||||
data["displayName"] = request.display_name
|
||||
if request.description is not None:
|
||||
data["description"] = request.description
|
||||
if request.website_url is not None:
|
||||
data["websiteUrl"] = request.website_url
|
||||
if request.logo_url is not None:
|
||||
data["logoUrl"] = request.logo_url
|
||||
if request.metadata is not None:
|
||||
data["metadata"] = prisma.Json(request.metadata)
|
||||
|
||||
creator = await prisma.models.LlmModelCreator.prisma().update(
|
||||
where={"id": existing.id},
|
||||
data=data,
|
||||
)
|
||||
logger.info(f"Updated creator '{name}' (id: {creator.id})")
|
||||
return _map_creator(creator)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/creators/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_creator(name: str) -> None:
|
||||
try:
|
||||
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"name": name},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail=f"Creator '{name}' not found")
|
||||
|
||||
if existing.Models and len(existing.Models) > 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot delete creator '{name}' — it has {len(existing.Models)} associated models",
|
||||
)
|
||||
|
||||
await prisma.models.LlmModelCreator.prisma().delete(where={"id": existing.id})
|
||||
logger.info(f"Deleted creator '{name}' (id: {existing.id})")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,701 @@
|
||||
"""Tests for LLM registry admin CRUD routes (admin_routes.py).
|
||||
|
||||
Covers provider, model, creator, migration CRUD endpoints.
|
||||
All endpoints require admin authentication.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.llm.admin_routes import router as admin_router
|
||||
|
||||
admin_app = fastapi.FastAPI()
|
||||
admin_app.include_router(admin_router)
|
||||
admin_client = fastapi.testclient.TestClient(admin_app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_admin):
|
||||
"""Bypass JWT admin auth for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
admin_app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
admin_app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock factory helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NOW = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _make_mock_provider(
|
||||
id: str = "prov-1",
|
||||
name: str = "openai",
|
||||
display_name: str = "OpenAI",
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
models: list | None = None,
|
||||
) -> Mock:
|
||||
p = Mock()
|
||||
p.id = id
|
||||
p.name = name
|
||||
p.displayName = display_name
|
||||
p.description = description
|
||||
p.defaultCredentialProvider = default_credential_provider
|
||||
p.defaultCredentialId = default_credential_id
|
||||
p.defaultCredentialType = default_credential_type
|
||||
p.metadata = {}
|
||||
p.createdAt = _NOW
|
||||
p.updatedAt = _NOW
|
||||
p.Models = models if models is not None else []
|
||||
return p
|
||||
|
||||
|
||||
def _make_mock_model(
|
||||
id: str = "model-1",
|
||||
slug: str = "gpt-4",
|
||||
display_name: str = "GPT-4",
|
||||
description: str | None = None,
|
||||
provider_id: str = "prov-1",
|
||||
creator_id: str | None = None,
|
||||
context_window: int = 128000,
|
||||
max_output_tokens: int | None = 4096,
|
||||
price_tier: int = 2,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
costs: list | None = None,
|
||||
creator: Mock | None = None,
|
||||
) -> Mock:
|
||||
m = Mock()
|
||||
m.id = id
|
||||
m.slug = slug
|
||||
m.displayName = display_name
|
||||
m.description = description
|
||||
m.providerId = provider_id
|
||||
m.creatorId = creator_id
|
||||
m.contextWindow = context_window
|
||||
m.maxOutputTokens = max_output_tokens
|
||||
m.priceTier = price_tier
|
||||
m.isEnabled = is_enabled
|
||||
m.isRecommended = is_recommended
|
||||
m.supportsTools = False
|
||||
m.supportsJsonOutput = False
|
||||
m.supportsReasoning = False
|
||||
m.supportsParallelToolCalls = False
|
||||
m.capabilities = {}
|
||||
m.metadata = {}
|
||||
m.createdAt = _NOW
|
||||
m.updatedAt = _NOW
|
||||
m.Costs = costs or []
|
||||
m.Creator = creator
|
||||
return m
|
||||
|
||||
|
||||
def _make_mock_creator(
|
||||
id: str = "creator-1",
|
||||
name: str = "openai",
|
||||
display_name: str = "OpenAI",
|
||||
description: str | None = None,
|
||||
website_url: str | None = None,
|
||||
logo_url: str | None = None,
|
||||
models: list | None = None,
|
||||
) -> Mock:
|
||||
c = Mock()
|
||||
c.id = id
|
||||
c.name = name
|
||||
c.displayName = display_name
|
||||
c.description = description
|
||||
c.websiteUrl = website_url
|
||||
c.logoUrl = logo_url
|
||||
c.metadata = {}
|
||||
c.createdAt = _NOW
|
||||
c.updatedAt = _NOW
|
||||
c.Models = models if models is not None else []
|
||||
return c
|
||||
|
||||
|
||||
def _make_mock_migration(
|
||||
id: str = "mig-1",
|
||||
source_slug: str = "gpt-3",
|
||||
target_slug: str = "gpt-4",
|
||||
node_count: int = 3,
|
||||
is_reverted: bool = False,
|
||||
) -> dict:
|
||||
return {
|
||||
"id": id,
|
||||
"source_model_slug": source_slug,
|
||||
"target_model_slug": target_slug,
|
||||
"reason": "upgrade",
|
||||
"node_count": node_count,
|
||||
"custom_credit_cost": None,
|
||||
"is_reverted": is_reverted,
|
||||
"reverted_at": None,
|
||||
"created_at": _NOW.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_provider(mocker):
|
||||
"""POST /llm/providers creates a provider and returns 201 with provider fields."""
|
||||
mock_provider = _make_mock_provider()
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.create_provider = AsyncMock(return_value=mock_provider)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/providers",
|
||||
json={"name": "openai", "display_name": "OpenAI"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "openai"
|
||||
assert data["display_name"] == "OpenAI"
|
||||
mock_db.create_provider.assert_called_once()
|
||||
mock_db.refresh_runtime_caches.assert_called_once()
|
||||
|
||||
|
||||
def test_create_provider_validation_error(mocker):
|
||||
"""db_write.create_provider raising ValueError returns 400."""
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.create_provider = AsyncMock(side_effect=ValueError("duplicate name"))
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/providers",
|
||||
json={"name": "openai", "display_name": "OpenAI"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "duplicate name" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_provider(mocker):
|
||||
"""PATCH /llm/providers/{name} returns 200 with updated fields."""
|
||||
existing = _make_mock_provider()
|
||||
updated = _make_mock_provider(display_name="OpenAI Updated")
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.update_provider = AsyncMock(return_value=updated)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.patch(
|
||||
"/llm/providers/openai",
|
||||
json={"display_name": "OpenAI Updated"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["display_name"] == "OpenAI Updated"
|
||||
|
||||
|
||||
def test_update_provider_not_found(mocker):
|
||||
"""PATCH returns 404 when the provider does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
|
||||
response = admin_client.patch(
|
||||
"/llm/providers/nonexistent",
|
||||
json={"display_name": "X"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_provider(mocker):
|
||||
"""DELETE /llm/providers/{name} returns 204 on success."""
|
||||
existing = _make_mock_provider()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.delete_provider = AsyncMock(return_value=True)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.delete("/llm/providers/openai")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
def test_delete_provider_not_found(mocker):
|
||||
"""DELETE returns 404 when the provider does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
|
||||
response = admin_client.delete("/llm/providers/ghost")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_provider_has_models(mocker):
|
||||
"""DELETE returns 400 when db_write raises ValueError (provider has models)."""
|
||||
existing = _make_mock_provider()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.delete_provider = AsyncMock(
|
||||
side_effect=ValueError("Cannot delete provider — it has 2 model(s)")
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.delete("/llm/providers/openai")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "model" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_create_provider_server_error(mocker):
|
||||
"""POST /llm/providers returns 500 when an unexpected exception occurs."""
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.create_provider = AsyncMock(side_effect=RuntimeError("unexpected"))
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/providers",
|
||||
json={"name": "openai", "display_name": "OpenAI"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_admin_list_providers(mocker):
|
||||
"""GET /llm/admin/providers returns providers with model_count."""
|
||||
provider = _make_mock_provider(models=[_make_mock_model()])
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_many = AsyncMock(return_value=[provider])
|
||||
|
||||
response = admin_client.get("/llm/admin/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
assert len(providers) == 1
|
||||
assert providers[0]["model_count"] == 1
|
||||
assert providers[0]["name"] == "openai"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_model(mocker):
|
||||
"""POST /llm/models creates a model and returns 201 with model fields."""
|
||||
mock_provider = _make_mock_provider()
|
||||
mock_model = _make_mock_model()
|
||||
|
||||
# Patch both provider lookups (by name, then by id fallback) and refetch
|
||||
prisma_models_mock = mocker.patch("prisma.models")
|
||||
prisma_models_mock.LlmProvider.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=mock_provider
|
||||
)
|
||||
prisma_models_mock.LlmModel.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=mock_model
|
||||
)
|
||||
prisma_models_mock.LlmModelCost.prisma.return_value.create = AsyncMock()
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.create_model = AsyncMock(return_value=mock_model)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/models",
|
||||
json={
|
||||
"slug": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"provider_id": "openai",
|
||||
"context_window": 128000,
|
||||
"price_tier": 2,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["slug"] == "gpt-4"
|
||||
mock_db.create_model.assert_called_once()
|
||||
mock_db.refresh_runtime_caches.assert_called_once()
|
||||
|
||||
|
||||
def test_create_model_provider_not_found(mocker):
|
||||
"""POST /llm/models returns 404 when the provider is not found."""
|
||||
prisma_models_mock = mocker.patch("prisma.models")
|
||||
prisma_models_mock.LlmProvider.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/models",
|
||||
json={
|
||||
"slug": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"provider_id": "nonexistent",
|
||||
"context_window": 128000,
|
||||
"price_tier": 2,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_update_model(mocker):
|
||||
"""PATCH /llm/models/{slug} returns 200 with updated model."""
|
||||
existing = _make_mock_model()
|
||||
updated = _make_mock_model(display_name="GPT-4 Turbo")
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.update_model = AsyncMock(return_value=updated)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.patch(
|
||||
"/llm/models/gpt-4",
|
||||
json={"display_name": "GPT-4 Turbo"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["display_name"] == "GPT-4 Turbo"
|
||||
|
||||
|
||||
def test_update_model_not_found(mocker):
|
||||
"""PATCH returns 404 when the model slug does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
|
||||
response = admin_client.patch("/llm/models/unknown-slug", json={})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_model(mocker):
|
||||
"""DELETE /llm/models/{slug} without replacement returns 200 with nodes_migrated=0."""
|
||||
existing = _make_mock_model()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.delete_model = AsyncMock(
|
||||
return_value={
|
||||
"deleted_model_slug": "gpt-4",
|
||||
"deleted_model_display_name": "GPT-4",
|
||||
"replacement_model_slug": None,
|
||||
"nodes_migrated": 0,
|
||||
}
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.delete("/llm/models/gpt-4")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["nodes_migrated"] == 0
|
||||
assert data["deleted_model_slug"] == "gpt-4"
|
||||
|
||||
|
||||
def test_delete_model_with_migration(mocker):
|
||||
"""DELETE with replacement_model_slug query param migrates nodes and returns count."""
|
||||
existing = _make_mock_model()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.delete_model = AsyncMock(
|
||||
return_value={
|
||||
"deleted_model_slug": "gpt-3",
|
||||
"deleted_model_display_name": "GPT-3",
|
||||
"replacement_model_slug": "gpt-4",
|
||||
"nodes_migrated": 5,
|
||||
}
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.delete(
|
||||
"/llm/models/gpt-3?replacement_model_slug=gpt-4"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["nodes_migrated"] == 5
|
||||
mock_db.delete_model.assert_called_once_with(
|
||||
model_id=existing.id, replacement_model_slug="gpt-4"
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_usage(mocker):
|
||||
"""GET /llm/models/{slug}/usage returns node_count."""
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.get_model_usage = AsyncMock(
|
||||
return_value={"model_slug": "gpt-4", "node_count": 7}
|
||||
)
|
||||
|
||||
response = admin_client.get("/llm/models/gpt-4/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["node_count"] == 7
|
||||
assert data["model_slug"] == "gpt-4"
|
||||
|
||||
|
||||
def test_toggle_model_enable(mocker):
|
||||
"""POST /llm/models/{slug}/toggle with is_enabled=True toggles model and returns 200."""
|
||||
existing = _make_mock_model(is_enabled=False)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.toggle_model_with_migration = AsyncMock(
|
||||
return_value={"nodes_migrated": 0, "migrated_to_slug": None, "migration_id": None}
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/models/gpt-4/toggle",
|
||||
json={"is_enabled": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["nodes_migrated"] == 0
|
||||
|
||||
|
||||
def test_toggle_model_disable_with_migration(mocker):
|
||||
"""Disabling with migrate_to_slug passes migration args and returns nodes_migrated."""
|
||||
existing = _make_mock_model(is_enabled=True)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.toggle_model_with_migration = AsyncMock(
|
||||
return_value={
|
||||
"nodes_migrated": 3,
|
||||
"migrated_to_slug": "gpt-4-turbo",
|
||||
"migration_id": "mig-abc",
|
||||
}
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/models/gpt-4/toggle",
|
||||
json={"is_enabled": False, "migrate_to_slug": "gpt-4-turbo"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["nodes_migrated"] == 3
|
||||
assert data["migration_id"] == "mig-abc"
|
||||
mock_db.toggle_model_with_migration.assert_called_once_with(
|
||||
model_id=existing.id,
|
||||
is_enabled=False,
|
||||
migrate_to_slug="gpt-4-turbo",
|
||||
migration_reason=None,
|
||||
custom_credit_cost=None,
|
||||
)
|
||||
|
||||
|
||||
def test_toggle_model_not_found(mocker):
|
||||
"""POST toggle returns 404 when the model slug does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/models/ghost-model/toggle",
|
||||
json={"is_enabled": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Migrations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_migrations(mocker):
|
||||
"""GET /llm/migrations returns migrations list."""
|
||||
migrations = [_make_mock_migration(), _make_mock_migration(id="mig-2")]
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.list_migrations = AsyncMock(return_value=migrations)
|
||||
|
||||
response = admin_client.get("/llm/migrations")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["migrations"]) == 2
|
||||
assert data["migrations"][0]["id"] == "mig-1"
|
||||
|
||||
|
||||
def test_revert_migration(mocker):
|
||||
"""POST /llm/migrations/{id}/revert calls db_write and returns nodes_reverted."""
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.revert_migration = AsyncMock(
|
||||
return_value={
|
||||
"migration_id": "mig-1",
|
||||
"source_model_slug": "gpt-3",
|
||||
"target_model_slug": "gpt-4",
|
||||
"nodes_reverted": 3,
|
||||
"nodes_already_changed": 0,
|
||||
"source_model_re_enabled": True,
|
||||
}
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post("/llm/migrations/mig-1/revert")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["nodes_reverted"] == 3
|
||||
assert data["migration_id"] == "mig-1"
|
||||
mock_db.revert_migration.assert_called_once_with(
|
||||
migration_id="mig-1", re_enable_source_model=True
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Creator CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_creator(mocker):
|
||||
"""POST /llm/creators returns 201 with creator fields."""
|
||||
mock_creator = _make_mock_creator()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelCreator.prisma"
|
||||
).return_value.create = AsyncMock(return_value=mock_creator)
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/creators",
|
||||
json={"name": "openai", "display_name": "OpenAI"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "openai"
|
||||
assert data["display_name"] == "OpenAI"
|
||||
|
||||
|
||||
def test_list_creators(mocker):
|
||||
"""GET /llm/creators returns creators list."""
|
||||
creators = [_make_mock_creator(), _make_mock_creator(id="c-2", name="anthropic")]
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelCreator.prisma"
|
||||
).return_value.find_many = AsyncMock(return_value=creators)
|
||||
|
||||
response = admin_client.get("/llm/creators")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["creators"]) == 2
|
||||
|
||||
|
||||
def test_update_creator(mocker):
|
||||
"""PATCH /llm/creators/{name} returns 200 with updated creator."""
|
||||
existing = _make_mock_creator()
|
||||
updated = _make_mock_creator(display_name="OpenAI Corp")
|
||||
prisma_mock = mocker.patch("prisma.models.LlmModelCreator.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=existing)
|
||||
prisma_mock.update = AsyncMock(return_value=updated)
|
||||
|
||||
response = admin_client.patch(
|
||||
"/llm/creators/openai",
|
||||
json={"display_name": "OpenAI Corp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["display_name"] == "OpenAI Corp"
|
||||
|
||||
|
||||
def test_update_creator_not_found(mocker):
|
||||
"""PATCH returns 404 when creator does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelCreator.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
response = admin_client.patch(
|
||||
"/llm/creators/nobody",
|
||||
json={"display_name": "Nobody"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_creator_success(mocker):
|
||||
"""DELETE /llm/creators/{name} returns 204 when creator has no models."""
|
||||
existing = _make_mock_creator(models=[])
|
||||
prisma_mock = mocker.patch("prisma.models.LlmModelCreator.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=existing)
|
||||
prisma_mock.delete = AsyncMock(return_value=existing)
|
||||
|
||||
response = admin_client.delete("/llm/creators/openai")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
def test_delete_creator_has_models(mocker):
|
||||
"""DELETE returns 400 when the creator still has associated models."""
|
||||
existing = _make_mock_creator(models=[_make_mock_model()])
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelCreator.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
response = admin_client.delete("/llm/creators/openai")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "models" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin model list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_admin_list_models(mocker):
|
||||
"""GET /llm/admin/models returns models list with creator and costs."""
|
||||
model = _make_mock_model()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_many = AsyncMock(return_value=[model])
|
||||
|
||||
response = admin_client.get("/llm/admin/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["models"]) == 1
|
||||
assert data["models"][0]["slug"] == "gpt-4"
|
||||
28
autogpt_platform/backend/backend/server/v2/llm/conftest.py
Normal file
28
autogpt_platform/backend/backend/server/v2/llm/conftest.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Local test fixtures for LLM registry tests."""
|
||||
|
||||
import fastapi
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
|
||||
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
|
||||
return {"sub": test_user_id, "role": "user", "email": "test@example.com"}
|
||||
|
||||
return {"get_jwt_payload": override_get_jwt_payload, "user_id": test_user_id}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_admin(admin_user_id):
|
||||
"""Provide mock JWT payload for admin user testing."""
|
||||
|
||||
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
|
||||
return {
|
||||
"sub": admin_user_id,
|
||||
"role": "admin",
|
||||
"email": "test-admin@example.com",
|
||||
}
|
||||
|
||||
return {"get_jwt_payload": override_get_jwt_payload, "user_id": admin_user_id}
|
||||
593
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
593
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
@@ -0,0 +1,593 @@
|
||||
"""Database write operations for LLM registry admin API."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
import prisma.models
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.db import transaction
|
||||
from backend.data.llm_registry.notifications import publish_registry_refresh_notification
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _node_model_value(slug: str) -> str:
|
||||
"""Extract the model value stored in AgentNode.constantInput from a registry slug.
|
||||
|
||||
Registry slugs are formatted as 'provider/model-name' (e.g. 'openai/gpt-4o').
|
||||
The LLM block stores only the model-name part (e.g. 'gpt-4o') in constantInput.
|
||||
"""
|
||||
return slug.split("/", 1)[-1] if "/" in slug else slug
|
||||
|
||||
|
||||
def _build_provider_data(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build provider data dict for Prisma operations."""
|
||||
return {
|
||||
"name": name,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"defaultCredentialProvider": default_credential_provider,
|
||||
"defaultCredentialId": default_credential_id,
|
||||
"defaultCredentialType": default_credential_type,
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
|
||||
|
||||
def _build_model_data(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build model data dict for Prisma operations."""
|
||||
data: dict[str, Any] = {
|
||||
"slug": slug,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"Provider": {"connect": {"id": provider_id}},
|
||||
"contextWindow": context_window,
|
||||
"maxOutputTokens": max_output_tokens,
|
||||
"priceTier": price_tier,
|
||||
"isEnabled": is_enabled,
|
||||
"isRecommended": is_recommended,
|
||||
"supportsTools": supports_tools,
|
||||
"supportsJsonOutput": supports_json_output,
|
||||
"supportsReasoning": supports_reasoning,
|
||||
"supportsParallelToolCalls": supports_parallel_tool_calls,
|
||||
"capabilities": prisma.Json(capabilities or {}),
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
if creator_id:
|
||||
data["Creator"] = {"connect": {"id": creator_id}}
|
||||
return data
|
||||
|
||||
|
||||
async def create_provider(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Create a new LLM provider."""
|
||||
data = _build_provider_data(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
default_credential_provider=default_credential_provider,
|
||||
default_credential_id=default_credential_id,
|
||||
default_credential_type=default_credential_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
provider = await prisma.models.LlmProvider.prisma().create(
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError("Failed to create provider")
|
||||
return provider
|
||||
|
||||
|
||||
async def update_provider(
|
||||
provider_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Update an existing LLM provider."""
|
||||
# Fetch existing provider to get current name
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id}
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if default_credential_provider is not None:
|
||||
data["defaultCredentialProvider"] = default_credential_provider
|
||||
if default_credential_id is not None:
|
||||
data["defaultCredentialId"] = default_credential_id
|
||||
if default_credential_type is not None:
|
||||
data["defaultCredentialType"] = default_credential_type
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
|
||||
updated = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not updated:
|
||||
raise ValueError("Failed to update provider")
|
||||
return updated
|
||||
|
||||
|
||||
async def delete_provider(provider_id: str) -> bool:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
"""
|
||||
# Check if provider exists
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Check if provider has any models
|
||||
model_count = len(provider.Models) if provider.Models else 0
|
||||
if model_count > 0:
|
||||
raise ValueError(
|
||||
f"Cannot delete provider '{provider.displayName}' because it has "
|
||||
f"{model_count} model(s). Delete all models first."
|
||||
)
|
||||
|
||||
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
|
||||
return True
|
||||
|
||||
|
||||
async def create_model(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Create a new LLM model."""
|
||||
data = _build_model_data(
|
||||
slug=slug,
|
||||
display_name=display_name,
|
||||
provider_id=provider_id,
|
||||
context_window=context_window,
|
||||
price_tier=price_tier,
|
||||
description=description,
|
||||
creator_id=creator_id,
|
||||
max_output_tokens=max_output_tokens,
|
||||
is_enabled=is_enabled,
|
||||
is_recommended=is_recommended,
|
||||
supports_tools=supports_tools,
|
||||
supports_json_output=supports_json_output,
|
||||
supports_reasoning=supports_reasoning,
|
||||
supports_parallel_tool_calls=supports_parallel_tool_calls,
|
||||
capabilities=capabilities,
|
||||
metadata=metadata,
|
||||
)
|
||||
model = await prisma.models.LlmModel.prisma().create(
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
if not model:
|
||||
raise ValueError("Failed to create model")
|
||||
return model
|
||||
|
||||
|
||||
async def update_model(
|
||||
model_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
context_window: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
price_tier: int | None = None,
|
||||
is_enabled: bool | None = None,
|
||||
is_recommended: bool | None = None,
|
||||
supports_tools: bool | None = None,
|
||||
supports_json_output: bool | None = None,
|
||||
supports_reasoning: bool | None = None,
|
||||
supports_parallel_tool_calls: bool | None = None,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Update an existing LLM model.
|
||||
|
||||
When is_recommended=True, clears the flag on all other models first so
|
||||
only one model can be recommended at a time.
|
||||
"""
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if context_window is not None:
|
||||
data["contextWindow"] = context_window
|
||||
if max_output_tokens is not None:
|
||||
data["maxOutputTokens"] = max_output_tokens
|
||||
if price_tier is not None:
|
||||
data["priceTier"] = price_tier
|
||||
if is_enabled is not None:
|
||||
data["isEnabled"] = is_enabled
|
||||
if is_recommended is not None:
|
||||
data["isRecommended"] = is_recommended
|
||||
if supports_tools is not None:
|
||||
data["supportsTools"] = supports_tools
|
||||
if supports_json_output is not None:
|
||||
data["supportsJsonOutput"] = supports_json_output
|
||||
if supports_reasoning is not None:
|
||||
data["supportsReasoning"] = supports_reasoning
|
||||
if supports_parallel_tool_calls is not None:
|
||||
data["supportsParallelToolCalls"] = supports_parallel_tool_calls
|
||||
if capabilities is not None:
|
||||
data["capabilities"] = prisma.Json(capabilities)
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
if creator_id is not None:
|
||||
data["creatorId"] = creator_id if creator_id else None
|
||||
|
||||
async with transaction() as tx:
|
||||
# Enforce single recommended model: unset all others first.
|
||||
if is_recommended is True:
|
||||
await tx.llmmodel.update_many(
|
||||
where={"id": {"not": model_id}},
|
||||
data={"isRecommended": False},
|
||||
)
|
||||
|
||||
model = await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return model
|
||||
|
||||
|
||||
async def get_model_usage(slug: str) -> dict[str, Any]:
|
||||
"""Get usage count for a model — how many AgentNodes reference it."""
|
||||
import prisma as prisma_module
|
||||
|
||||
model_value = _node_model_value(slug)
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
model_value,
|
||||
)
|
||||
node_count = int(count_result[0]["count"]) if count_result else 0
|
||||
return {"model_slug": slug, "node_count": node_count}
|
||||
|
||||
|
||||
async def toggle_model_with_migration(
|
||||
model_id: str,
|
||||
is_enabled: bool,
|
||||
migrate_to_slug: str | None = None,
|
||||
migration_reason: str | None = None,
|
||||
custom_credit_cost: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Toggle a model's enabled status, optionally migrating workflows when disabling."""
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
nodes_migrated = 0
|
||||
migration_id: str | None = None
|
||||
|
||||
if not is_enabled and migrate_to_slug:
|
||||
async with transaction() as tx:
|
||||
replacement = await tx.llmmodel.find_unique(
|
||||
where={"slug": migrate_to_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' not found"
|
||||
)
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
source_value = _node_model_value(model.slug)
|
||||
target_value = _node_model_value(migrate_to_slug)
|
||||
node_ids_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT id
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
FOR UPDATE
|
||||
""",
|
||||
source_value,
|
||||
)
|
||||
migrated_node_ids = (
|
||||
[row["id"] for row in node_ids_result] if node_ids_result else []
|
||||
)
|
||||
nodes_migrated = len(migrated_node_ids)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
""",
|
||||
target_value,
|
||||
node_ids_json,
|
||||
)
|
||||
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
migration_record = await tx.llmmodelmigration.create(
|
||||
data={
|
||||
"sourceModelSlug": model.slug,
|
||||
"targetModelSlug": migrate_to_slug,
|
||||
"reason": migration_reason,
|
||||
"migratedNodeIds": json.dumps(migrated_node_ids),
|
||||
"nodeCount": nodes_migrated,
|
||||
"customCreditCost": custom_credit_cost,
|
||||
}
|
||||
)
|
||||
migration_id = migration_record.id
|
||||
else:
|
||||
await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
)
|
||||
|
||||
return {
|
||||
"nodes_migrated": nodes_migrated,
|
||||
"migrated_to_slug": migrate_to_slug if nodes_migrated > 0 else None,
|
||||
"migration_id": migration_id,
|
||||
}
|
||||
|
||||
|
||||
async def delete_model(
|
||||
model_id: str, replacement_model_slug: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Delete an LLM model, optionally migrating affected AgentNodes first.
|
||||
|
||||
If workflows are using this model and no replacement is given, raises ValueError.
|
||||
If replacement is given, atomically migrates all affected nodes then deletes.
|
||||
"""
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
deleted_slug = model.slug
|
||||
deleted_display_name = model.displayName
|
||||
|
||||
async with transaction() as tx:
|
||||
count_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
deleted_slug,
|
||||
)
|
||||
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
if nodes_to_migrate > 0:
|
||||
if not replacement_model_slug:
|
||||
raise ValueError(
|
||||
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
|
||||
f"are using it. Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
replacement = await tx.llmmodel.find_unique(
|
||||
where={"slug": replacement_model_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' not found"
|
||||
)
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' is disabled."
|
||||
)
|
||||
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug,
|
||||
)
|
||||
|
||||
await tx.llmmodel.delete(where={"id": model_id})
|
||||
|
||||
return {
|
||||
"deleted_model_slug": deleted_slug,
|
||||
"deleted_model_display_name": deleted_display_name,
|
||||
"replacement_model_slug": replacement_model_slug,
|
||||
"nodes_migrated": nodes_to_migrate,
|
||||
}
|
||||
|
||||
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List model migrations."""
|
||||
where: Any = None if include_reverted else {"isReverted": False}
|
||||
records = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||
where=where,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"source_model_slug": r.sourceModelSlug,
|
||||
"target_model_slug": r.targetModelSlug,
|
||||
"reason": r.reason,
|
||||
"node_count": r.nodeCount,
|
||||
"custom_credit_cost": r.customCreditCost,
|
||||
"is_reverted": r.isReverted,
|
||||
"reverted_at": r.revertedAt.isoformat() if r.revertedAt else None,
|
||||
"created_at": r.createdAt.isoformat(),
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
|
||||
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Revert a model migration, restoring affected nodes to their original model."""
|
||||
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
if not migration:
|
||||
raise ValueError(f"Migration with id '{migration_id}' not found")
|
||||
|
||||
if migration.isReverted:
|
||||
raise ValueError(
|
||||
f"Migration '{migration_id}' has already been reverted"
|
||||
)
|
||||
|
||||
source_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migration.sourceModelSlug}
|
||||
)
|
||||
if not source_model:
|
||||
raise ValueError(
|
||||
f"Source model '{migration.sourceModelSlug}' no longer exists."
|
||||
)
|
||||
|
||||
migrated_node_ids: list[str] = (
|
||||
migration.migratedNodeIds
|
||||
if isinstance(migration.migratedNodeIds, list)
|
||||
else json.loads(migration.migratedNodeIds) # type: ignore
|
||||
)
|
||||
if not migrated_node_ids:
|
||||
raise ValueError("No nodes to revert in this migration")
|
||||
|
||||
source_model_re_enabled = False
|
||||
|
||||
async with transaction() as tx:
|
||||
if not source_model.isEnabled and re_enable_source_model:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": source_model.id},
|
||||
data={"isEnabled": True},
|
||||
)
|
||||
source_model_re_enabled = True
|
||||
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
result = await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
AND "constantInput"::jsonb->>'model' = $3
|
||||
""",
|
||||
migration.sourceModelSlug,
|
||||
node_ids_json,
|
||||
migration.targetModelSlug,
|
||||
)
|
||||
nodes_reverted = result if isinstance(result, int) else 0
|
||||
|
||||
await tx.llmmodelmigration.update(
|
||||
where={"id": migration_id},
|
||||
data={
|
||||
"isReverted": True,
|
||||
"revertedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"migration_id": migration_id,
|
||||
"source_model_slug": migration.sourceModelSlug,
|
||||
"target_model_slug": migration.targetModelSlug,
|
||||
"nodes_reverted": nodes_reverted,
|
||||
"nodes_already_changed": len(migrated_node_ids) - nodes_reverted,
|
||||
"source_model_re_enabled": source_model_re_enabled,
|
||||
}
|
||||
|
||||
|
||||
async def refresh_runtime_caches() -> None:
|
||||
llm_registry.clear_registry_cache()
|
||||
await llm_registry.refresh_llm_registry()
|
||||
await publish_registry_refresh_notification()
|
||||
698
autogpt_platform/backend/backend/server/v2/llm/db_write_test.py
Normal file
698
autogpt_platform/backend/backend/server/v2/llm/db_write_test.py
Normal file
@@ -0,0 +1,698 @@
|
||||
"""Tests for LLM registry DB write operations (db_write.py).
|
||||
|
||||
All functions under test are async; patch Prisma at the point of use.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.llm import db_write
|
||||
|
||||
_NOW = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_provider(
|
||||
id: str = "prov-1",
|
||||
name: str = "openai",
|
||||
display_name: str = "OpenAI",
|
||||
models: list | None = None,
|
||||
) -> Mock:
|
||||
p = Mock()
|
||||
p.id = id
|
||||
p.name = name
|
||||
p.displayName = display_name
|
||||
p.description = None
|
||||
p.defaultCredentialProvider = None
|
||||
p.defaultCredentialId = None
|
||||
p.defaultCredentialType = None
|
||||
p.metadata = {}
|
||||
p.createdAt = _NOW
|
||||
p.updatedAt = _NOW
|
||||
p.Models = models if models is not None else []
|
||||
return p
|
||||
|
||||
|
||||
def _make_model(
|
||||
id: str = "model-1",
|
||||
slug: str = "gpt-4",
|
||||
display_name: str = "GPT-4",
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
) -> Mock:
|
||||
m = Mock()
|
||||
m.id = id
|
||||
m.slug = slug
|
||||
m.displayName = display_name
|
||||
m.description = None
|
||||
m.providerId = "prov-1"
|
||||
m.creatorId = None
|
||||
m.contextWindow = 128000
|
||||
m.maxOutputTokens = 4096
|
||||
m.priceTier = 2
|
||||
m.isEnabled = is_enabled
|
||||
m.isRecommended = is_recommended
|
||||
m.supportsTools = False
|
||||
m.supportsJsonOutput = False
|
||||
m.supportsReasoning = False
|
||||
m.supportsParallelToolCalls = False
|
||||
m.capabilities = {}
|
||||
m.metadata = {}
|
||||
m.createdAt = _NOW
|
||||
m.updatedAt = _NOW
|
||||
m.Costs = []
|
||||
m.Creator = None
|
||||
return m
|
||||
|
||||
|
||||
def _make_migration(
|
||||
id: str = "mig-1",
|
||||
source_slug: str = "gpt-3",
|
||||
target_slug: str = "gpt-4",
|
||||
node_count: int = 3,
|
||||
migrated_node_ids: list | None = None,
|
||||
is_reverted: bool = False,
|
||||
) -> Mock:
|
||||
mg = Mock()
|
||||
mg.id = id
|
||||
mg.sourceModelSlug = source_slug
|
||||
mg.targetModelSlug = target_slug
|
||||
mg.reason = "upgrade"
|
||||
mg.nodeCount = node_count
|
||||
mg.customCreditCost = None
|
||||
mg.isReverted = is_reverted
|
||||
mg.revertedAt = None
|
||||
mg.createdAt = _NOW
|
||||
mg.migratedNodeIds = migrated_node_ids if migrated_node_ids is not None else ["n1", "n2", "n3"]
|
||||
return mg
|
||||
|
||||
|
||||
def _make_tx_ctx(tx: Mock) -> Mock:
|
||||
"""Return an async context manager that yields tx."""
|
||||
ctx = MagicMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=tx)
|
||||
ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
return ctx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_provider(mocker):
|
||||
"""create_provider calls prisma.create and returns the new provider."""
|
||||
mock_provider = _make_provider()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.create = AsyncMock(return_value=mock_provider)
|
||||
|
||||
result = await db_write.create_provider(name="openai", display_name="OpenAI")
|
||||
|
||||
assert result.name == "openai"
|
||||
assert result.id == "prov-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_provider(mocker):
|
||||
"""update_provider fetches existing provider then calls update."""
|
||||
existing = _make_provider()
|
||||
updated = _make_provider(display_name="OpenAI v2")
|
||||
prisma_mock = mocker.patch("prisma.models.LlmProvider.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=existing)
|
||||
prisma_mock.update = AsyncMock(return_value=updated)
|
||||
|
||||
result = await db_write.update_provider(
|
||||
provider_id="prov-1", display_name="OpenAI v2"
|
||||
)
|
||||
|
||||
assert result.displayName == "OpenAI v2"
|
||||
prisma_mock.find_unique.assert_called_once_with(where={"id": "prov-1"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_provider_not_found(mocker):
|
||||
"""update_provider raises ValueError when the provider does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await db_write.update_provider(provider_id="ghost")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_provider_success(mocker):
|
||||
"""delete_provider deletes the provider when it has no models."""
|
||||
existing = _make_provider(models=[])
|
||||
prisma_mock = mocker.patch("prisma.models.LlmProvider.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=existing)
|
||||
prisma_mock.delete = AsyncMock(return_value=existing)
|
||||
|
||||
result = await db_write.delete_provider(provider_id="prov-1")
|
||||
|
||||
assert result is True
|
||||
prisma_mock.delete.assert_called_once_with(where={"id": "prov-1"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_provider_has_models(mocker):
|
||||
"""delete_provider raises ValueError when the provider has associated models."""
|
||||
model = _make_model()
|
||||
existing = _make_provider(models=[model])
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot delete"):
|
||||
await db_write.delete_provider(provider_id="prov-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_model(mocker):
|
||||
"""create_model calls prisma.create with slug and display_name in data."""
|
||||
mock_model = _make_model()
|
||||
prisma_create = mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.create
|
||||
prisma_create = AsyncMock(return_value=mock_model)
|
||||
mocker.patch("prisma.models.LlmModel.prisma").return_value.create = prisma_create
|
||||
|
||||
result = await db_write.create_model(
|
||||
slug="gpt-4",
|
||||
display_name="GPT-4",
|
||||
provider_id="prov-1",
|
||||
context_window=128000,
|
||||
price_tier=2,
|
||||
)
|
||||
|
||||
assert result.slug == "gpt-4"
|
||||
call_kwargs = prisma_create.call_args
|
||||
data_arg = call_kwargs.kwargs.get("data") or call_kwargs.args[0]
|
||||
assert data_arg["slug"] == "gpt-4"
|
||||
assert data_arg["displayName"] == "GPT-4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_model_recommended_clears_others(mocker):
|
||||
"""update_model with is_recommended=True calls update_many to clear others first."""
|
||||
updated_model = _make_model(is_recommended=True)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.update_many = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock(return_value=updated_model)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.update_model(model_id="model-1", is_recommended=True)
|
||||
|
||||
tx.llmmodel.update_many.assert_called_once_with(
|
||||
where={"id": {"not": "model-1"}},
|
||||
data={"isRecommended": False},
|
||||
)
|
||||
assert result.isRecommended is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_model_recommended_false_no_clear(mocker):
|
||||
"""update_model with is_recommended=False does NOT call update_many."""
|
||||
updated_model = _make_model(is_recommended=False)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.update_many = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock(return_value=updated_model)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
await db_write.update_model(model_id="model-1", is_recommended=False)
|
||||
|
||||
tx.llmmodel.update_many.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_model_not_found(mocker):
|
||||
"""update_model raises ValueError when tx.llmmodel.update returns None."""
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.update_many = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock(return_value=None)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await db_write.update_model(model_id="ghost")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_usage(mocker):
|
||||
"""get_model_usage parses query_raw result and returns node_count."""
|
||||
mock_client = Mock()
|
||||
mock_client.query_raw = AsyncMock(return_value=[{"count": "3"}])
|
||||
mocker.patch("prisma.get_client", return_value=mock_client)
|
||||
|
||||
result = await db_write.get_model_usage("gpt-4")
|
||||
|
||||
assert result["node_count"] == 3
|
||||
assert result["model_slug"] == "gpt-4"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_no_usage(mocker):
|
||||
"""delete_model with no node usage deletes the model and returns nodes_migrated=0."""
|
||||
model = _make_model()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.query_raw = AsyncMock(return_value=[{"count": "0"}])
|
||||
tx.execute_raw = AsyncMock()
|
||||
tx.llmmodel.delete = AsyncMock()
|
||||
tx.llmmodel.find_unique = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.delete_model(model_id="model-1")
|
||||
|
||||
tx.llmmodel.delete.assert_called_once_with(where={"id": "model-1"})
|
||||
assert result["nodes_migrated"] == 0
|
||||
assert result["deleted_model_slug"] == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_with_replacement(mocker):
|
||||
"""delete_model migrates nodes and deletes when replacement is valid and enabled."""
|
||||
model = _make_model(slug="gpt-3")
|
||||
replacement = _make_model(id="model-2", slug="gpt-4", is_enabled=True)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.query_raw = AsyncMock(return_value=[{"count": "2"}])
|
||||
tx.llmmodel.find_unique = AsyncMock(return_value=replacement)
|
||||
tx.execute_raw = AsyncMock()
|
||||
tx.llmmodel.delete = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.delete_model(
|
||||
model_id="model-1", replacement_model_slug="gpt-4"
|
||||
)
|
||||
|
||||
tx.execute_raw.assert_called_once()
|
||||
tx.llmmodel.delete.assert_called_once()
|
||||
assert result["nodes_migrated"] == 2
|
||||
assert result["replacement_model_slug"] == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_usage_no_replacement(mocker):
|
||||
"""delete_model raises ValueError when nodes use the model but no replacement given."""
|
||||
model = _make_model()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.query_raw = AsyncMock(return_value=[{"count": "5"}])
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="provide a replacement"):
|
||||
await db_write.delete_model(model_id="model-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_replacement_disabled(mocker):
|
||||
"""delete_model raises ValueError when the replacement model is disabled."""
|
||||
model = _make_model(slug="gpt-3")
|
||||
disabled_replacement = _make_model(slug="gpt-4", is_enabled=False)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.query_raw = AsyncMock(return_value=[{"count": "2"}])
|
||||
tx.llmmodel.find_unique = AsyncMock(return_value=disabled_replacement)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="is disabled"):
|
||||
await db_write.delete_model(
|
||||
model_id="model-1", replacement_model_slug="gpt-4"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_not_found(mocker):
|
||||
"""delete_model raises ValueError when the model does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await db_write.delete_model(model_id="ghost")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Toggle model with migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_enable_no_migration(mocker):
|
||||
"""toggle_model_with_migration enable without migrate_to_slug does simple update."""
|
||||
model = _make_model(is_enabled=False)
|
||||
prisma_mock = mocker.patch("prisma.models.LlmModel.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=model)
|
||||
prisma_mock.update = AsyncMock(return_value=model)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id="model-1", is_enabled=True
|
||||
)
|
||||
|
||||
prisma_mock.update.assert_called_once()
|
||||
assert result["nodes_migrated"] == 0
|
||||
assert result["migration_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_disable_with_migration(mocker):
|
||||
"""Disabling with migrate_to_slug creates migration record and returns nodes_migrated."""
|
||||
model = _make_model(slug="gpt-3", is_enabled=True)
|
||||
replacement = _make_model(id="model-2", slug="gpt-4", is_enabled=True)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
migration_record = Mock()
|
||||
migration_record.id = "mig-new"
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.find_unique = AsyncMock(return_value=replacement)
|
||||
tx.query_raw = AsyncMock(
|
||||
return_value=[{"id": "node-1"}, {"id": "node-2"}]
|
||||
)
|
||||
tx.execute_raw = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock()
|
||||
tx.llmmodelmigration.create = AsyncMock(return_value=migration_record)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id="model-1",
|
||||
is_enabled=False,
|
||||
migrate_to_slug="gpt-4",
|
||||
migration_reason="upgrade",
|
||||
)
|
||||
|
||||
assert result["nodes_migrated"] == 2
|
||||
assert result["migration_id"] == "mig-new"
|
||||
tx.llmmodelmigration.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_disable_migration_target_not_found(mocker):
|
||||
"""Disabling with nonexistent replacement raises ValueError."""
|
||||
model = _make_model(slug="gpt-3", is_enabled=True)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await db_write.toggle_model_with_migration(
|
||||
model_id="model-1", is_enabled=False, migrate_to_slug="ghost"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_disable_migration_target_disabled(mocker):
|
||||
"""Disabling with a disabled replacement raises ValueError."""
|
||||
model = _make_model(slug="gpt-3", is_enabled=True)
|
||||
disabled = _make_model(slug="gpt-4", is_enabled=False)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.find_unique = AsyncMock(return_value=disabled)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
await db_write.toggle_model_with_migration(
|
||||
model_id="model-1", is_enabled=False, migrate_to_slug="gpt-4"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_disable_without_migration(mocker):
|
||||
"""Disabling without migrate_to_slug just updates is_enabled, no nodes migrated."""
|
||||
model = _make_model(slug="gpt-3", is_enabled=True)
|
||||
prisma_mock = mocker.patch("prisma.models.LlmModel.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=model)
|
||||
prisma_mock.update = AsyncMock(return_value=model)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id="model-1", is_enabled=False # no migrate_to_slug
|
||||
)
|
||||
|
||||
# Should do a simple update, no transaction, no migration record
|
||||
prisma_mock.update.assert_called_once_with(
|
||||
where={"id": "model-1"}, data={"isEnabled": False}
|
||||
)
|
||||
assert result["nodes_migrated"] == 0
|
||||
assert result["migration_id"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Migration operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_migrations_active_only(mocker):
|
||||
"""list_migrations(include_reverted=False) passes where={isReverted: False}."""
|
||||
records = [_make_migration()]
|
||||
prisma_find_many = mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_many
|
||||
prisma_find_many = AsyncMock(return_value=records)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_many = prisma_find_many
|
||||
|
||||
result = await db_write.list_migrations(include_reverted=False)
|
||||
|
||||
call_kwargs = prisma_find_many.call_args
|
||||
where_arg = (call_kwargs.kwargs.get("where") or
|
||||
(call_kwargs.args[0] if call_kwargs.args else None))
|
||||
assert where_arg == {"isReverted": False}
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_migrations_include_reverted(mocker):
|
||||
"""list_migrations(include_reverted=True) passes where=None."""
|
||||
records = [_make_migration(), _make_migration(id="mig-2", is_reverted=True)]
|
||||
prisma_find_many = AsyncMock(return_value=records)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_many = prisma_find_many
|
||||
|
||||
result = await db_write.list_migrations(include_reverted=True)
|
||||
|
||||
call_kwargs = prisma_find_many.call_args
|
||||
where_arg = call_kwargs.kwargs.get("where")
|
||||
assert where_arg is None
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_migration_success(mocker):
|
||||
"""revert_migration re-enables source model, updates nodes, marks migration reverted."""
|
||||
migration = _make_migration(migrated_node_ids=["n1", "n2"])
|
||||
source_model = _make_model(is_enabled=False)
|
||||
|
||||
prisma_migration_mock = mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value
|
||||
prisma_migration_mock.find_unique = AsyncMock(return_value=migration)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=source_model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock()
|
||||
tx.execute_raw = AsyncMock(return_value=2)
|
||||
tx.llmmodelmigration.update = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.revert_migration(
|
||||
migration_id="mig-1", re_enable_source_model=True
|
||||
)
|
||||
|
||||
assert result["nodes_reverted"] == 2
|
||||
assert result["source_model_re_enabled"] is True
|
||||
tx.llmmodel.update.assert_called_once_with(
|
||||
where={"id": source_model.id},
|
||||
data={"isEnabled": True},
|
||||
)
|
||||
tx.llmmodelmigration.update.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_migration_already_reverted(mocker):
|
||||
"""revert_migration raises ValueError when migration is already reverted."""
|
||||
migration = _make_migration(is_reverted=True)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=migration)
|
||||
|
||||
with pytest.raises(ValueError, match="already been reverted"):
|
||||
await db_write.revert_migration(migration_id="mig-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_migration_not_found(mocker):
|
||||
"""revert_migration raises ValueError when migration does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await db_write.revert_migration(migration_id="ghost")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_migration_no_re_enable(mocker):
|
||||
"""revert_migration with re_enable_source_model=False does not re-enable source."""
|
||||
migration = _make_migration(migrated_node_ids=["n1", "n2"])
|
||||
source_model = _make_model(is_enabled=False)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=migration)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=source_model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock()
|
||||
tx.execute_raw = AsyncMock(return_value=2)
|
||||
tx.llmmodelmigration.update = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.revert_migration(
|
||||
migration_id="mig-1", re_enable_source_model=False
|
||||
)
|
||||
|
||||
# Model should NOT have been re-enabled
|
||||
tx.llmmodel.update.assert_not_called()
|
||||
assert result["source_model_re_enabled"] is False
|
||||
assert result["nodes_reverted"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_migration_source_model_gone(mocker):
|
||||
"""revert_migration raises ValueError when the source model no longer exists."""
|
||||
migration = _make_migration()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=migration)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="no longer exists"):
|
||||
await db_write.revert_migration(migration_id="mig-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_runtime_caches(mocker):
|
||||
"""refresh_runtime_caches clears cache, refreshes registry, publishes notification."""
|
||||
mock_clear = mocker.patch("backend.data.llm_registry.clear_registry_cache")
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.data.llm_registry.refresh_llm_registry",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
mock_publish = mocker.patch(
|
||||
"backend.data.llm_registry.notifications.publish_registry_refresh_notification",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
await db_write.refresh_runtime_caches()
|
||||
|
||||
mock_clear.assert_called_once()
|
||||
mock_refresh.assert_called_once()
|
||||
mock_publish.assert_called_once()
|
||||
69
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
69
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Pydantic models for LLM registry public API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class LlmModelCost(pydantic.BaseModel):
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int = pydantic.Field(ge=0)
|
||||
credential_provider: str
|
||||
credential_id: str | None = None
|
||||
credential_type: str | None = None
|
||||
currency: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCreator(pydantic.BaseModel):
|
||||
"""Represents the organization that created/trained the model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
|
||||
|
||||
class LlmModel(pydantic.BaseModel):
|
||||
"""Public-facing LLM model information."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
provider_name: str
|
||||
creator: LlmModelCreator | None = None
|
||||
context_window: int
|
||||
max_output_tokens: int | None = None
|
||||
price_tier: int # 1=cheapest, 2=medium, 3=expensive
|
||||
is_enabled: bool = True
|
||||
is_recommended: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvider(pydantic.BaseModel):
|
||||
"""Provider with its enabled models."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
models: list[LlmModel] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmModelsResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/models."""
|
||||
|
||||
models: list[LlmModel]
|
||||
total: int
|
||||
|
||||
|
||||
class LlmProvidersResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/providers."""
|
||||
|
||||
providers: list[LlmProvider]
|
||||
total: int
|
||||
96
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
96
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Public read-only API for LLM registry."""
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data.llm_registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCreator,
|
||||
get_all_models,
|
||||
get_enabled_models,
|
||||
)
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/llm",
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(
|
||||
creator: RegistryModelCreator | None,
|
||||
) -> llm_model.LlmModelCreator | None:
|
||||
if not creator:
|
||||
return None
|
||||
return llm_model.LlmModelCreator(
|
||||
id=creator.id,
|
||||
name=creator.name,
|
||||
display_name=creator.display_name,
|
||||
description=creator.description,
|
||||
website_url=creator.website_url,
|
||||
logo_url=creator.logo_url,
|
||||
)
|
||||
|
||||
|
||||
def _map_model(model: RegistryModel) -> llm_model.LlmModel:
|
||||
return llm_model.LlmModel(
|
||||
slug=model.slug,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
provider_name=model.provider_display_name,
|
||||
creator=_map_creator(model.creator),
|
||||
context_window=model.metadata.context_window,
|
||||
max_output_tokens=model.metadata.max_output_tokens,
|
||||
price_tier=model.metadata.price_tier,
|
||||
is_enabled=model.is_enabled,
|
||||
is_recommended=model.is_recommended,
|
||||
capabilities=model.capabilities,
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
unit=cost.unit,
|
||||
credit_cost=cost.credit_cost,
|
||||
credential_provider=cost.credential_provider,
|
||||
credential_id=cost.credential_id,
|
||||
credential_type=cost.credential_type,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata,
|
||||
)
|
||||
for cost in model.costs
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models(
|
||||
enabled_only: bool = fastapi.Query(
|
||||
default=True, description="Only return enabled models"
|
||||
),
|
||||
):
|
||||
registry_models = get_enabled_models() if enabled_only else get_all_models()
|
||||
models = [_map_model(m) for m in registry_models]
|
||||
return llm_model.LlmModelsResponse(models=models, total=len(models))
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
registry_models = get_enabled_models()
|
||||
|
||||
provider_map: dict[str, list[RegistryModel]] = {}
|
||||
for model in registry_models:
|
||||
provider_key = model.metadata.provider
|
||||
if provider_key not in provider_map:
|
||||
provider_map[provider_key] = []
|
||||
provider_map[provider_key].append(model)
|
||||
|
||||
providers = [
|
||||
llm_model.LlmProvider(
|
||||
name=provider_key,
|
||||
display_name=models[0].provider_display_name if models else provider_key,
|
||||
models=[
|
||||
_map_model(m) for m in sorted(models, key=lambda m: m.display_name)
|
||||
],
|
||||
)
|
||||
for provider_key, models in sorted(provider_map.items())
|
||||
]
|
||||
|
||||
return llm_model.LlmProvidersResponse(providers=providers, total=len(providers))
|
||||
282
autogpt_platform/backend/backend/server/v2/llm/routes_test.py
Normal file
282
autogpt_platform/backend/backend/server/v2/llm/routes_test.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Tests for public read-only LLM registry routes (routes.py).
|
||||
|
||||
Covers:
|
||||
- GET /llm/models (enabled_only=True default, enabled_only=False)
|
||||
- GET /llm/providers
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.llm.routes import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Bypass JWT auth for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_cost(
|
||||
unit: str = "RUN",
|
||||
credit_cost: int = 10,
|
||||
credential_provider: str = "openai",
|
||||
credential_id: str | None = None,
|
||||
credential_type: str | None = None,
|
||||
currency: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> Mock:
|
||||
cost = Mock()
|
||||
cost.unit = unit
|
||||
cost.credit_cost = credit_cost
|
||||
cost.credential_provider = credential_provider
|
||||
cost.credential_id = credential_id
|
||||
cost.credential_type = credential_type
|
||||
cost.currency = currency
|
||||
cost.metadata = metadata or {}
|
||||
return cost
|
||||
|
||||
|
||||
def _make_mock_creator(
|
||||
id: str = "creator-1",
|
||||
name: str = "openai",
|
||||
display_name: str = "OpenAI",
|
||||
description: str | None = "AI company",
|
||||
website_url: str | None = "https://openai.com",
|
||||
logo_url: str | None = None,
|
||||
) -> Mock:
|
||||
creator = Mock()
|
||||
creator.id = id
|
||||
creator.name = name
|
||||
creator.display_name = display_name
|
||||
creator.description = description
|
||||
creator.website_url = website_url
|
||||
creator.logo_url = logo_url
|
||||
return creator
|
||||
|
||||
|
||||
def _make_mock_model(
|
||||
slug: str = "gpt-4",
|
||||
display_name: str = "GPT-4",
|
||||
description: str | None = "Latest GPT",
|
||||
provider_display_name: str = "OpenAI",
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
capabilities: dict | None = None,
|
||||
provider_key: str = "openai",
|
||||
context_window: int = 128000,
|
||||
max_output_tokens: int | None = 4096,
|
||||
price_tier: int = 2,
|
||||
creator: Mock | None = None,
|
||||
costs: list | None = None,
|
||||
) -> Mock:
|
||||
model = Mock()
|
||||
model.slug = slug
|
||||
model.display_name = display_name
|
||||
model.description = description
|
||||
model.provider_display_name = provider_display_name
|
||||
model.is_enabled = is_enabled
|
||||
model.is_recommended = is_recommended
|
||||
model.capabilities = capabilities or {}
|
||||
model.creator = creator
|
||||
model.costs = costs or []
|
||||
|
||||
meta = Mock()
|
||||
meta.provider = provider_key
|
||||
meta.context_window = context_window
|
||||
meta.max_output_tokens = max_output_tokens
|
||||
meta.price_tier = price_tier
|
||||
model.metadata = meta
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /llm/models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_models_enabled_only(mocker):
|
||||
"""Default enabled_only=True calls get_enabled_models and returns correct shape."""
|
||||
mock_model = _make_mock_model()
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models", return_value=[mock_model]
|
||||
)
|
||||
mocker.patch("backend.server.v2.llm.routes.get_all_models", return_value=[])
|
||||
|
||||
response = client.get("/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert len(data["models"]) == 1
|
||||
first = data["models"][0]
|
||||
assert first["slug"] == "gpt-4"
|
||||
assert first["display_name"] == "GPT-4"
|
||||
assert first["provider_name"] == "OpenAI"
|
||||
assert first["is_enabled"] is True
|
||||
assert first["is_recommended"] is False
|
||||
assert first["context_window"] == 128000
|
||||
assert first["price_tier"] == 2
|
||||
assert first["creator"] is None
|
||||
assert first["costs"] == []
|
||||
|
||||
|
||||
def test_list_models_all(mocker):
|
||||
"""enabled_only=false calls get_all_models instead of get_enabled_models."""
|
||||
mock_model = _make_mock_model(is_enabled=False, slug="gpt-3")
|
||||
mock_get_all = mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_all_models", return_value=[mock_model]
|
||||
)
|
||||
mock_get_enabled = mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models", return_value=[]
|
||||
)
|
||||
|
||||
response = client.get("/llm/models?enabled_only=false")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get_all.assert_called_once()
|
||||
mock_get_enabled.assert_not_called()
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["models"][0]["slug"] == "gpt-3"
|
||||
|
||||
|
||||
def test_list_models_empty(mocker):
|
||||
"""Empty registry returns an empty models list with total=0."""
|
||||
mocker.patch("backend.server.v2.llm.routes.get_enabled_models", return_value=[])
|
||||
|
||||
response = client.get("/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
assert data["models"] == []
|
||||
|
||||
|
||||
def test_list_models_with_creator(mocker):
|
||||
"""Model with a creator surfaces creator fields in the response."""
|
||||
creator = _make_mock_creator()
|
||||
mock_model = _make_mock_model(creator=creator)
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models", return_value=[mock_model]
|
||||
)
|
||||
|
||||
response = client.get("/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
creator_data = response.json()["models"][0]["creator"]
|
||||
assert creator_data is not None
|
||||
assert creator_data["id"] == "creator-1"
|
||||
assert creator_data["name"] == "openai"
|
||||
assert creator_data["display_name"] == "OpenAI"
|
||||
assert creator_data["website_url"] == "https://openai.com"
|
||||
|
||||
|
||||
def test_list_models_with_costs(mocker):
|
||||
"""Model with costs surfaces cost entries in the response."""
|
||||
cost = _make_mock_cost(unit="TOKENS", credit_cost=5, credential_provider="openai")
|
||||
mock_model = _make_mock_model(costs=[cost])
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models", return_value=[mock_model]
|
||||
)
|
||||
|
||||
response = client.get("/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
costs_data = response.json()["models"][0]["costs"]
|
||||
assert len(costs_data) == 1
|
||||
assert costs_data[0]["unit"] == "TOKENS"
|
||||
assert costs_data[0]["credit_cost"] == 5
|
||||
assert costs_data[0]["credential_provider"] == "openai"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /llm/providers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_providers(mocker):
|
||||
"""Single provider groups its models correctly."""
|
||||
model_a = _make_mock_model(
|
||||
slug="gpt-4", display_name="GPT-4", provider_key="openai"
|
||||
)
|
||||
model_b = _make_mock_model(
|
||||
slug="gpt-3.5", display_name="GPT-3.5", provider_key="openai"
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models",
|
||||
return_value=[model_a, model_b],
|
||||
)
|
||||
|
||||
response = client.get("/llm/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
assert len(providers) == 1
|
||||
provider = providers[0]
|
||||
assert provider["name"] == "openai"
|
||||
assert provider["display_name"] == "OpenAI"
|
||||
assert len(provider["models"]) == 2
|
||||
|
||||
|
||||
def test_list_providers_multiple_providers(mocker):
|
||||
"""Two different providers are both present and sorted alphabetically."""
|
||||
openai_model = _make_mock_model(
|
||||
slug="gpt-4",
|
||||
display_name="GPT-4",
|
||||
provider_key="openai",
|
||||
provider_display_name="OpenAI",
|
||||
)
|
||||
anthropic_model = _make_mock_model(
|
||||
slug="claude-3",
|
||||
display_name="Claude 3",
|
||||
provider_key="anthropic",
|
||||
provider_display_name="Anthropic",
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models",
|
||||
return_value=[openai_model, anthropic_model],
|
||||
)
|
||||
|
||||
response = client.get("/llm/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
assert len(providers) == 2
|
||||
provider_names = [p["name"] for p in providers]
|
||||
# sorted alphabetically: anthropic before openai
|
||||
assert provider_names == sorted(provider_names)
|
||||
assert "anthropic" in provider_names
|
||||
assert "openai" in provider_names
|
||||
|
||||
|
||||
def test_list_providers_empty(mocker):
|
||||
"""Empty registry returns an empty providers list."""
|
||||
mocker.patch("backend.server.v2.llm.routes.get_enabled_models", return_value=[])
|
||||
|
||||
response = client.get("/llm/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["providers"] == []
|
||||
@@ -0,0 +1,148 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmProvider" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"defaultCredentialProvider" TEXT,
|
||||
"defaultCredentialId" TEXT,
|
||||
"defaultCredentialType" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCreator" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"websiteUrl" TEXT,
|
||||
"logoUrl" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModel" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"slug" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"providerId" TEXT NOT NULL,
|
||||
"creatorId" TEXT,
|
||||
"contextWindow" INTEGER NOT NULL,
|
||||
"maxOutputTokens" INTEGER,
|
||||
"priceTier" INTEGER NOT NULL DEFAULT 1,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"isRecommended" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsTools" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsReasoning" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsParallelToolCalls" BOOLEAN NOT NULL DEFAULT false,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}',
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCost" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
|
||||
"creditCost" INTEGER NOT NULL,
|
||||
"credentialProvider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"credentialType" TEXT,
|
||||
"currency" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
"llmModelId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelMigration" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"sourceModelSlug" TEXT NOT NULL,
|
||||
"targetModelSlug" TEXT NOT NULL,
|
||||
"reason" TEXT,
|
||||
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
|
||||
"nodeCount" INTEGER NOT NULL,
|
||||
"customCreditCost" INTEGER,
|
||||
"isReverted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"revertedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmProvider_name_key" ON "LlmProvider"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModel_slug_key" ON "LlmModel"("slug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- CreateIndex (partial unique for default costs - no specific credential)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_default_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL;
|
||||
|
||||
-- CreateIndex (partial unique for credential-specific costs)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_credential_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "credentialId", "unit") WHERE "credentialId" IS NOT NULL;
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");
|
||||
|
||||
-- CreateIndex (partial unique to prevent multiple active migrations per source)
|
||||
CREATE UNIQUE INDEX "LlmModelMigration_active_source_key" ON "LlmModelMigration"("sourceModelSlug") WHERE "isReverted" = false;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey" FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_sourceModelSlug_fkey" FOREIGN KEY ("sourceModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_targetModelSlug_fkey" FOREIGN KEY ("targetModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddCheckConstraints (enforce data integrity)
|
||||
ALTER TABLE "LlmModel"
|
||||
ADD CONSTRAINT "LlmModel_priceTier_check" CHECK ("priceTier" BETWEEN 1 AND 3);
|
||||
|
||||
ALTER TABLE "LlmModelCost"
|
||||
ADD CONSTRAINT "LlmModelCost_creditCost_check" CHECK ("creditCost" >= 0);
|
||||
|
||||
ALTER TABLE "LlmModelMigration"
|
||||
ADD CONSTRAINT "LlmModelMigration_nodeCount_check" CHECK ("nodeCount" >= 0),
|
||||
ADD CONSTRAINT "LlmModelMigration_customCreditCost_check" CHECK ("customCreditCost" IS NULL OR "customCreditCost" >= 0);
|
||||
@@ -0,0 +1,287 @@
|
||||
-- Seed LLM Registry from existing hard-coded data
|
||||
-- This migration populates the LlmProvider, LlmModelCreator, LlmModel, and LlmModelCost tables
|
||||
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
|
||||
|
||||
-- Insert Providers
|
||||
INSERT INTO "LlmProvider" ("id", "createdAt", "updatedAt", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Model Creators
|
||||
INSERT INTO "LlmModelCreator" ("id", "createdAt", "updatedAt", "name", "displayName", "description", "websiteUrl", "logoUrl", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT, O1, O3, and DALL-E models', 'https://openai.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude AI models', 'https://anthropic.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama foundation models', 'https://llama.meta.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini and PaLM models', 'https://deepmind.google', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'mistralai', 'Mistral AI', 'Creator of Mistral and Codestral models', 'https://mistral.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command language models', 'https://cohere.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek reasoning models', 'https://deepseek.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'alibaba', 'Alibaba', 'Creator of Qwen language models', 'https://qwenlm.github.io', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 AI models', 'https://v0.dev', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of Phi models', 'https://microsoft.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar search models', 'https://perplexity.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'nousresearch', 'Nous Research', 'Creator of Hermes language models', 'https://nousresearch.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova language models', 'https://aws.amazon.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'moonshotai', 'Moonshot AI', 'Creator of Kimi language models', 'https://moonshot.ai', NULL, '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Models (using CTEs to reference provider and creator IDs)
|
||||
WITH provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
),
|
||||
creator_ids AS (
|
||||
SELECT "id", "name" FROM "LlmModelCreator"
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "createdAt", "updatedAt", "slug", "displayName", "description", "providerId", "creatorId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP,
|
||||
model_slug,
|
||||
model_display_name,
|
||||
NULL,
|
||||
p."id",
|
||||
c."id",
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM (VALUES
|
||||
-- OpenAI models (creator: openai)
|
||||
('o3-2025-04-16', 'O3', 'openai', 'openai', 200000, 100000),
|
||||
('o3-mini', 'O3 Mini', 'openai', 'openai', 200000, 100000),
|
||||
('o1', 'O1', 'openai', 'openai', 200000, 100000),
|
||||
('o1-mini', 'O1 Mini', 'openai', 'openai', 128000, 65536),
|
||||
('gpt-5.2-2025-12-11', 'GPT-5.2', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-2025-08-07', 'GPT 5', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 'openai', 400000, 16384),
|
||||
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 'openai', 1000000, 32768),
|
||||
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 'openai', 1047576, 32768),
|
||||
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 'openai', 128000, 16384),
|
||||
('gpt-4o', 'GPT 4o', 'openai', 'openai', 128000, 16384),
|
||||
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 'openai', 128000, 4096),
|
||||
-- Anthropic models (creator: anthropic)
|
||||
('claude-opus-4-6', 'Claude Opus 4.6', 'anthropic', 'anthropic', 200000, 128000),
|
||||
('claude-sonnet-4-6', 'Claude Sonnet 4.6', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 'anthropic', 200000, 32000),
|
||||
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 'anthropic', 200000, 32000),
|
||||
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 'anthropic', 200000, 4096),
|
||||
-- AI/ML API models (creators: alibaba, nvidia, meta)
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 'alibaba', 32000, 8000),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 'nvidia', 128000, 40000),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 'meta', 128000, NULL),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 'meta', 131000, 2000),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 'meta', 128000, NULL),
|
||||
-- Groq models (creator: meta for Llama)
|
||||
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 'meta', 128000, 32768),
|
||||
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 'meta', 128000, 8192),
|
||||
-- Ollama models (creators: meta for Llama, mistralai for Mistral)
|
||||
('llama3.3', 'Llama 3.3', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3.2', 'Llama 3.2', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3', 'Llama 3', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 'meta', 8192, NULL),
|
||||
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 'mistralai', 32768, NULL),
|
||||
-- OpenRouter models (creators: google, mistralai, cohere, deepseek, perplexity, nousresearch, openai, amazon, microsoft, gryphe, meta, xai, moonshotai, alibaba)
|
||||
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 'google', 1050000, 8192),
|
||||
('google/gemini-2.5-pro', 'Gemini 2.5 Pro', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-3.1-pro-preview', 'Gemini 3.1 Pro Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-3-flash-preview', 'Gemini 3 Flash Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 'google', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 'google', 1048576, 8192),
|
||||
('google/gemini-3.1-flash-lite-preview', 'Gemini 3.1 Flash Lite Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 'google', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 'google', 1048576, 8192),
|
||||
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 'mistralai', 128000, 4096),
|
||||
('mistralai/mistral-large-2512', 'Mistral Large 3 2512', 'open_router', 'mistralai', 262144, NULL),
|
||||
('mistralai/mistral-medium-3.1', 'Mistral Medium 3.1', 'open_router', 'mistralai', 131072, NULL),
|
||||
('mistralai/mistral-small-3.2-24b-instruct', 'Mistral Small 3.2 24B', 'open_router', 'mistralai', 131072, 131072),
|
||||
('mistralai/codestral-2508', 'Codestral 2508', 'open_router', 'mistralai', 256000, NULL),
|
||||
('cohere/command-r-08-2024', 'Command R', 'open_router', 'cohere', 128000, 4096),
|
||||
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 'cohere', 128000, 4096),
|
||||
('cohere/command-a-03-2025', 'Command A 03.2025', 'open_router', 'cohere', 256000, 8192),
|
||||
('cohere/command-a-reasoning-08-2025', 'Command A Reasoning 08.2025', 'open_router', 'cohere', 256000, 32768),
|
||||
('cohere/command-a-translate-08-2025', 'Command A Translate 08.2025', 'open_router', 'cohere', 128000, 8192),
|
||||
('cohere/command-a-vision-07-2025', 'Command A Vision 07.2025', 'open_router', 'cohere', 128000, 8192),
|
||||
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 'deepseek', 64000, 2048),
|
||||
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 'deepseek', 163840, 163840),
|
||||
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 'perplexity', 127000, 8000),
|
||||
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 'perplexity', 200000, 8000),
|
||||
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 'perplexity', 128000, 16000),
|
||||
('perplexity/sonar-reasoning-pro', 'Sonar Reasoning Pro', 'open_router', 'perplexity', 128000, 8000),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 'nousresearch', 131000, 4096),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 'nousresearch', 12288, 12288),
|
||||
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 'openai', 131072, 131072),
|
||||
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 'openai', 131072, 32768),
|
||||
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 'amazon', 300000, 5120),
|
||||
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 'amazon', 128000, 5120),
|
||||
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 'amazon', 300000, 5120),
|
||||
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 'microsoft', 65536, 4096),
|
||||
('microsoft/phi-4', 'Phi-4', 'open_router', 'microsoft', 16384, 16384),
|
||||
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 'gryphe', 4096, 4096),
|
||||
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 'meta', 131072, 131072),
|
||||
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 'meta', 1048576, 1000000),
|
||||
('x-ai/grok-3', 'Grok 3', 'open_router', 'xai', 131072, 131072),
|
||||
('x-ai/grok-4', 'Grok 4', 'open_router', 'xai', 256000, 256000),
|
||||
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 'xai', 2000000, 30000),
|
||||
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 'xai', 2000000, 30000),
|
||||
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 'xai', 256000, 10000),
|
||||
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 'moonshotai', 131000, 131000),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 'alibaba', 262144, 262144),
|
||||
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 'alibaba', 262144, 262144),
|
||||
-- Llama API models (creator: meta)
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 'meta', 128000, 4028),
|
||||
-- v0 models (creator: vercel)
|
||||
('v0-1.5-md', 'v0 1.5 MD', 'v0', 'vercel', 128000, 64000),
|
||||
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 'vercel', 512000, 64000),
|
||||
('v0-1.0-md', 'v0 1.0 MD', 'v0', 'vercel', 128000, 64000)
|
||||
) AS models(model_slug, model_display_name, provider_name, creator_name, context_window, max_output_tokens)
|
||||
JOIN provider_ids p ON p."name" = models.provider_name
|
||||
JOIN creator_ids c ON c."name" = models.creator_name
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Insert Costs (using CTEs to reference model IDs)
|
||||
WITH model_ids AS (
|
||||
SELECT "id", "slug", "providerId" FROM "LlmModel"
|
||||
),
|
||||
provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "createdAt", "updatedAt", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP,
|
||||
'RUN'::"LlmCostUnit",
|
||||
cost,
|
||||
p."name",
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM (VALUES
|
||||
-- OpenAI costs
|
||||
('o3-2025-04-16', 4),
|
||||
('o3-mini', 2),
|
||||
('o1', 16),
|
||||
('o1-mini', 4),
|
||||
('gpt-5.2-2025-12-11', 5),
|
||||
('gpt-5-2025-08-07', 2),
|
||||
('gpt-5.1-2025-11-13', 5),
|
||||
('gpt-5-mini-2025-08-07', 1),
|
||||
('gpt-5-nano-2025-08-07', 1),
|
||||
('gpt-5-chat-latest', 5),
|
||||
('gpt-4.1-2025-04-14', 2),
|
||||
('gpt-4.1-mini-2025-04-14', 1),
|
||||
('gpt-4o-mini', 1),
|
||||
('gpt-4o', 3),
|
||||
('gpt-4-turbo', 10),
|
||||
-- Anthropic costs
|
||||
('claude-opus-4-6', 21),
|
||||
('claude-sonnet-4-6', 5),
|
||||
('claude-opus-4-1-20250805', 21),
|
||||
('claude-opus-4-20250514', 21),
|
||||
('claude-sonnet-4-20250514', 5),
|
||||
('claude-haiku-4-5-20251001', 4),
|
||||
('claude-opus-4-5-20251101', 14),
|
||||
('claude-sonnet-4-5-20250929', 9),
|
||||
('claude-3-haiku-20240307', 1),
|
||||
-- AI/ML API costs
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
|
||||
-- Groq costs
|
||||
('llama-3.3-70b-versatile', 1),
|
||||
('llama-3.1-8b-instant', 1),
|
||||
-- Ollama costs
|
||||
('llama3.3', 1),
|
||||
('llama3.2', 1),
|
||||
('llama3', 1),
|
||||
('llama3.1:405b', 1),
|
||||
('dolphin-mistral:latest', 1),
|
||||
-- OpenRouter costs
|
||||
('google/gemini-2.5-pro-preview-03-25', 4),
|
||||
('google/gemini-2.5-pro', 4),
|
||||
('google/gemini-3.1-pro-preview', 5),
|
||||
('google/gemini-3-flash-preview', 3),
|
||||
('google/gemini-3.1-flash-lite-preview', 1),
|
||||
('mistralai/mistral-nemo', 1),
|
||||
('mistralai/mistral-large-2512', 3),
|
||||
('mistralai/mistral-medium-3.1', 2),
|
||||
('mistralai/mistral-small-3.2-24b-instruct', 1),
|
||||
('mistralai/codestral-2508', 2),
|
||||
('cohere/command-r-08-2024', 1),
|
||||
('cohere/command-r-plus-08-2024', 3),
|
||||
('cohere/command-a-03-2025', 2),
|
||||
('cohere/command-a-reasoning-08-2025', 3),
|
||||
('cohere/command-a-translate-08-2025', 1),
|
||||
('cohere/command-a-vision-07-2025', 2),
|
||||
('deepseek/deepseek-chat', 2),
|
||||
('perplexity/sonar', 1),
|
||||
('perplexity/sonar-pro', 5),
|
||||
('perplexity/sonar-deep-research', 10),
|
||||
('perplexity/sonar-reasoning-pro', 5),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 1),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 1),
|
||||
('amazon/nova-lite-v1', 1),
|
||||
('amazon/nova-micro-v1', 1),
|
||||
('amazon/nova-pro-v1', 1),
|
||||
('microsoft/wizardlm-2-8x22b', 1),
|
||||
('microsoft/phi-4', 1),
|
||||
('gryphe/mythomax-l2-13b', 1),
|
||||
('meta-llama/llama-4-scout', 1),
|
||||
('meta-llama/llama-4-maverick', 1),
|
||||
('x-ai/grok-3', 5),
|
||||
('x-ai/grok-4', 9),
|
||||
('x-ai/grok-4-fast', 1),
|
||||
('x-ai/grok-4.1-fast', 1),
|
||||
('x-ai/grok-code-fast-1', 1),
|
||||
('moonshotai/kimi-k2', 1),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 1),
|
||||
('qwen/qwen3-coder', 9),
|
||||
('google/gemini-2.5-flash', 1),
|
||||
('google/gemini-2.0-flash-001', 1),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 1),
|
||||
('google/gemini-2.0-flash-lite-001', 1),
|
||||
('deepseek/deepseek-r1-0528', 1),
|
||||
('openai/gpt-oss-120b', 1),
|
||||
('openai/gpt-oss-20b', 1),
|
||||
-- Llama API costs
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
|
||||
('Llama-3.3-8B-Instruct', 1),
|
||||
('Llama-3.3-70B-Instruct', 1),
|
||||
-- v0 costs
|
||||
('v0-1.5-md', 1),
|
||||
('v0-1.5-lg', 2),
|
||||
('v0-1.0-md', 1)
|
||||
) AS costs(model_slug, cost)
|
||||
JOIN model_ids m ON m."slug" = costs.model_slug
|
||||
JOIN provider_ids p ON p."id" = m."providerId"
|
||||
ON CONFLICT ("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL DO NOTHING;
|
||||
|
||||
@@ -1301,3 +1301,164 @@ model OAuthRefreshToken {
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LLM Registry Models
|
||||
// ============================================================================
|
||||
|
||||
enum LlmCostUnit {
|
||||
RUN
|
||||
TOKENS
|
||||
}
|
||||
|
||||
model LlmProvider {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
defaultCredentialProvider String?
|
||||
defaultCredentialId String?
|
||||
defaultCredentialType String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModel {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
slug String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
providerId String
|
||||
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
|
||||
|
||||
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
|
||||
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
|
||||
creatorId String?
|
||||
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
|
||||
|
||||
contextWindow Int
|
||||
maxOutputTokens Int?
|
||||
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive (DB constraint: 1-3)
|
||||
isEnabled Boolean @default(true)
|
||||
isRecommended Boolean @default(false)
|
||||
|
||||
// Model-specific capabilities
|
||||
// These vary per model even within the same provider (e.g., Hugging Face)
|
||||
// Default to false for safety - partially-seeded rows should not be assumed capable
|
||||
supportsTools Boolean @default(false)
|
||||
supportsJsonOutput Boolean @default(false)
|
||||
supportsReasoning Boolean @default(false)
|
||||
supportsParallelToolCalls Boolean @default(false)
|
||||
|
||||
capabilities Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
|
||||
Costs LlmModelCost[]
|
||||
SourceMigrations LlmModelMigration[] @relation("SourceMigrations")
|
||||
TargetMigrations LlmModelMigration[] @relation("TargetMigrations")
|
||||
|
||||
@@index([providerId, isEnabled])
|
||||
@@index([creatorId])
|
||||
// Note: slug already has @unique which creates an implicit index
|
||||
}
|
||||
|
||||
model LlmModelCost {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
unit LlmCostUnit @default(RUN)
|
||||
|
||||
creditCost Int // DB constraint: >= 0
|
||||
|
||||
// Provider identifier (e.g., "openai", "anthropic", "openrouter")
|
||||
// Used to determine which credential system provides the API key.
|
||||
// Allows different pricing for:
|
||||
// - Default provider costs (WHERE credentialId IS NULL)
|
||||
// - User's own API key costs (WHERE credentialId IS NOT NULL)
|
||||
credentialProvider String
|
||||
credentialId String?
|
||||
credentialType String?
|
||||
currency String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
llmModelId String
|
||||
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Note: Unique constraints are implemented as partial indexes in migration SQL:
|
||||
// - One for default costs (WHERE credentialId IS NULL)
|
||||
// - One for credential-specific costs (WHERE credentialId IS NOT NULL)
|
||||
// This allows both provider-level defaults and credential-specific overrides
|
||||
}
|
||||
|
||||
model LlmModelCreator {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique // e.g., "openai", "anthropic", "meta"
|
||||
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
|
||||
description String?
|
||||
websiteUrl String? // Link to creator's website
|
||||
logoUrl String? // URL to creator's logo
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModelMigration {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
sourceModelSlug String // The original model that was disabled
|
||||
targetModelSlug String // The model workflows were migrated to
|
||||
reason String? // Why the migration happened (e.g., "Provider outage")
|
||||
|
||||
// FK constraints ensure slugs reference valid models
|
||||
SourceModel LlmModel @relation("SourceMigrations", fields: [sourceModelSlug], references: [slug], onDelete: Restrict, onUpdate: Cascade)
|
||||
TargetModel LlmModel @relation("TargetMigrations", fields: [targetModelSlug], references: [slug], onDelete: Restrict, onUpdate: Cascade)
|
||||
|
||||
// Track affected nodes as JSON array of node IDs
|
||||
// Format: ["node-uuid-1", "node-uuid-2", ...]
|
||||
migratedNodeIds Json @default("[]")
|
||||
nodeCount Int // Number of nodes migrated (DB constraint: >= 0)
|
||||
|
||||
// Custom pricing override for migrated workflows during the migration period.
|
||||
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
|
||||
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
|
||||
// to avoid billing surprises, or offer a discount during the transition.
|
||||
//
|
||||
// IMPORTANT: This field is intended for integration with the billing system.
|
||||
// When billing calculates costs for nodes affected by this migration, it should
|
||||
// check if customCreditCost is set and use it instead of the target model's cost.
|
||||
// If null, the target model's normal cost applies.
|
||||
//
|
||||
// TODO: Integrate with billing system to apply this override during cost calculation.
|
||||
// LIMITATION: This is a simple Int and doesn't distinguish RUN vs TOKENS pricing.
|
||||
// For token-priced models, this may be ambiguous. Consider migrating to a relation
|
||||
// with LlmModelCost or a dedicated override model in a follow-up PR.
|
||||
customCreditCost Int? // DB constraint: >= 0 when not null
|
||||
|
||||
// Revert tracking
|
||||
isReverted Boolean @default(false)
|
||||
revertedAt DateTime?
|
||||
|
||||
// Note: Partial unique index in migration SQL prevents multiple active migrations per source:
|
||||
// UNIQUE (sourceModelSlug) WHERE isReverted = false
|
||||
@@index([targetModelSlug])
|
||||
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user