mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
60 Commits
fix/artifa
...
feat/llm-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
648a2c76a0 | ||
|
|
a4849038c3 | ||
|
|
3d81f3f42a | ||
|
|
f435187138 | ||
|
|
a5a192d334 | ||
|
|
00bb31c0f4 | ||
|
|
dcf207eb72 | ||
|
|
da533089d3 | ||
|
|
5352dc9778 | ||
|
|
fc932aa415 | ||
|
|
405bdb2808 | ||
|
|
6fcd05ef61 | ||
|
|
33c30c6990 | ||
|
|
52d074d31f | ||
|
|
93be8e5095 | ||
|
|
b0d9ef13e6 | ||
|
|
d5a0ce2815 | ||
|
|
3f6b1120f3 | ||
|
|
77757a25a5 | ||
|
|
e192695884 | ||
|
|
5b2d4595d1 | ||
|
|
4e1774c939 | ||
|
|
845ce6ae8d | ||
|
|
84f30775fd | ||
|
|
62065292ec | ||
|
|
dff9b0f3b2 | ||
|
|
fa47d898d1 | ||
|
|
67455f6a35 | ||
|
|
cb8cf81be7 | ||
|
|
ef30c1ed76 | ||
|
|
b5f63c13a4 | ||
|
|
c5dfe3333d | ||
|
|
696b273afc | ||
|
|
732365cd8f | ||
|
|
7e85371ce5 | ||
|
|
3f964c8aba | ||
|
|
ad1f489c5c | ||
|
|
9c77a2207f | ||
|
|
081fa9f2db | ||
|
|
8b6dea2496 | ||
|
|
5e15213846 | ||
|
|
f0cc4ae573 | ||
|
|
e0282b00db | ||
|
|
9a9c36b806 | ||
|
|
d5381625cd | ||
|
|
f6ae3d6593 | ||
|
|
0fb1b854df | ||
|
|
64a011664a | ||
|
|
1db7c048d9 | ||
|
|
4c5627c966 | ||
|
|
d97d137a51 | ||
|
|
ded9e293ff | ||
|
|
83d504bed2 | ||
|
|
a5f1ffb35b | ||
|
|
97c6516a14 | ||
|
|
876dde8bc7 | ||
|
|
0bfdd74b25 | ||
|
|
a7d2f81b18 | ||
|
|
3699eaa556 | ||
|
|
21adf9e0fb |
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import platform
|
||||
@@ -37,8 +38,10 @@ import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.llm_registry
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.v2.llm
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.library.exceptions import (
|
||||
@@ -117,16 +120,47 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
_registry_subscription_task: asyncio.Task | None = None
|
||||
|
||||
try:
|
||||
await backend.data.llm_registry.refresh_llm_registry()
|
||||
logger.info("LLM registry loaded successfully at startup")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load LLM registry at startup: {e}. "
|
||||
"Blocks will initialize with empty registry."
|
||||
)
|
||||
|
||||
_registry_subscription_task = asyncio.create_task(
|
||||
backend.data.llm_registry.subscribe_to_registry_refresh(
|
||||
backend.data.llm_registry.refresh_llm_registry
|
||||
)
|
||||
)
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
try:
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
except Exception as e:
|
||||
if "AgentNode" in str(e):
|
||||
logger.warning("migrate_llm_models skipped: AgentNode table not found (%s)", e)
|
||||
else:
|
||||
logger.error("migrate_llm_models failed unexpectedly: %s", e, exc_info=True)
|
||||
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
if _registry_subscription_task:
|
||||
_registry_subscription_task.cancel()
|
||||
try:
|
||||
await _registry_subscription_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
@@ -355,6 +389,16 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.admin_router,
|
||||
tags=["v2", "llm", "admin"],
|
||||
prefix="/api",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -36,9 +36,9 @@ from backend.util.models import Pagination
|
||||
from backend.util.request import parse_url
|
||||
|
||||
from .block import BlockInput
|
||||
from .db import BaseDbModel
|
||||
from .db import BaseDbModel, execute_raw_with_schema
|
||||
from .db import prisma as db
|
||||
from .db import query_raw_with_schema, transaction
|
||||
from .db import execute_raw_with_schema, query_raw_with_schema, transaction
|
||||
from .dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, MAX_GRAPH_VERSIONS_FETCH
|
||||
from .model import CredentialsFieldInfo, CredentialsMetaInput, is_credentials_field_name
|
||||
@@ -1663,22 +1663,19 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
llm_model_fields[block.id] = field_name
|
||||
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel]
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
enum_values = repr(tuple(v.value for v in LlmModel))
|
||||
|
||||
# Update each block
|
||||
for id, path in llm_model_fields.items():
|
||||
query = f"""
|
||||
UPDATE platform."AgentNode"
|
||||
query = """
|
||||
UPDATE {schema_prefix}"AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? ($4)::text
|
||||
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
|
||||
"""
|
||||
AND "constantInput"->>($4)::text NOT IN """ + enum_values
|
||||
|
||||
await db.execute_raw(
|
||||
query, # type: ignore - is supposed to be LiteralString
|
||||
await execute_raw_with_schema(
|
||||
query,
|
||||
[path],
|
||||
migrate_to.value,
|
||||
id,
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""LLM Registry - Dynamic model management system."""
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from .notifications import (
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
from .registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
clear_registry_cache,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
get_enabled_models,
|
||||
get_model,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Cache management
|
||||
"clear_registry_cache",
|
||||
"publish_registry_refresh_notification",
|
||||
"subscribe_to_registry_refresh",
|
||||
# Read functions
|
||||
"refresh_llm_registry",
|
||||
"get_model",
|
||||
"get_all_models",
|
||||
"get_enabled_models",
|
||||
"get_schema_options",
|
||||
"get_default_model_slug",
|
||||
"get_all_model_slugs_for_validation",
|
||||
]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Pub/sub notifications for LLM registry cross-process synchronisation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
async def publish_registry_refresh_notification() -> None:
|
||||
"""Publish a refresh signal so all other workers reload their in-process cache."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.debug("Published LLM registry refresh notification")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish registry refresh notification: %s", e)
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Listen for registry refresh signals and call on_refresh each time one arrives.
|
||||
|
||||
Designed to run as a long-lived background asyncio.Task. Automatically
|
||||
reconnects if the Redis connection drops.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable invoked on each refresh signal.
|
||||
Typically ``llm_registry.refresh_llm_registry``.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Dedicated connection — pub/sub must not share a connection used
|
||||
# for regular commands.
|
||||
redis_sub = AsyncRedis(
|
||||
host=HOST, port=PORT, password=PASSWORD, decode_responses=True
|
||||
)
|
||||
pubsub = redis_sub.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info("Subscribed to LLM registry refresh channel")
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if (
|
||||
message
|
||||
and message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.debug("LLM registry refresh signal received")
|
||||
try:
|
||||
await on_refresh()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in registry on_refresh callback: %s", e
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error processing registry refresh message: %s", e
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("LLM registry subscription task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"LLM registry subscription error: %s. Retrying in 5s...", e
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
@@ -0,0 +1,195 @@
|
||||
"""Tests for LLM registry pub/sub notifications (notifications.py).
|
||||
|
||||
Covers:
|
||||
- publish_registry_refresh_notification: happy path and Redis error swallowed
|
||||
- subscribe_to_registry_refresh: message triggers on_refresh, non-message
|
||||
types ignored, wrong channel ignored, CancelledError stops the loop,
|
||||
connection errors trigger reconnect
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.llm_registry.notifications import (
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# publish_registry_refresh_notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_sends_to_correct_channel(mocker):
|
||||
"""publish_registry_refresh_notification publishes on the registry channel."""
|
||||
mock_redis = AsyncMock()
|
||||
mocker.patch(
|
||||
"backend.data.redis_client.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
await publish_registry_refresh_notification()
|
||||
|
||||
mock_redis.publish.assert_called_once_with(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_swallows_redis_error(mocker):
|
||||
"""Redis errors during publish are caught and logged, not raised."""
|
||||
mocker.patch(
|
||||
"backend.data.redis_client.get_redis_async",
|
||||
side_effect=ConnectionError("Redis unavailable"),
|
||||
)
|
||||
|
||||
# Should not raise — errors are swallowed to avoid crashing the admin op
|
||||
await publish_registry_refresh_notification()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# subscribe_to_registry_refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_pubsub(messages: list) -> MagicMock:
|
||||
"""Build a mock pubsub that returns messages in sequence then raises CancelledError."""
|
||||
pubsub = AsyncMock()
|
||||
pubsub.subscribe = AsyncMock()
|
||||
# Once messages are exhausted the next get_message raises CancelledError to
|
||||
# break the infinite loop cleanly in tests.
|
||||
pubsub.get_message = AsyncMock(
|
||||
side_effect=messages + [asyncio.CancelledError()]
|
||||
)
|
||||
return pubsub
|
||||
|
||||
|
||||
def _make_message(channel: str = REGISTRY_REFRESH_CHANNEL, msg_type: str = "message"):
|
||||
return {"type": msg_type, "channel": channel, "data": "refresh"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_calls_on_refresh_for_valid_message(mocker):
|
||||
"""A message on the registry channel triggers the on_refresh callback."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([_make_message()])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_ignores_non_message_types(mocker):
|
||||
"""Subscribe messages of type 'subscribe' (handshake) do not trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([
|
||||
_make_message(msg_type="subscribe"), # handshake — should be ignored
|
||||
_make_message(msg_type="psubscribe"), # also ignored
|
||||
])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_ignores_wrong_channel(mocker):
|
||||
"""Messages on a different channel do not trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([_make_message(channel="some:other:channel")])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_handles_none_message(mocker):
|
||||
"""None returned by get_message (timeout) does not crash or trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([None, None]) # two timeouts, then CancelledError
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_processes_multiple_messages(mocker):
|
||||
"""Multiple valid messages each trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([_make_message(), _make_message(), _make_message()])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
assert on_refresh.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_cancelled_error_stops_loop(mocker):
|
||||
"""CancelledError at the outer loop level causes the function to return cleanly."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = AsyncMock()
|
||||
pubsub.subscribe = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
# Should return normally — CancelledError is caught and the loop breaks
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_reconnects_after_connection_error(mocker):
|
||||
"""A connection error on the first attempt triggers a reconnect attempt."""
|
||||
on_refresh = AsyncMock()
|
||||
|
||||
# First call raises ConnectionError; second call succeeds then CancelledError
|
||||
good_pubsub = _make_pubsub([_make_message()])
|
||||
bad_redis = MagicMock()
|
||||
bad_redis.pubsub.side_effect = ConnectionError("Redis down")
|
||||
good_redis = MagicMock()
|
||||
good_redis.pubsub.return_value = good_pubsub
|
||||
|
||||
mock_redis_cls = mocker.patch(
|
||||
"redis.asyncio.Redis", side_effect=[bad_redis, good_redis]
|
||||
)
|
||||
mock_sleep = mocker.patch("asyncio.sleep", new=AsyncMock())
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
# Should have slept before retrying
|
||||
mock_sleep.assert_called_once_with(5)
|
||||
# Should have tried to create two Redis connections
|
||||
assert mock_redis_cls.call_count == 2
|
||||
# After reconnect, the valid message triggered on_refresh
|
||||
on_refresh.assert_called_once()
|
||||
254
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
254
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import prisma.models
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegistryModelCost(BaseModel):
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None = None
|
||||
credential_type: str | None = None
|
||||
currency: str | None = None
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
|
||||
class RegistryModelCreator(BaseModel):
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
|
||||
|
||||
class RegistryModel(BaseModel):
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any] = {}
|
||||
extra_metadata: dict[str, Any] = {}
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = ()
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
# Typed capability fields from DB schema
|
||||
supports_tools: bool = False
|
||||
supports_json_output: bool = False
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool_calls: bool = False
|
||||
|
||||
|
||||
# L1 in-process cache — Redis is the shared L2 via @cached(shared_cache=True)
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: # type: ignore[name-defined]
|
||||
"""Transform a raw Prisma LlmModel record into a RegistryModel instance."""
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
unit=str(cost.unit),
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=dict(cost.metadata or {}),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
capabilities = dict(record.capabilities or {})
|
||||
|
||||
if not record.Provider:
|
||||
logger.warning(
|
||||
"LlmModel %s has no Provider despite NOT NULL FK - "
|
||||
"falling back to providerId %s",
|
||||
record.slug,
|
||||
record.providerId,
|
||||
)
|
||||
provider_name = record.Provider.name if record.Provider else record.providerId
|
||||
provider_display = (
|
||||
record.Provider.displayName if record.Provider else record.providerId
|
||||
)
|
||||
creator_name = record.Creator.displayName if record.Creator else "Unknown"
|
||||
|
||||
if record.priceTier not in (1, 2, 3):
|
||||
logger.warning(
|
||||
"LlmModel %s has out-of-range priceTier=%s, defaulting to 1",
|
||||
record.slug,
|
||||
record.priceTier,
|
||||
)
|
||||
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=(
|
||||
record.maxOutputTokens
|
||||
if record.maxOutputTokens is not None
|
||||
else record.contextWindow
|
||||
),
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier,
|
||||
)
|
||||
|
||||
return RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=capabilities,
|
||||
extra_metadata=dict(record.metadata or {}),
|
||||
provider_display_name=provider_display,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
supports_tools=record.supportsTools,
|
||||
supports_json_output=record.supportsJsonOutput,
|
||||
supports_reasoning=record.supportsReasoning,
|
||||
supports_parallel_tool_calls=record.supportsParallelToolCalls,
|
||||
)
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl_seconds=300, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def _fetch_registry_from_db() -> list[RegistryModel]:
|
||||
"""Fetch all LLM models from the database.
|
||||
|
||||
Results are cached in Redis (shared_cache=True) so subsequent calls within
|
||||
the TTL window skip the DB entirely — both within this process and across
|
||||
all other workers that share the same Redis instance.
|
||||
"""
|
||||
records = await prisma.models.LlmModel.prisma().find_many( # type: ignore[attr-defined]
|
||||
include={"Provider": True, "Costs": True, "Creator": True}
|
||||
)
|
||||
logger.info("Fetched %d LLM models from database", len(records))
|
||||
return [_record_to_registry_model(r) for r in records]
|
||||
|
||||
|
||||
def clear_registry_cache() -> None:
|
||||
"""Invalidate the shared Redis cache for the registry DB fetch.
|
||||
|
||||
Call this before refresh_llm_registry() after any admin DB mutation so the
|
||||
next fetch hits the database rather than serving the now-stale cached data.
|
||||
"""
|
||||
_fetch_registry_from_db.cache_clear()
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the in-process L1 cache from Redis/DB.
|
||||
|
||||
On the first call (or after clear_registry_cache()), fetches fresh data
|
||||
from the database and stores it in Redis. Subsequent calls by other
|
||||
workers hit the Redis cache instead of the DB.
|
||||
"""
|
||||
async with _lock:
|
||||
try:
|
||||
models = await _fetch_registry_from_db()
|
||||
new_models = {m.slug: m for m in models}
|
||||
|
||||
global _dynamic_models, _schema_options
|
||||
_dynamic_models = new_models
|
||||
_schema_options = _build_schema_options()
|
||||
|
||||
logger.info(
|
||||
"LLM registry refreshed: %d models, %d schema options",
|
||||
len(_dynamic_models),
|
||||
len(_schema_options),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to refresh LLM registry: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
return [
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
for model in sorted(
|
||||
_dynamic_models.values(), key=lambda m: m.display_name.lower()
|
||||
)
|
||||
if model.is_enabled
|
||||
]
|
||||
|
||||
|
||||
def get_model(slug: str) -> RegistryModel | None:
|
||||
"""Get a model by slug from the registry."""
|
||||
return _dynamic_models.get(slug)
|
||||
|
||||
|
||||
def get_all_models() -> list[RegistryModel]:
|
||||
"""Get all models from the registry (including disabled)."""
|
||||
return list(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_enabled_models() -> list[RegistryModel]:
|
||||
"""Get only enabled models from the registry."""
|
||||
return [model for model in _dynamic_models.values() if model.is_enabled]
|
||||
|
||||
|
||||
def get_schema_options() -> list[dict[str, str]]:
|
||||
"""Get schema options for model selection dropdown (enabled models only)."""
|
||||
return list(_schema_options)
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""Get the default model slug (first recommended, or first enabled)."""
|
||||
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
|
||||
recommended = next(
|
||||
(m.slug for m in models if m.is_recommended and m.is_enabled), None
|
||||
)
|
||||
return recommended or next((m.slug for m in models if m.is_enabled), None)
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> list[str]:
|
||||
"""Get all model slugs for validation (enabled models only)."""
|
||||
return [model.slug for model in _dynamic_models.values() if model.is_enabled]
|
||||
@@ -0,0 +1,466 @@
|
||||
"""Unit tests for the LLM registry module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import pydantic
|
||||
|
||||
from backend.data.llm_registry.registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
_build_schema_options,
|
||||
_record_to_registry_model,
|
||||
clear_registry_cache,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
get_enabled_models,
|
||||
get_model,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_record(**overrides):
|
||||
"""Build a realistic mock Prisma LlmModel record."""
|
||||
provider = Mock()
|
||||
provider.name = "openai"
|
||||
provider.displayName = "OpenAI"
|
||||
|
||||
record = Mock()
|
||||
record.slug = "openai/gpt-4o"
|
||||
record.displayName = "GPT-4o"
|
||||
record.description = "Latest GPT model"
|
||||
record.providerId = "provider-uuid"
|
||||
record.Provider = provider
|
||||
record.creatorId = "creator-uuid"
|
||||
record.Creator = None
|
||||
record.contextWindow = 128000
|
||||
record.maxOutputTokens = 16384
|
||||
record.priceTier = 2
|
||||
record.isEnabled = True
|
||||
record.isRecommended = False
|
||||
record.supportsTools = True
|
||||
record.supportsJsonOutput = True
|
||||
record.supportsReasoning = False
|
||||
record.supportsParallelToolCalls = True
|
||||
record.capabilities = {}
|
||||
record.metadata = {}
|
||||
record.Costs = []
|
||||
|
||||
for key, value in overrides.items():
|
||||
setattr(record, key, value)
|
||||
return record
|
||||
|
||||
|
||||
def _make_registry_model(**kwargs) -> RegistryModel:
|
||||
"""Build a minimal RegistryModel for testing registry-level functions."""
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
|
||||
defaults = dict(
|
||||
slug="openai/gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description=None,
|
||||
metadata=ModelMetadata(
|
||||
provider="openai",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
display_name="GPT-4o",
|
||||
provider_name="OpenAI",
|
||||
creator_name="Unknown",
|
||||
price_tier=2,
|
||||
),
|
||||
capabilities={},
|
||||
extra_metadata={},
|
||||
provider_display_name="OpenAI",
|
||||
is_enabled=True,
|
||||
is_recommended=False,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return RegistryModel(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _record_to_registry_model tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_record_to_registry_model():
|
||||
"""Happy-path: well-formed record produces a correct RegistryModel."""
|
||||
record = _make_mock_record()
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.slug == "openai/gpt-4o"
|
||||
assert model.display_name == "GPT-4o"
|
||||
assert model.description == "Latest GPT model"
|
||||
assert model.provider_display_name == "OpenAI"
|
||||
assert model.is_enabled is True
|
||||
assert model.is_recommended is False
|
||||
assert model.supports_tools is True
|
||||
assert model.supports_json_output is True
|
||||
assert model.supports_reasoning is False
|
||||
assert model.supports_parallel_tool_calls is True
|
||||
assert model.metadata.provider == "openai"
|
||||
assert model.metadata.context_window == 128000
|
||||
assert model.metadata.max_output_tokens == 16384
|
||||
assert model.metadata.price_tier == 2
|
||||
assert model.creator is None
|
||||
assert model.costs == ()
|
||||
|
||||
|
||||
def test_record_to_registry_model_missing_provider(caplog):
|
||||
"""Record with no Provider relation falls back to providerId and logs a warning."""
|
||||
record = _make_mock_record(Provider=None, providerId="provider-uuid")
|
||||
with caplog.at_level("WARNING"):
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert "no Provider" in caplog.text
|
||||
assert model.metadata.provider == "provider-uuid"
|
||||
assert model.provider_display_name == "provider-uuid"
|
||||
|
||||
|
||||
def test_record_to_registry_model_missing_creator():
|
||||
"""When Creator is None, creator_name defaults to 'Unknown' and creator field is None."""
|
||||
record = _make_mock_record(Creator=None)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.creator is None
|
||||
assert model.metadata.creator_name == "Unknown"
|
||||
|
||||
|
||||
def test_record_to_registry_model_with_creator():
|
||||
"""When Creator is present, it is parsed into RegistryModelCreator."""
|
||||
creator_mock = Mock()
|
||||
creator_mock.id = "creator-uuid"
|
||||
creator_mock.name = "openai"
|
||||
creator_mock.displayName = "OpenAI"
|
||||
creator_mock.description = "AI company"
|
||||
creator_mock.websiteUrl = "https://openai.com"
|
||||
creator_mock.logoUrl = "https://openai.com/logo.png"
|
||||
|
||||
record = _make_mock_record(Creator=creator_mock)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.creator is not None
|
||||
assert isinstance(model.creator, RegistryModelCreator)
|
||||
assert model.creator.id == "creator-uuid"
|
||||
assert model.creator.display_name == "OpenAI"
|
||||
assert model.metadata.creator_name == "OpenAI"
|
||||
|
||||
|
||||
def test_record_to_registry_model_null_max_output_tokens():
|
||||
"""maxOutputTokens=None falls back to contextWindow."""
|
||||
record = _make_mock_record(maxOutputTokens=None, contextWindow=64000)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.metadata.max_output_tokens == 64000
|
||||
|
||||
|
||||
def test_record_to_registry_model_invalid_price_tier(caplog):
|
||||
"""Out-of-range priceTier is coerced to 1 and a warning is logged."""
|
||||
record = _make_mock_record(priceTier=99)
|
||||
with caplog.at_level("WARNING"):
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert "out-of-range priceTier" in caplog.text
|
||||
assert model.metadata.price_tier == 1
|
||||
|
||||
|
||||
def test_record_to_registry_model_with_costs():
|
||||
"""Costs are parsed into RegistryModelCost tuples."""
|
||||
cost_mock = Mock()
|
||||
cost_mock.unit = "TOKENS"
|
||||
cost_mock.creditCost = 10
|
||||
cost_mock.credentialProvider = "openai"
|
||||
cost_mock.credentialId = None
|
||||
cost_mock.credentialType = None
|
||||
cost_mock.currency = "USD"
|
||||
cost_mock.metadata = {}
|
||||
|
||||
record = _make_mock_record(Costs=[cost_mock])
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert len(model.costs) == 1
|
||||
cost = model.costs[0]
|
||||
assert isinstance(cost, RegistryModelCost)
|
||||
assert cost.unit == "TOKENS"
|
||||
assert cost.credit_cost == 10
|
||||
assert cost.credential_provider == "openai"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_default_model_slug tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_default_model_slug_recommended():
|
||||
"""Recommended model is preferred over non-recommended enabled models."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_recommended=False
|
||||
),
|
||||
"openai/gpt-4o-recommended": _make_registry_model(
|
||||
slug="openai/gpt-4o-recommended",
|
||||
display_name="GPT-4o Recommended",
|
||||
is_recommended=True,
|
||||
),
|
||||
}
|
||||
|
||||
result = get_default_model_slug()
|
||||
assert result == "openai/gpt-4o-recommended"
|
||||
|
||||
|
||||
def test_get_default_model_slug_fallback():
|
||||
"""With no recommended model, falls back to first enabled (alphabetical)."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_recommended=False
|
||||
),
|
||||
"openai/gpt-3.5": _make_registry_model(
|
||||
slug="openai/gpt-3.5", display_name="GPT-3.5", is_recommended=False
|
||||
),
|
||||
}
|
||||
|
||||
result = get_default_model_slug()
|
||||
# Sorted alphabetically: GPT-3.5 < GPT-4o
|
||||
assert result == "openai/gpt-3.5"
|
||||
|
||||
|
||||
def test_get_default_model_slug_empty():
|
||||
"""Empty registry returns None."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {}
|
||||
|
||||
result = get_default_model_slug()
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_schema_options / get_schema_options tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_schema_options():
|
||||
"""Only enabled models appear, sorted case-insensitively."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_enabled=True
|
||||
),
|
||||
"openai/disabled": _make_registry_model(
|
||||
slug="openai/disabled", display_name="Disabled Model", is_enabled=False
|
||||
),
|
||||
"openai/gpt-3.5": _make_registry_model(
|
||||
slug="openai/gpt-3.5", display_name="gpt-3.5", is_enabled=True
|
||||
),
|
||||
}
|
||||
|
||||
options = _build_schema_options()
|
||||
slugs = [o["value"] for o in options]
|
||||
|
||||
# disabled model should be excluded
|
||||
assert "openai/disabled" not in slugs
|
||||
# only enabled models
|
||||
assert "openai/gpt-4o" in slugs
|
||||
assert "openai/gpt-3.5" in slugs
|
||||
# case-insensitive sort: "gpt-3.5" < "GPT-4o" (both lowercase: "gpt-3.5" < "gpt-4o")
|
||||
assert slugs.index("openai/gpt-3.5") < slugs.index("openai/gpt-4o")
|
||||
|
||||
# Verify structure
|
||||
for option in options:
|
||||
assert "label" in option
|
||||
assert "value" in option
|
||||
assert "group" in option
|
||||
assert "description" in option
|
||||
|
||||
|
||||
def test_get_schema_options_returns_copy():
|
||||
"""Mutating the returned list does not affect the internal cache."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(slug="openai/gpt-4o", display_name="GPT-4o"),
|
||||
}
|
||||
reg._schema_options = _build_schema_options()
|
||||
|
||||
options = get_schema_options()
|
||||
original_length = len(options)
|
||||
options.append({"label": "Injected", "value": "evil/model", "group": "evil", "description": ""})
|
||||
|
||||
# Internal state should be unchanged
|
||||
assert len(get_schema_options()) == original_length
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic frozen model tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registry_model_frozen():
|
||||
"""Pydantic frozen=True should reject attribute assignment."""
|
||||
model = _make_registry_model()
|
||||
|
||||
with pytest.raises((pydantic.ValidationError, TypeError)):
|
||||
model.slug = "changed/slug" # type: ignore[misc]
|
||||
|
||||
|
||||
def test_registry_model_cost_frozen():
|
||||
"""RegistryModelCost is also frozen."""
|
||||
cost = RegistryModelCost(
|
||||
unit="TOKENS",
|
||||
credit_cost=5,
|
||||
credential_provider="openai",
|
||||
)
|
||||
with pytest.raises((pydantic.ValidationError, TypeError)):
|
||||
cost.unit = "RUN" # type: ignore[misc]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# refresh_llm_registry tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_llm_registry():
|
||||
"""Mock prisma find_many, verify cache is populated after refresh."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
record = _make_mock_record()
|
||||
mock_find_many = AsyncMock(return_value=[record])
|
||||
|
||||
with patch("prisma.models.LlmModel.prisma") as mock_prisma_cls:
|
||||
mock_prisma_instance = Mock()
|
||||
mock_prisma_instance.find_many = mock_find_many
|
||||
mock_prisma_cls.return_value = mock_prisma_instance
|
||||
|
||||
# Clear state first
|
||||
reg._dynamic_models = {}
|
||||
reg._schema_options = []
|
||||
|
||||
await refresh_llm_registry()
|
||||
|
||||
assert "openai/gpt-4o" in reg._dynamic_models
|
||||
model = reg._dynamic_models["openai/gpt-4o"]
|
||||
assert isinstance(model, RegistryModel)
|
||||
assert model.slug == "openai/gpt-4o"
|
||||
# Schema options should be populated too
|
||||
assert len(reg._schema_options) == 1
|
||||
assert reg._schema_options[0]["value"] == "openai/gpt-4o"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# clear_registry_cache tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_clear_registry_cache():
|
||||
"""clear_registry_cache calls cache_clear on the cached fetch function."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch.object(reg._fetch_registry_from_db, "cache_clear") as mock_clear:
|
||||
clear_registry_cache()
|
||||
mock_clear.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_model / get_all_models / get_enabled_models / get_all_model_slugs tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_model_found():
|
||||
"""get_model returns the model when the slug exists in the registry."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
m = _make_registry_model(slug="openai/gpt-4o")
|
||||
reg._dynamic_models = {"openai/gpt-4o": m}
|
||||
|
||||
result = get_model("openai/gpt-4o")
|
||||
|
||||
assert result is m
|
||||
|
||||
|
||||
def test_get_model_not_found():
|
||||
"""get_model returns None for an unknown slug."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {}
|
||||
|
||||
assert get_model("nonexistent/model") is None
|
||||
|
||||
|
||||
def test_get_all_models_includes_disabled():
|
||||
"""get_all_models returns all models, including disabled ones."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
enabled = _make_registry_model(slug="openai/gpt-4", is_enabled=True)
|
||||
disabled = _make_registry_model(slug="openai/old-model", is_enabled=False)
|
||||
reg._dynamic_models = {"openai/gpt-4": enabled, "openai/old-model": disabled}
|
||||
|
||||
result = get_all_models()
|
||||
|
||||
slugs = [m.slug for m in result]
|
||||
assert "openai/gpt-4" in slugs
|
||||
assert "openai/old-model" in slugs
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_get_enabled_models_excludes_disabled():
|
||||
"""get_enabled_models returns only models where is_enabled=True."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
enabled = _make_registry_model(slug="openai/gpt-4", is_enabled=True)
|
||||
disabled = _make_registry_model(slug="openai/old-model", is_enabled=False)
|
||||
reg._dynamic_models = {"openai/gpt-4": enabled, "openai/old-model": disabled}
|
||||
|
||||
result = get_enabled_models()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].slug == "openai/gpt-4"
|
||||
|
||||
|
||||
def test_get_all_model_slugs_for_validation():
|
||||
"""get_all_model_slugs_for_validation returns only enabled model slugs."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4": _make_registry_model(slug="openai/gpt-4", is_enabled=True),
|
||||
"openai/old": _make_registry_model(slug="openai/old", is_enabled=False),
|
||||
"anthropic/claude": _make_registry_model(
|
||||
slug="anthropic/claude", is_enabled=True
|
||||
),
|
||||
}
|
||||
|
||||
result = get_all_model_slugs_for_validation()
|
||||
|
||||
assert "openai/gpt-4" in result
|
||||
assert "anthropic/claude" in result
|
||||
assert "openai/old" not in result
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_llm_registry_error_is_reraised(mocker):
|
||||
"""refresh_llm_registry re-raises exceptions after logging them."""
|
||||
mocker.patch(
|
||||
"backend.data.llm_registry.registry._fetch_registry_from_db",
|
||||
new=AsyncMock(side_effect=RuntimeError("DB unavailable")),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="DB unavailable"):
|
||||
await refresh_llm_registry()
|
||||
@@ -0,0 +1,6 @@
|
||||
"""LLM registry API (public + admin)."""
|
||||
|
||||
from .admin_routes import router as admin_router
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router", "admin_router"]
|
||||
145
autogpt_platform/backend/backend/server/v2/llm/admin_model.py
Normal file
145
autogpt_platform/backend/backend/server/v2/llm/admin_model.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Request/response models for LLM registry admin API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class CreateLlmProviderRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
default_credential_provider: str | None = None
|
||||
default_credential_id: str | None = None
|
||||
default_credential_type: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class UpdateLlmProviderRequest(pydantic.BaseModel):
|
||||
display_name: str | None = None
|
||||
description: str | None = None
|
||||
default_credential_provider: str | None = None
|
||||
default_credential_id: str | None = None
|
||||
default_credential_type: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CreateLlmModelRequest(pydantic.BaseModel):
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
provider_name: str
|
||||
creator_id: str | None = None
|
||||
context_window: int = pydantic.Field(gt=0)
|
||||
max_output_tokens: int | None = pydantic.Field(default=None, gt=0)
|
||||
price_tier: int = pydantic.Field(ge=1, le=3)
|
||||
is_enabled: bool = True
|
||||
is_recommended: bool = False
|
||||
supports_tools: bool = False
|
||||
supports_json_output: bool = False
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool_calls: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[dict[str, Any]] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateLlmModelRequest(pydantic.BaseModel):
|
||||
display_name: str | None = None
|
||||
description: str | None = None
|
||||
creator_id: str | None = None
|
||||
context_window: int | None = pydantic.Field(default=None, gt=0)
|
||||
max_output_tokens: int | None = pydantic.Field(default=None, gt=0)
|
||||
price_tier: int | None = pydantic.Field(default=None, ge=1, le=3)
|
||||
is_enabled: bool | None = None
|
||||
is_recommended: bool | None = None
|
||||
supports_tools: bool | None = None
|
||||
supports_json_output: bool | None = None
|
||||
supports_reasoning: bool | None = None
|
||||
supports_parallel_tool_calls: bool | None = None
|
||||
capabilities: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ToggleLlmModelRequest(pydantic.BaseModel):
|
||||
is_enabled: bool
|
||||
migrate_to_slug: str | None = None
|
||||
migration_reason: str | None = None
|
||||
custom_credit_cost: int | None = None
|
||||
|
||||
|
||||
class CreateLlmCreatorRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class UpdateLlmCreatorRequest(pydantic.BaseModel):
|
||||
display_name: str | None = None
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LlmCreatorAdminResponse(pydantic.BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
|
||||
|
||||
class LlmModelCostAdminResponse(pydantic.BaseModel):
|
||||
unit: str
|
||||
credit_cost: float
|
||||
credential_provider: str
|
||||
credential_type: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmProviderAdminResponse(pydantic.BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
default_credential_provider: str | None = None
|
||||
default_credential_id: str | None = None
|
||||
default_credential_type: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
model_count: int | None = None
|
||||
|
||||
|
||||
class LlmModelAdminResponse(pydantic.BaseModel):
|
||||
id: str
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
provider_id: str
|
||||
creator_id: str | None = None
|
||||
context_window: int
|
||||
max_output_tokens: int | None = None
|
||||
price_tier: int
|
||||
is_enabled: bool
|
||||
is_recommended: bool
|
||||
supports_tools: bool
|
||||
supports_json_output: bool
|
||||
supports_reasoning: bool
|
||||
supports_parallel_tool_calls: bool
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
creator: LlmCreatorAdminResponse | None = None
|
||||
costs: list[LlmModelCostAdminResponse] = pydantic.Field(default_factory=list)
|
||||
565
autogpt_platform/backend/backend/server/v2/llm/admin_routes.py
Normal file
565
autogpt_platform/backend/backend/server/v2/llm/admin_routes.py
Normal file
@@ -0,0 +1,565 @@
|
||||
"""Admin API for LLM registry management."""
|
||||
|
||||
import logging
|
||||
|
||||
import autogpt_libs.auth
|
||||
import prisma
|
||||
import prisma.models
|
||||
from fastapi import APIRouter, HTTPException, Security, status
|
||||
|
||||
from backend.server.v2.llm import db_write
|
||||
from backend.server.v2.llm.admin_model import (
|
||||
CreateLlmCreatorRequest,
|
||||
CreateLlmModelRequest,
|
||||
CreateLlmProviderRequest,
|
||||
LlmCreatorAdminResponse,
|
||||
LlmModelAdminResponse,
|
||||
LlmModelCostAdminResponse,
|
||||
LlmProviderAdminResponse,
|
||||
ToggleLlmModelRequest,
|
||||
UpdateLlmCreatorRequest,
|
||||
UpdateLlmModelRequest,
|
||||
UpdateLlmProviderRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _map_provider(provider: prisma.models.LlmProvider, model_count: int | None = None) -> LlmProviderAdminResponse:
|
||||
return LlmProviderAdminResponse(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
display_name=provider.displayName,
|
||||
description=provider.description,
|
||||
default_credential_provider=provider.defaultCredentialProvider,
|
||||
default_credential_id=provider.defaultCredentialId,
|
||||
default_credential_type=provider.defaultCredentialType,
|
||||
metadata=dict(provider.metadata or {}),
|
||||
created_at=provider.createdAt.isoformat() if provider.createdAt else None,
|
||||
updated_at=provider.updatedAt.isoformat() if provider.updatedAt else None,
|
||||
model_count=model_count,
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(creator: prisma.models.LlmModelCreator) -> LlmCreatorAdminResponse:
|
||||
return LlmCreatorAdminResponse(
|
||||
id=creator.id,
|
||||
name=creator.name,
|
||||
display_name=creator.displayName,
|
||||
description=creator.description,
|
||||
website_url=creator.websiteUrl,
|
||||
logo_url=creator.logoUrl,
|
||||
metadata=dict(creator.metadata or {}),
|
||||
created_at=creator.createdAt.isoformat() if creator.createdAt else None,
|
||||
updated_at=creator.updatedAt.isoformat() if creator.updatedAt else None,
|
||||
)
|
||||
|
||||
|
||||
def _map_model(model: prisma.models.LlmModel) -> LlmModelAdminResponse:
|
||||
return LlmModelAdminResponse(
|
||||
id=model.id,
|
||||
slug=model.slug,
|
||||
display_name=model.displayName,
|
||||
description=model.description,
|
||||
provider_id=model.providerId,
|
||||
creator_id=model.creatorId,
|
||||
context_window=model.contextWindow,
|
||||
max_output_tokens=model.maxOutputTokens,
|
||||
price_tier=model.priceTier,
|
||||
is_enabled=model.isEnabled,
|
||||
is_recommended=model.isRecommended,
|
||||
supports_tools=model.supportsTools,
|
||||
supports_json_output=model.supportsJsonOutput,
|
||||
supports_reasoning=model.supportsReasoning,
|
||||
supports_parallel_tool_calls=model.supportsParallelToolCalls,
|
||||
capabilities=dict(model.capabilities or {}),
|
||||
metadata=dict(model.metadata or {}),
|
||||
created_at=model.createdAt.isoformat() if model.createdAt else None,
|
||||
updated_at=model.updatedAt.isoformat() if model.updatedAt else None,
|
||||
creator=_map_creator(model.Creator) if model.Creator else None,
|
||||
costs=[
|
||||
LlmModelCostAdminResponse(
|
||||
unit=c.unit,
|
||||
credit_cost=float(c.creditCost),
|
||||
credential_provider=c.credentialProvider,
|
||||
credential_type=c.credentialType,
|
||||
metadata=dict(c.metadata or {}),
|
||||
)
|
||||
for c in (model.Costs or [])
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_model(request: CreateLlmModelRequest) -> LlmModelAdminResponse:
|
||||
try:
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": request.provider_name}
|
||||
)
|
||||
if not provider:
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": request.provider_name}
|
||||
)
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Provider '{request.provider_name}' not found",
|
||||
)
|
||||
|
||||
model = await db_write.create_model(
|
||||
slug=request.slug,
|
||||
display_name=request.display_name,
|
||||
provider_id=provider.id,
|
||||
context_window=request.context_window,
|
||||
price_tier=request.price_tier,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
if request.costs:
|
||||
for cost_input in request.costs:
|
||||
await prisma.models.LlmModelCost.prisma().create(
|
||||
data={
|
||||
"unit": cost_input.get("unit", "RUN"),
|
||||
"creditCost": int(cost_input.get("credit_cost", 1)),
|
||||
"credentialProvider": provider.name,
|
||||
"metadata": prisma.Json(cost_input.get("metadata", {})),
|
||||
"Model": {"connect": {"id": model.id}},
|
||||
}
|
||||
)
|
||||
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created model '{request.slug}' (id: {model.id})")
|
||||
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model.id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return _map_model(model)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create model")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/models/{slug:path}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_model(slug: str, request: UpdateLlmModelRequest) -> LlmModelAdminResponse:
|
||||
try:
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
model = await db_write.update_model(
|
||||
model_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
context_window=request.context_window,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
price_tier=request.price_tier,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated model '{slug}' (id: {model.id})")
|
||||
return _map_model(model)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update model")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/models/{slug:path}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_model(slug: str, replacement_model_slug: str | None = None) -> dict:
|
||||
try:
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
result = await db_write.delete_model(
|
||||
model_id=existing.id,
|
||||
replacement_model_slug=replacement_model_slug,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Deleted model '{slug}' (migrated {result['nodes_migrated']} nodes)"
|
||||
)
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete model")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/models/{slug:path}/usage",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def get_model_usage(slug: str) -> dict:
|
||||
try:
|
||||
return await db_write.get_model_usage(slug)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get model usage: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get model usage")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models/{slug:path}/toggle",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def toggle_model(slug: str, request: ToggleLlmModelRequest) -> dict:
|
||||
try:
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
if not request.is_enabled and existing.isRecommended:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"Cannot disable the recommended model. "
|
||||
"Change the recommended model before disabling this one."
|
||||
),
|
||||
)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id=existing.id,
|
||||
is_enabled=request.is_enabled,
|
||||
migrate_to_slug=request.migrate_to_slug,
|
||||
migration_reason=request.migration_reason,
|
||||
custom_credit_cost=request.custom_credit_cost,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Toggled model '{slug}' enabled={request.is_enabled} "
|
||||
f"(migrated {result['nodes_migrated']} nodes)"
|
||||
)
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to toggle model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to toggle model")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/migrations",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def list_migrations(include_reverted: bool = False) -> dict:
|
||||
try:
|
||||
migrations = await db_write.list_migrations(include_reverted=include_reverted)
|
||||
return {"migrations": migrations}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list migrations: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list migrations")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/migrations/{migration_id}/revert",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def revert_migration(migration_id: str, re_enable_source_model: bool = True) -> dict:
|
||||
try:
|
||||
result = await db_write.revert_migration(
|
||||
migration_id=migration_id,
|
||||
re_enable_source_model=re_enable_source_model,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Reverted migration {migration_id}: "
|
||||
f"{result['nodes_reverted']} nodes restored"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to revert migration: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to revert migration")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/providers",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_provider(request: CreateLlmProviderRequest) -> LlmProviderAdminResponse:
|
||||
try:
|
||||
provider = await db_write.create_provider(
|
||||
name=request.name,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created provider '{request.name}' (id: {provider.id})")
|
||||
return _map_provider(provider)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create provider")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/providers/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_provider(name: str, request: UpdateLlmProviderRequest) -> LlmProviderAdminResponse:
|
||||
try:
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
provider = await db_write.update_provider(
|
||||
provider_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated provider '{name}' (id: {provider.id})")
|
||||
return _map_provider(provider)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update provider")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/providers/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_provider(name: str) -> None:
|
||||
try:
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
await db_write.delete_provider(provider_id=existing.id)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Deleted provider '{name}' (id: {existing.id})")
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete provider")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/admin/providers",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def admin_list_providers() -> dict:
|
||||
try:
|
||||
providers = await prisma.models.LlmProvider.prisma().find_many(
|
||||
order={"name": "asc"},
|
||||
include={"Models": True},
|
||||
)
|
||||
return {
|
||||
"providers": [
|
||||
_map_provider(p, model_count=len(p.Models) if p.Models else 0).model_dump()
|
||||
for p in providers
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list providers: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list providers")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/admin/models",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def admin_list_models(
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
enabled_only: bool = False,
|
||||
) -> dict:
|
||||
try:
|
||||
where = {"isEnabled": True} if enabled_only else {}
|
||||
models = await prisma.models.LlmModel.prisma().find_many(
|
||||
where=where,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
order={"displayName": "asc"},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return {"models": [_map_model(m).model_dump() for m in models]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list models: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list models")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/creators",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def list_creators() -> dict:
|
||||
try:
|
||||
creators = await prisma.models.LlmModelCreator.prisma().find_many(
|
||||
order={"name": "asc"}
|
||||
)
|
||||
return {"creators": [_map_creator(c).model_dump() for c in creators]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list creators: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list creators")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/creators",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_creator(request: CreateLlmCreatorRequest) -> LlmCreatorAdminResponse:
|
||||
try:
|
||||
creator = await prisma.models.LlmModelCreator.prisma().create(
|
||||
data={
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"websiteUrl": request.website_url,
|
||||
"logoUrl": request.logo_url,
|
||||
"metadata": prisma.Json(request.metadata),
|
||||
}
|
||||
)
|
||||
logger.info(f"Created creator '{creator.name}' (id: {creator.id})")
|
||||
return _map_creator(creator)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/creators/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_creator(name: str, request: UpdateLlmCreatorRequest) -> LlmCreatorAdminResponse:
|
||||
try:
|
||||
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail=f"Creator '{name}' not found")
|
||||
|
||||
data: dict = {}
|
||||
if request.display_name is not None:
|
||||
data["displayName"] = request.display_name
|
||||
if request.description is not None:
|
||||
data["description"] = request.description
|
||||
if request.website_url is not None:
|
||||
data["websiteUrl"] = request.website_url
|
||||
if request.logo_url is not None:
|
||||
data["logoUrl"] = request.logo_url
|
||||
if request.metadata is not None:
|
||||
data["metadata"] = prisma.Json(request.metadata)
|
||||
|
||||
creator = await prisma.models.LlmModelCreator.prisma().update(
|
||||
where={"id": existing.id},
|
||||
data=data,
|
||||
)
|
||||
logger.info(f"Updated creator '{name}' (id: {creator.id})")
|
||||
return _map_creator(creator)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/creators/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_creator(name: str) -> None:
|
||||
try:
|
||||
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"name": name},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail=f"Creator '{name}' not found")
|
||||
|
||||
if existing.Models and len(existing.Models) > 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot delete creator '{name}' — it has {len(existing.Models)} associated models",
|
||||
)
|
||||
|
||||
await prisma.models.LlmModelCreator.prisma().delete(where={"id": existing.id})
|
||||
logger.info(f"Deleted creator '{name}' (id: {existing.id})")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,701 @@
|
||||
"""Tests for LLM registry admin CRUD routes (admin_routes.py).
|
||||
|
||||
Covers provider, model, creator, migration CRUD endpoints.
|
||||
All endpoints require admin authentication.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.llm.admin_routes import router as admin_router
|
||||
|
||||
admin_app = fastapi.FastAPI()
|
||||
admin_app.include_router(admin_router)
|
||||
admin_client = fastapi.testclient.TestClient(admin_app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_admin):
|
||||
"""Bypass JWT admin auth for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
admin_app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
admin_app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock factory helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NOW = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _make_mock_provider(
|
||||
id: str = "prov-1",
|
||||
name: str = "openai",
|
||||
display_name: str = "OpenAI",
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
models: list | None = None,
|
||||
) -> Mock:
|
||||
p = Mock()
|
||||
p.id = id
|
||||
p.name = name
|
||||
p.displayName = display_name
|
||||
p.description = description
|
||||
p.defaultCredentialProvider = default_credential_provider
|
||||
p.defaultCredentialId = default_credential_id
|
||||
p.defaultCredentialType = default_credential_type
|
||||
p.metadata = {}
|
||||
p.createdAt = _NOW
|
||||
p.updatedAt = _NOW
|
||||
p.Models = models if models is not None else []
|
||||
return p
|
||||
|
||||
|
||||
def _make_mock_model(
|
||||
id: str = "model-1",
|
||||
slug: str = "gpt-4",
|
||||
display_name: str = "GPT-4",
|
||||
description: str | None = None,
|
||||
provider_id: str = "prov-1",
|
||||
creator_id: str | None = None,
|
||||
context_window: int = 128000,
|
||||
max_output_tokens: int | None = 4096,
|
||||
price_tier: int = 2,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
costs: list | None = None,
|
||||
creator: Mock | None = None,
|
||||
) -> Mock:
|
||||
m = Mock()
|
||||
m.id = id
|
||||
m.slug = slug
|
||||
m.displayName = display_name
|
||||
m.description = description
|
||||
m.providerId = provider_id
|
||||
m.creatorId = creator_id
|
||||
m.contextWindow = context_window
|
||||
m.maxOutputTokens = max_output_tokens
|
||||
m.priceTier = price_tier
|
||||
m.isEnabled = is_enabled
|
||||
m.isRecommended = is_recommended
|
||||
m.supportsTools = False
|
||||
m.supportsJsonOutput = False
|
||||
m.supportsReasoning = False
|
||||
m.supportsParallelToolCalls = False
|
||||
m.capabilities = {}
|
||||
m.metadata = {}
|
||||
m.createdAt = _NOW
|
||||
m.updatedAt = _NOW
|
||||
m.Costs = costs or []
|
||||
m.Creator = creator
|
||||
return m
|
||||
|
||||
|
||||
def _make_mock_creator(
|
||||
id: str = "creator-1",
|
||||
name: str = "openai",
|
||||
display_name: str = "OpenAI",
|
||||
description: str | None = None,
|
||||
website_url: str | None = None,
|
||||
logo_url: str | None = None,
|
||||
models: list | None = None,
|
||||
) -> Mock:
|
||||
c = Mock()
|
||||
c.id = id
|
||||
c.name = name
|
||||
c.displayName = display_name
|
||||
c.description = description
|
||||
c.websiteUrl = website_url
|
||||
c.logoUrl = logo_url
|
||||
c.metadata = {}
|
||||
c.createdAt = _NOW
|
||||
c.updatedAt = _NOW
|
||||
c.Models = models if models is not None else []
|
||||
return c
|
||||
|
||||
|
||||
def _make_mock_migration(
|
||||
id: str = "mig-1",
|
||||
source_slug: str = "gpt-3",
|
||||
target_slug: str = "gpt-4",
|
||||
node_count: int = 3,
|
||||
is_reverted: bool = False,
|
||||
) -> dict:
|
||||
return {
|
||||
"id": id,
|
||||
"source_model_slug": source_slug,
|
||||
"target_model_slug": target_slug,
|
||||
"reason": "upgrade",
|
||||
"node_count": node_count,
|
||||
"custom_credit_cost": None,
|
||||
"is_reverted": is_reverted,
|
||||
"reverted_at": None,
|
||||
"created_at": _NOW.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_provider(mocker):
|
||||
"""POST /llm/providers creates a provider and returns 201 with provider fields."""
|
||||
mock_provider = _make_mock_provider()
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.create_provider = AsyncMock(return_value=mock_provider)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/providers",
|
||||
json={"name": "openai", "display_name": "OpenAI"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "openai"
|
||||
assert data["display_name"] == "OpenAI"
|
||||
mock_db.create_provider.assert_called_once()
|
||||
mock_db.refresh_runtime_caches.assert_called_once()
|
||||
|
||||
|
||||
def test_create_provider_validation_error(mocker):
|
||||
"""db_write.create_provider raising ValueError returns 400."""
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.create_provider = AsyncMock(side_effect=ValueError("duplicate name"))
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/providers",
|
||||
json={"name": "openai", "display_name": "OpenAI"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "duplicate name" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_provider(mocker):
|
||||
"""PATCH /llm/providers/{name} returns 200 with updated fields."""
|
||||
existing = _make_mock_provider()
|
||||
updated = _make_mock_provider(display_name="OpenAI Updated")
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.update_provider = AsyncMock(return_value=updated)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.patch(
|
||||
"/llm/providers/openai",
|
||||
json={"display_name": "OpenAI Updated"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["display_name"] == "OpenAI Updated"
|
||||
|
||||
|
||||
def test_update_provider_not_found(mocker):
|
||||
"""PATCH returns 404 when the provider does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
|
||||
response = admin_client.patch(
|
||||
"/llm/providers/nonexistent",
|
||||
json={"display_name": "X"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_provider(mocker):
|
||||
"""DELETE /llm/providers/{name} returns 204 on success."""
|
||||
existing = _make_mock_provider()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.delete_provider = AsyncMock(return_value=True)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.delete("/llm/providers/openai")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
def test_delete_provider_not_found(mocker):
|
||||
"""DELETE returns 404 when the provider does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
|
||||
response = admin_client.delete("/llm/providers/ghost")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_provider_has_models(mocker):
|
||||
"""DELETE returns 400 when db_write raises ValueError (provider has models)."""
|
||||
existing = _make_mock_provider()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.delete_provider = AsyncMock(
|
||||
side_effect=ValueError("Cannot delete provider — it has 2 model(s)")
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.delete("/llm/providers/openai")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "model" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_create_provider_server_error(mocker):
|
||||
"""POST /llm/providers returns 500 when an unexpected exception occurs."""
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.create_provider = AsyncMock(side_effect=RuntimeError("unexpected"))
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/providers",
|
||||
json={"name": "openai", "display_name": "OpenAI"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_admin_list_providers(mocker):
|
||||
"""GET /llm/admin/providers returns providers with model_count."""
|
||||
provider = _make_mock_provider(models=[_make_mock_model()])
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_many = AsyncMock(return_value=[provider])
|
||||
|
||||
response = admin_client.get("/llm/admin/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
assert len(providers) == 1
|
||||
assert providers[0]["model_count"] == 1
|
||||
assert providers[0]["name"] == "openai"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_model(mocker):
|
||||
"""POST /llm/models creates a model and returns 201 with model fields."""
|
||||
mock_provider = _make_mock_provider()
|
||||
mock_model = _make_mock_model()
|
||||
|
||||
# Patch both provider lookups (by name, then by id fallback) and refetch
|
||||
prisma_models_mock = mocker.patch("prisma.models")
|
||||
prisma_models_mock.LlmProvider.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=mock_provider
|
||||
)
|
||||
prisma_models_mock.LlmModel.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=mock_model
|
||||
)
|
||||
prisma_models_mock.LlmModelCost.prisma.return_value.create = AsyncMock()
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.create_model = AsyncMock(return_value=mock_model)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/models",
|
||||
json={
|
||||
"slug": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"provider_id": "openai",
|
||||
"context_window": 128000,
|
||||
"price_tier": 2,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["slug"] == "gpt-4"
|
||||
mock_db.create_model.assert_called_once()
|
||||
mock_db.refresh_runtime_caches.assert_called_once()
|
||||
|
||||
|
||||
def test_create_model_provider_not_found(mocker):
|
||||
"""POST /llm/models returns 404 when the provider is not found."""
|
||||
prisma_models_mock = mocker.patch("prisma.models")
|
||||
prisma_models_mock.LlmProvider.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/models",
|
||||
json={
|
||||
"slug": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"provider_id": "nonexistent",
|
||||
"context_window": 128000,
|
||||
"price_tier": 2,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_update_model(mocker):
|
||||
"""PATCH /llm/models/{slug} returns 200 with updated model."""
|
||||
existing = _make_mock_model()
|
||||
updated = _make_mock_model(display_name="GPT-4 Turbo")
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.update_model = AsyncMock(return_value=updated)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.patch(
|
||||
"/llm/models/gpt-4",
|
||||
json={"display_name": "GPT-4 Turbo"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["display_name"] == "GPT-4 Turbo"
|
||||
|
||||
|
||||
def test_update_model_not_found(mocker):
|
||||
"""PATCH returns 404 when the model slug does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
|
||||
response = admin_client.patch("/llm/models/unknown-slug", json={})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_model(mocker):
|
||||
"""DELETE /llm/models/{slug} without replacement returns 200 with nodes_migrated=0."""
|
||||
existing = _make_mock_model()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.delete_model = AsyncMock(
|
||||
return_value={
|
||||
"deleted_model_slug": "gpt-4",
|
||||
"deleted_model_display_name": "GPT-4",
|
||||
"replacement_model_slug": None,
|
||||
"nodes_migrated": 0,
|
||||
}
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.delete("/llm/models/gpt-4")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["nodes_migrated"] == 0
|
||||
assert data["deleted_model_slug"] == "gpt-4"
|
||||
|
||||
|
||||
def test_delete_model_with_migration(mocker):
|
||||
"""DELETE with replacement_model_slug query param migrates nodes and returns count."""
|
||||
existing = _make_mock_model()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.delete_model = AsyncMock(
|
||||
return_value={
|
||||
"deleted_model_slug": "gpt-3",
|
||||
"deleted_model_display_name": "GPT-3",
|
||||
"replacement_model_slug": "gpt-4",
|
||||
"nodes_migrated": 5,
|
||||
}
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.delete(
|
||||
"/llm/models/gpt-3?replacement_model_slug=gpt-4"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["nodes_migrated"] == 5
|
||||
mock_db.delete_model.assert_called_once_with(
|
||||
model_id=existing.id, replacement_model_slug="gpt-4"
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_usage(mocker):
|
||||
"""GET /llm/models/{slug}/usage returns node_count."""
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.get_model_usage = AsyncMock(
|
||||
return_value={"model_slug": "gpt-4", "node_count": 7}
|
||||
)
|
||||
|
||||
response = admin_client.get("/llm/models/gpt-4/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["node_count"] == 7
|
||||
assert data["model_slug"] == "gpt-4"
|
||||
|
||||
|
||||
def test_toggle_model_enable(mocker):
|
||||
"""POST /llm/models/{slug}/toggle with is_enabled=True toggles model and returns 200."""
|
||||
existing = _make_mock_model(is_enabled=False)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.toggle_model_with_migration = AsyncMock(
|
||||
return_value={"nodes_migrated": 0, "migrated_to_slug": None, "migration_id": None}
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/models/gpt-4/toggle",
|
||||
json={"is_enabled": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["nodes_migrated"] == 0
|
||||
|
||||
|
||||
def test_toggle_model_disable_with_migration(mocker):
|
||||
"""Disabling with migrate_to_slug passes migration args and returns nodes_migrated."""
|
||||
existing = _make_mock_model(is_enabled=True)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.toggle_model_with_migration = AsyncMock(
|
||||
return_value={
|
||||
"nodes_migrated": 3,
|
||||
"migrated_to_slug": "gpt-4-turbo",
|
||||
"migration_id": "mig-abc",
|
||||
}
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/models/gpt-4/toggle",
|
||||
json={"is_enabled": False, "migrate_to_slug": "gpt-4-turbo"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["nodes_migrated"] == 3
|
||||
assert data["migration_id"] == "mig-abc"
|
||||
mock_db.toggle_model_with_migration.assert_called_once_with(
|
||||
model_id=existing.id,
|
||||
is_enabled=False,
|
||||
migrate_to_slug="gpt-4-turbo",
|
||||
migration_reason=None,
|
||||
custom_credit_cost=None,
|
||||
)
|
||||
|
||||
|
||||
def test_toggle_model_not_found(mocker):
|
||||
"""POST toggle returns 404 when the model slug does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/models/ghost-model/toggle",
|
||||
json={"is_enabled": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Migrations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_migrations(mocker):
|
||||
"""GET /llm/migrations returns migrations list."""
|
||||
migrations = [_make_mock_migration(), _make_mock_migration(id="mig-2")]
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.list_migrations = AsyncMock(return_value=migrations)
|
||||
|
||||
response = admin_client.get("/llm/migrations")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["migrations"]) == 2
|
||||
assert data["migrations"][0]["id"] == "mig-1"
|
||||
|
||||
|
||||
def test_revert_migration(mocker):
|
||||
"""POST /llm/migrations/{id}/revert calls db_write and returns nodes_reverted."""
|
||||
mock_db = mocker.patch("backend.server.v2.llm.admin_routes.db_write")
|
||||
mock_db.revert_migration = AsyncMock(
|
||||
return_value={
|
||||
"migration_id": "mig-1",
|
||||
"source_model_slug": "gpt-3",
|
||||
"target_model_slug": "gpt-4",
|
||||
"nodes_reverted": 3,
|
||||
"nodes_already_changed": 0,
|
||||
"source_model_re_enabled": True,
|
||||
}
|
||||
)
|
||||
mock_db.refresh_runtime_caches = AsyncMock()
|
||||
|
||||
response = admin_client.post("/llm/migrations/mig-1/revert")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["nodes_reverted"] == 3
|
||||
assert data["migration_id"] == "mig-1"
|
||||
mock_db.revert_migration.assert_called_once_with(
|
||||
migration_id="mig-1", re_enable_source_model=True
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Creator CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_creator(mocker):
|
||||
"""POST /llm/creators returns 201 with creator fields."""
|
||||
mock_creator = _make_mock_creator()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelCreator.prisma"
|
||||
).return_value.create = AsyncMock(return_value=mock_creator)
|
||||
|
||||
response = admin_client.post(
|
||||
"/llm/creators",
|
||||
json={"name": "openai", "display_name": "OpenAI"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "openai"
|
||||
assert data["display_name"] == "OpenAI"
|
||||
|
||||
|
||||
def test_list_creators(mocker):
|
||||
"""GET /llm/creators returns creators list."""
|
||||
creators = [_make_mock_creator(), _make_mock_creator(id="c-2", name="anthropic")]
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelCreator.prisma"
|
||||
).return_value.find_many = AsyncMock(return_value=creators)
|
||||
|
||||
response = admin_client.get("/llm/creators")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["creators"]) == 2
|
||||
|
||||
|
||||
def test_update_creator(mocker):
|
||||
"""PATCH /llm/creators/{name} returns 200 with updated creator."""
|
||||
existing = _make_mock_creator()
|
||||
updated = _make_mock_creator(display_name="OpenAI Corp")
|
||||
prisma_mock = mocker.patch("prisma.models.LlmModelCreator.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=existing)
|
||||
prisma_mock.update = AsyncMock(return_value=updated)
|
||||
|
||||
response = admin_client.patch(
|
||||
"/llm/creators/openai",
|
||||
json={"display_name": "OpenAI Corp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["display_name"] == "OpenAI Corp"
|
||||
|
||||
|
||||
def test_update_creator_not_found(mocker):
|
||||
"""PATCH returns 404 when creator does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelCreator.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
response = admin_client.patch(
|
||||
"/llm/creators/nobody",
|
||||
json={"display_name": "Nobody"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_creator_success(mocker):
|
||||
"""DELETE /llm/creators/{name} returns 204 when creator has no models."""
|
||||
existing = _make_mock_creator(models=[])
|
||||
prisma_mock = mocker.patch("prisma.models.LlmModelCreator.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=existing)
|
||||
prisma_mock.delete = AsyncMock(return_value=existing)
|
||||
|
||||
response = admin_client.delete("/llm/creators/openai")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
def test_delete_creator_has_models(mocker):
|
||||
"""DELETE returns 400 when the creator still has associated models."""
|
||||
existing = _make_mock_creator(models=[_make_mock_model()])
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelCreator.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
response = admin_client.delete("/llm/creators/openai")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "models" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin model list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_admin_list_models(mocker):
|
||||
"""GET /llm/admin/models returns models list with creator and costs."""
|
||||
model = _make_mock_model()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_many = AsyncMock(return_value=[model])
|
||||
|
||||
response = admin_client.get("/llm/admin/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["models"]) == 1
|
||||
assert data["models"][0]["slug"] == "gpt-4"
|
||||
28
autogpt_platform/backend/backend/server/v2/llm/conftest.py
Normal file
28
autogpt_platform/backend/backend/server/v2/llm/conftest.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Local test fixtures for LLM registry tests."""
|
||||
|
||||
import fastapi
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
|
||||
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
|
||||
return {"sub": test_user_id, "role": "user", "email": "test@example.com"}
|
||||
|
||||
return {"get_jwt_payload": override_get_jwt_payload, "user_id": test_user_id}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_admin(admin_user_id):
|
||||
"""Provide mock JWT payload for admin user testing."""
|
||||
|
||||
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
|
||||
return {
|
||||
"sub": admin_user_id,
|
||||
"role": "admin",
|
||||
"email": "test-admin@example.com",
|
||||
}
|
||||
|
||||
return {"get_jwt_payload": override_get_jwt_payload, "user_id": admin_user_id}
|
||||
593
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
593
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
@@ -0,0 +1,593 @@
|
||||
"""Database write operations for LLM registry admin API."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
import prisma.models
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.db import transaction
|
||||
from backend.data.llm_registry.notifications import publish_registry_refresh_notification
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _node_model_value(slug: str) -> str:
|
||||
"""Extract the model value stored in AgentNode.constantInput from a registry slug.
|
||||
|
||||
Registry slugs are formatted as 'provider/model-name' (e.g. 'openai/gpt-4o').
|
||||
The LLM block stores only the model-name part (e.g. 'gpt-4o') in constantInput.
|
||||
"""
|
||||
return slug.split("/", 1)[-1] if "/" in slug else slug
|
||||
|
||||
|
||||
def _build_provider_data(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build provider data dict for Prisma operations."""
|
||||
return {
|
||||
"name": name,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"defaultCredentialProvider": default_credential_provider,
|
||||
"defaultCredentialId": default_credential_id,
|
||||
"defaultCredentialType": default_credential_type,
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
|
||||
|
||||
def _build_model_data(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build model data dict for Prisma operations."""
|
||||
data: dict[str, Any] = {
|
||||
"slug": slug,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"Provider": {"connect": {"id": provider_id}},
|
||||
"contextWindow": context_window,
|
||||
"maxOutputTokens": max_output_tokens,
|
||||
"priceTier": price_tier,
|
||||
"isEnabled": is_enabled,
|
||||
"isRecommended": is_recommended,
|
||||
"supportsTools": supports_tools,
|
||||
"supportsJsonOutput": supports_json_output,
|
||||
"supportsReasoning": supports_reasoning,
|
||||
"supportsParallelToolCalls": supports_parallel_tool_calls,
|
||||
"capabilities": prisma.Json(capabilities or {}),
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
if creator_id:
|
||||
data["Creator"] = {"connect": {"id": creator_id}}
|
||||
return data
|
||||
|
||||
|
||||
async def create_provider(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Create a new LLM provider."""
|
||||
data = _build_provider_data(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
default_credential_provider=default_credential_provider,
|
||||
default_credential_id=default_credential_id,
|
||||
default_credential_type=default_credential_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
provider = await prisma.models.LlmProvider.prisma().create(
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError("Failed to create provider")
|
||||
return provider
|
||||
|
||||
|
||||
async def update_provider(
|
||||
provider_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Update an existing LLM provider."""
|
||||
# Fetch existing provider to get current name
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id}
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if default_credential_provider is not None:
|
||||
data["defaultCredentialProvider"] = default_credential_provider
|
||||
if default_credential_id is not None:
|
||||
data["defaultCredentialId"] = default_credential_id
|
||||
if default_credential_type is not None:
|
||||
data["defaultCredentialType"] = default_credential_type
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
|
||||
updated = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not updated:
|
||||
raise ValueError("Failed to update provider")
|
||||
return updated
|
||||
|
||||
|
||||
async def delete_provider(provider_id: str) -> bool:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
"""
|
||||
# Check if provider exists
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Check if provider has any models
|
||||
model_count = len(provider.Models) if provider.Models else 0
|
||||
if model_count > 0:
|
||||
raise ValueError(
|
||||
f"Cannot delete provider '{provider.displayName}' because it has "
|
||||
f"{model_count} model(s). Delete all models first."
|
||||
)
|
||||
|
||||
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
|
||||
return True
|
||||
|
||||
|
||||
async def create_model(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Create a new LLM model."""
|
||||
data = _build_model_data(
|
||||
slug=slug,
|
||||
display_name=display_name,
|
||||
provider_id=provider_id,
|
||||
context_window=context_window,
|
||||
price_tier=price_tier,
|
||||
description=description,
|
||||
creator_id=creator_id,
|
||||
max_output_tokens=max_output_tokens,
|
||||
is_enabled=is_enabled,
|
||||
is_recommended=is_recommended,
|
||||
supports_tools=supports_tools,
|
||||
supports_json_output=supports_json_output,
|
||||
supports_reasoning=supports_reasoning,
|
||||
supports_parallel_tool_calls=supports_parallel_tool_calls,
|
||||
capabilities=capabilities,
|
||||
metadata=metadata,
|
||||
)
|
||||
model = await prisma.models.LlmModel.prisma().create(
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
if not model:
|
||||
raise ValueError("Failed to create model")
|
||||
return model
|
||||
|
||||
|
||||
async def update_model(
|
||||
model_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
context_window: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
price_tier: int | None = None,
|
||||
is_enabled: bool | None = None,
|
||||
is_recommended: bool | None = None,
|
||||
supports_tools: bool | None = None,
|
||||
supports_json_output: bool | None = None,
|
||||
supports_reasoning: bool | None = None,
|
||||
supports_parallel_tool_calls: bool | None = None,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Update an existing LLM model.
|
||||
|
||||
When is_recommended=True, clears the flag on all other models first so
|
||||
only one model can be recommended at a time.
|
||||
"""
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if context_window is not None:
|
||||
data["contextWindow"] = context_window
|
||||
if max_output_tokens is not None:
|
||||
data["maxOutputTokens"] = max_output_tokens
|
||||
if price_tier is not None:
|
||||
data["priceTier"] = price_tier
|
||||
if is_enabled is not None:
|
||||
data["isEnabled"] = is_enabled
|
||||
if is_recommended is not None:
|
||||
data["isRecommended"] = is_recommended
|
||||
if supports_tools is not None:
|
||||
data["supportsTools"] = supports_tools
|
||||
if supports_json_output is not None:
|
||||
data["supportsJsonOutput"] = supports_json_output
|
||||
if supports_reasoning is not None:
|
||||
data["supportsReasoning"] = supports_reasoning
|
||||
if supports_parallel_tool_calls is not None:
|
||||
data["supportsParallelToolCalls"] = supports_parallel_tool_calls
|
||||
if capabilities is not None:
|
||||
data["capabilities"] = prisma.Json(capabilities)
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
if creator_id is not None:
|
||||
data["creatorId"] = creator_id if creator_id else None
|
||||
|
||||
async with transaction() as tx:
|
||||
# Enforce single recommended model: unset all others first.
|
||||
if is_recommended is True:
|
||||
await tx.llmmodel.update_many(
|
||||
where={"id": {"not": model_id}},
|
||||
data={"isRecommended": False},
|
||||
)
|
||||
|
||||
model = await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return model
|
||||
|
||||
|
||||
async def get_model_usage(slug: str) -> dict[str, Any]:
|
||||
"""Get usage count for a model — how many AgentNodes reference it."""
|
||||
import prisma as prisma_module
|
||||
|
||||
model_value = _node_model_value(slug)
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
model_value,
|
||||
)
|
||||
node_count = int(count_result[0]["count"]) if count_result else 0
|
||||
return {"model_slug": slug, "node_count": node_count}
|
||||
|
||||
|
||||
async def toggle_model_with_migration(
|
||||
model_id: str,
|
||||
is_enabled: bool,
|
||||
migrate_to_slug: str | None = None,
|
||||
migration_reason: str | None = None,
|
||||
custom_credit_cost: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Toggle a model's enabled status, optionally migrating workflows when disabling."""
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
nodes_migrated = 0
|
||||
migration_id: str | None = None
|
||||
|
||||
if not is_enabled and migrate_to_slug:
|
||||
async with transaction() as tx:
|
||||
replacement = await tx.llmmodel.find_unique(
|
||||
where={"slug": migrate_to_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' not found"
|
||||
)
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
source_value = _node_model_value(model.slug)
|
||||
target_value = _node_model_value(migrate_to_slug)
|
||||
node_ids_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT id
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
FOR UPDATE
|
||||
""",
|
||||
source_value,
|
||||
)
|
||||
migrated_node_ids = (
|
||||
[row["id"] for row in node_ids_result] if node_ids_result else []
|
||||
)
|
||||
nodes_migrated = len(migrated_node_ids)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
""",
|
||||
target_value,
|
||||
node_ids_json,
|
||||
)
|
||||
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
migration_record = await tx.llmmodelmigration.create(
|
||||
data={
|
||||
"sourceModelSlug": model.slug,
|
||||
"targetModelSlug": migrate_to_slug,
|
||||
"reason": migration_reason,
|
||||
"migratedNodeIds": json.dumps(migrated_node_ids),
|
||||
"nodeCount": nodes_migrated,
|
||||
"customCreditCost": custom_credit_cost,
|
||||
}
|
||||
)
|
||||
migration_id = migration_record.id
|
||||
else:
|
||||
await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
)
|
||||
|
||||
return {
|
||||
"nodes_migrated": nodes_migrated,
|
||||
"migrated_to_slug": migrate_to_slug if nodes_migrated > 0 else None,
|
||||
"migration_id": migration_id,
|
||||
}
|
||||
|
||||
|
||||
async def delete_model(
|
||||
model_id: str, replacement_model_slug: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Delete an LLM model, optionally migrating affected AgentNodes first.
|
||||
|
||||
If workflows are using this model and no replacement is given, raises ValueError.
|
||||
If replacement is given, atomically migrates all affected nodes then deletes.
|
||||
"""
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
deleted_slug = model.slug
|
||||
deleted_display_name = model.displayName
|
||||
|
||||
async with transaction() as tx:
|
||||
count_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
deleted_slug,
|
||||
)
|
||||
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
if nodes_to_migrate > 0:
|
||||
if not replacement_model_slug:
|
||||
raise ValueError(
|
||||
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
|
||||
f"are using it. Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
replacement = await tx.llmmodel.find_unique(
|
||||
where={"slug": replacement_model_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' not found"
|
||||
)
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' is disabled."
|
||||
)
|
||||
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug,
|
||||
)
|
||||
|
||||
await tx.llmmodel.delete(where={"id": model_id})
|
||||
|
||||
return {
|
||||
"deleted_model_slug": deleted_slug,
|
||||
"deleted_model_display_name": deleted_display_name,
|
||||
"replacement_model_slug": replacement_model_slug,
|
||||
"nodes_migrated": nodes_to_migrate,
|
||||
}
|
||||
|
||||
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List model migrations."""
|
||||
where: Any = None if include_reverted else {"isReverted": False}
|
||||
records = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||
where=where,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"source_model_slug": r.sourceModelSlug,
|
||||
"target_model_slug": r.targetModelSlug,
|
||||
"reason": r.reason,
|
||||
"node_count": r.nodeCount,
|
||||
"custom_credit_cost": r.customCreditCost,
|
||||
"is_reverted": r.isReverted,
|
||||
"reverted_at": r.revertedAt.isoformat() if r.revertedAt else None,
|
||||
"created_at": r.createdAt.isoformat(),
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
|
||||
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Revert a model migration, restoring affected nodes to their original model."""
|
||||
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
if not migration:
|
||||
raise ValueError(f"Migration with id '{migration_id}' not found")
|
||||
|
||||
if migration.isReverted:
|
||||
raise ValueError(
|
||||
f"Migration '{migration_id}' has already been reverted"
|
||||
)
|
||||
|
||||
source_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migration.sourceModelSlug}
|
||||
)
|
||||
if not source_model:
|
||||
raise ValueError(
|
||||
f"Source model '{migration.sourceModelSlug}' no longer exists."
|
||||
)
|
||||
|
||||
migrated_node_ids: list[str] = (
|
||||
migration.migratedNodeIds
|
||||
if isinstance(migration.migratedNodeIds, list)
|
||||
else json.loads(migration.migratedNodeIds) # type: ignore
|
||||
)
|
||||
if not migrated_node_ids:
|
||||
raise ValueError("No nodes to revert in this migration")
|
||||
|
||||
source_model_re_enabled = False
|
||||
|
||||
async with transaction() as tx:
|
||||
if not source_model.isEnabled and re_enable_source_model:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": source_model.id},
|
||||
data={"isEnabled": True},
|
||||
)
|
||||
source_model_re_enabled = True
|
||||
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
result = await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
AND "constantInput"::jsonb->>'model' = $3
|
||||
""",
|
||||
migration.sourceModelSlug,
|
||||
node_ids_json,
|
||||
migration.targetModelSlug,
|
||||
)
|
||||
nodes_reverted = result if isinstance(result, int) else 0
|
||||
|
||||
await tx.llmmodelmigration.update(
|
||||
where={"id": migration_id},
|
||||
data={
|
||||
"isReverted": True,
|
||||
"revertedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"migration_id": migration_id,
|
||||
"source_model_slug": migration.sourceModelSlug,
|
||||
"target_model_slug": migration.targetModelSlug,
|
||||
"nodes_reverted": nodes_reverted,
|
||||
"nodes_already_changed": len(migrated_node_ids) - nodes_reverted,
|
||||
"source_model_re_enabled": source_model_re_enabled,
|
||||
}
|
||||
|
||||
|
||||
async def refresh_runtime_caches() -> None:
|
||||
llm_registry.clear_registry_cache()
|
||||
await llm_registry.refresh_llm_registry()
|
||||
await publish_registry_refresh_notification()
|
||||
698
autogpt_platform/backend/backend/server/v2/llm/db_write_test.py
Normal file
698
autogpt_platform/backend/backend/server/v2/llm/db_write_test.py
Normal file
@@ -0,0 +1,698 @@
|
||||
"""Tests for LLM registry DB write operations (db_write.py).
|
||||
|
||||
All functions under test are async; patch Prisma at the point of use.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.llm import db_write
|
||||
|
||||
_NOW = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_provider(
|
||||
id: str = "prov-1",
|
||||
name: str = "openai",
|
||||
display_name: str = "OpenAI",
|
||||
models: list | None = None,
|
||||
) -> Mock:
|
||||
p = Mock()
|
||||
p.id = id
|
||||
p.name = name
|
||||
p.displayName = display_name
|
||||
p.description = None
|
||||
p.defaultCredentialProvider = None
|
||||
p.defaultCredentialId = None
|
||||
p.defaultCredentialType = None
|
||||
p.metadata = {}
|
||||
p.createdAt = _NOW
|
||||
p.updatedAt = _NOW
|
||||
p.Models = models if models is not None else []
|
||||
return p
|
||||
|
||||
|
||||
def _make_model(
|
||||
id: str = "model-1",
|
||||
slug: str = "gpt-4",
|
||||
display_name: str = "GPT-4",
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
) -> Mock:
|
||||
m = Mock()
|
||||
m.id = id
|
||||
m.slug = slug
|
||||
m.displayName = display_name
|
||||
m.description = None
|
||||
m.providerId = "prov-1"
|
||||
m.creatorId = None
|
||||
m.contextWindow = 128000
|
||||
m.maxOutputTokens = 4096
|
||||
m.priceTier = 2
|
||||
m.isEnabled = is_enabled
|
||||
m.isRecommended = is_recommended
|
||||
m.supportsTools = False
|
||||
m.supportsJsonOutput = False
|
||||
m.supportsReasoning = False
|
||||
m.supportsParallelToolCalls = False
|
||||
m.capabilities = {}
|
||||
m.metadata = {}
|
||||
m.createdAt = _NOW
|
||||
m.updatedAt = _NOW
|
||||
m.Costs = []
|
||||
m.Creator = None
|
||||
return m
|
||||
|
||||
|
||||
def _make_migration(
|
||||
id: str = "mig-1",
|
||||
source_slug: str = "gpt-3",
|
||||
target_slug: str = "gpt-4",
|
||||
node_count: int = 3,
|
||||
migrated_node_ids: list | None = None,
|
||||
is_reverted: bool = False,
|
||||
) -> Mock:
|
||||
mg = Mock()
|
||||
mg.id = id
|
||||
mg.sourceModelSlug = source_slug
|
||||
mg.targetModelSlug = target_slug
|
||||
mg.reason = "upgrade"
|
||||
mg.nodeCount = node_count
|
||||
mg.customCreditCost = None
|
||||
mg.isReverted = is_reverted
|
||||
mg.revertedAt = None
|
||||
mg.createdAt = _NOW
|
||||
mg.migratedNodeIds = migrated_node_ids if migrated_node_ids is not None else ["n1", "n2", "n3"]
|
||||
return mg
|
||||
|
||||
|
||||
def _make_tx_ctx(tx: Mock) -> Mock:
|
||||
"""Return an async context manager that yields tx."""
|
||||
ctx = MagicMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=tx)
|
||||
ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
return ctx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_provider(mocker):
|
||||
"""create_provider calls prisma.create and returns the new provider."""
|
||||
mock_provider = _make_provider()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.create = AsyncMock(return_value=mock_provider)
|
||||
|
||||
result = await db_write.create_provider(name="openai", display_name="OpenAI")
|
||||
|
||||
assert result.name == "openai"
|
||||
assert result.id == "prov-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_provider(mocker):
|
||||
"""update_provider fetches existing provider then calls update."""
|
||||
existing = _make_provider()
|
||||
updated = _make_provider(display_name="OpenAI v2")
|
||||
prisma_mock = mocker.patch("prisma.models.LlmProvider.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=existing)
|
||||
prisma_mock.update = AsyncMock(return_value=updated)
|
||||
|
||||
result = await db_write.update_provider(
|
||||
provider_id="prov-1", display_name="OpenAI v2"
|
||||
)
|
||||
|
||||
assert result.displayName == "OpenAI v2"
|
||||
prisma_mock.find_unique.assert_called_once_with(where={"id": "prov-1"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_provider_not_found(mocker):
|
||||
"""update_provider raises ValueError when the provider does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await db_write.update_provider(provider_id="ghost")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_provider_success(mocker):
|
||||
"""delete_provider deletes the provider when it has no models."""
|
||||
existing = _make_provider(models=[])
|
||||
prisma_mock = mocker.patch("prisma.models.LlmProvider.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=existing)
|
||||
prisma_mock.delete = AsyncMock(return_value=existing)
|
||||
|
||||
result = await db_write.delete_provider(provider_id="prov-1")
|
||||
|
||||
assert result is True
|
||||
prisma_mock.delete.assert_called_once_with(where={"id": "prov-1"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_provider_has_models(mocker):
|
||||
"""delete_provider raises ValueError when the provider has associated models."""
|
||||
model = _make_model()
|
||||
existing = _make_provider(models=[model])
|
||||
mocker.patch(
|
||||
"prisma.models.LlmProvider.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=existing)
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot delete"):
|
||||
await db_write.delete_provider(provider_id="prov-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_model(mocker):
|
||||
"""create_model calls prisma.create with slug and display_name in data."""
|
||||
mock_model = _make_model()
|
||||
prisma_create = mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.create
|
||||
prisma_create = AsyncMock(return_value=mock_model)
|
||||
mocker.patch("prisma.models.LlmModel.prisma").return_value.create = prisma_create
|
||||
|
||||
result = await db_write.create_model(
|
||||
slug="gpt-4",
|
||||
display_name="GPT-4",
|
||||
provider_id="prov-1",
|
||||
context_window=128000,
|
||||
price_tier=2,
|
||||
)
|
||||
|
||||
assert result.slug == "gpt-4"
|
||||
call_kwargs = prisma_create.call_args
|
||||
data_arg = call_kwargs.kwargs.get("data") or call_kwargs.args[0]
|
||||
assert data_arg["slug"] == "gpt-4"
|
||||
assert data_arg["displayName"] == "GPT-4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_model_recommended_clears_others(mocker):
|
||||
"""update_model with is_recommended=True calls update_many to clear others first."""
|
||||
updated_model = _make_model(is_recommended=True)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.update_many = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock(return_value=updated_model)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.update_model(model_id="model-1", is_recommended=True)
|
||||
|
||||
tx.llmmodel.update_many.assert_called_once_with(
|
||||
where={"id": {"not": "model-1"}},
|
||||
data={"isRecommended": False},
|
||||
)
|
||||
assert result.isRecommended is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_model_recommended_false_no_clear(mocker):
|
||||
"""update_model with is_recommended=False does NOT call update_many."""
|
||||
updated_model = _make_model(is_recommended=False)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.update_many = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock(return_value=updated_model)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
await db_write.update_model(model_id="model-1", is_recommended=False)
|
||||
|
||||
tx.llmmodel.update_many.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_model_not_found(mocker):
|
||||
"""update_model raises ValueError when tx.llmmodel.update returns None."""
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.update_many = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock(return_value=None)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await db_write.update_model(model_id="ghost")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_usage(mocker):
|
||||
"""get_model_usage parses query_raw result and returns node_count."""
|
||||
mock_client = Mock()
|
||||
mock_client.query_raw = AsyncMock(return_value=[{"count": "3"}])
|
||||
mocker.patch("prisma.get_client", return_value=mock_client)
|
||||
|
||||
result = await db_write.get_model_usage("gpt-4")
|
||||
|
||||
assert result["node_count"] == 3
|
||||
assert result["model_slug"] == "gpt-4"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_no_usage(mocker):
|
||||
"""delete_model with no node usage deletes the model and returns nodes_migrated=0."""
|
||||
model = _make_model()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.query_raw = AsyncMock(return_value=[{"count": "0"}])
|
||||
tx.execute_raw = AsyncMock()
|
||||
tx.llmmodel.delete = AsyncMock()
|
||||
tx.llmmodel.find_unique = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.delete_model(model_id="model-1")
|
||||
|
||||
tx.llmmodel.delete.assert_called_once_with(where={"id": "model-1"})
|
||||
assert result["nodes_migrated"] == 0
|
||||
assert result["deleted_model_slug"] == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_with_replacement(mocker):
|
||||
"""delete_model migrates nodes and deletes when replacement is valid and enabled."""
|
||||
model = _make_model(slug="gpt-3")
|
||||
replacement = _make_model(id="model-2", slug="gpt-4", is_enabled=True)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.query_raw = AsyncMock(return_value=[{"count": "2"}])
|
||||
tx.llmmodel.find_unique = AsyncMock(return_value=replacement)
|
||||
tx.execute_raw = AsyncMock()
|
||||
tx.llmmodel.delete = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.delete_model(
|
||||
model_id="model-1", replacement_model_slug="gpt-4"
|
||||
)
|
||||
|
||||
tx.execute_raw.assert_called_once()
|
||||
tx.llmmodel.delete.assert_called_once()
|
||||
assert result["nodes_migrated"] == 2
|
||||
assert result["replacement_model_slug"] == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_usage_no_replacement(mocker):
|
||||
"""delete_model raises ValueError when nodes use the model but no replacement given."""
|
||||
model = _make_model()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.query_raw = AsyncMock(return_value=[{"count": "5"}])
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="provide a replacement"):
|
||||
await db_write.delete_model(model_id="model-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_replacement_disabled(mocker):
|
||||
"""delete_model raises ValueError when the replacement model is disabled."""
|
||||
model = _make_model(slug="gpt-3")
|
||||
disabled_replacement = _make_model(slug="gpt-4", is_enabled=False)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.query_raw = AsyncMock(return_value=[{"count": "2"}])
|
||||
tx.llmmodel.find_unique = AsyncMock(return_value=disabled_replacement)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="is disabled"):
|
||||
await db_write.delete_model(
|
||||
model_id="model-1", replacement_model_slug="gpt-4"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_model_not_found(mocker):
|
||||
"""delete_model raises ValueError when the model does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await db_write.delete_model(model_id="ghost")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Toggle model with migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_enable_no_migration(mocker):
|
||||
"""toggle_model_with_migration enable without migrate_to_slug does simple update."""
|
||||
model = _make_model(is_enabled=False)
|
||||
prisma_mock = mocker.patch("prisma.models.LlmModel.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=model)
|
||||
prisma_mock.update = AsyncMock(return_value=model)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id="model-1", is_enabled=True
|
||||
)
|
||||
|
||||
prisma_mock.update.assert_called_once()
|
||||
assert result["nodes_migrated"] == 0
|
||||
assert result["migration_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_disable_with_migration(mocker):
|
||||
"""Disabling with migrate_to_slug creates migration record and returns nodes_migrated."""
|
||||
model = _make_model(slug="gpt-3", is_enabled=True)
|
||||
replacement = _make_model(id="model-2", slug="gpt-4", is_enabled=True)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
migration_record = Mock()
|
||||
migration_record.id = "mig-new"
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.find_unique = AsyncMock(return_value=replacement)
|
||||
tx.query_raw = AsyncMock(
|
||||
return_value=[{"id": "node-1"}, {"id": "node-2"}]
|
||||
)
|
||||
tx.execute_raw = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock()
|
||||
tx.llmmodelmigration.create = AsyncMock(return_value=migration_record)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id="model-1",
|
||||
is_enabled=False,
|
||||
migrate_to_slug="gpt-4",
|
||||
migration_reason="upgrade",
|
||||
)
|
||||
|
||||
assert result["nodes_migrated"] == 2
|
||||
assert result["migration_id"] == "mig-new"
|
||||
tx.llmmodelmigration.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_disable_migration_target_not_found(mocker):
|
||||
"""Disabling with nonexistent replacement raises ValueError."""
|
||||
model = _make_model(slug="gpt-3", is_enabled=True)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await db_write.toggle_model_with_migration(
|
||||
model_id="model-1", is_enabled=False, migrate_to_slug="ghost"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_disable_migration_target_disabled(mocker):
|
||||
"""Disabling with a disabled replacement raises ValueError."""
|
||||
model = _make_model(slug="gpt-3", is_enabled=True)
|
||||
disabled = _make_model(slug="gpt-4", is_enabled=False)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.find_unique = AsyncMock(return_value=disabled)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
await db_write.toggle_model_with_migration(
|
||||
model_id="model-1", is_enabled=False, migrate_to_slug="gpt-4"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_disable_without_migration(mocker):
|
||||
"""Disabling without migrate_to_slug just updates is_enabled, no nodes migrated."""
|
||||
model = _make_model(slug="gpt-3", is_enabled=True)
|
||||
prisma_mock = mocker.patch("prisma.models.LlmModel.prisma").return_value
|
||||
prisma_mock.find_unique = AsyncMock(return_value=model)
|
||||
prisma_mock.update = AsyncMock(return_value=model)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id="model-1", is_enabled=False # no migrate_to_slug
|
||||
)
|
||||
|
||||
# Should do a simple update, no transaction, no migration record
|
||||
prisma_mock.update.assert_called_once_with(
|
||||
where={"id": "model-1"}, data={"isEnabled": False}
|
||||
)
|
||||
assert result["nodes_migrated"] == 0
|
||||
assert result["migration_id"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Migration operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_migrations_active_only(mocker):
|
||||
"""list_migrations(include_reverted=False) passes where={isReverted: False}."""
|
||||
records = [_make_migration()]
|
||||
prisma_find_many = mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_many
|
||||
prisma_find_many = AsyncMock(return_value=records)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_many = prisma_find_many
|
||||
|
||||
result = await db_write.list_migrations(include_reverted=False)
|
||||
|
||||
call_kwargs = prisma_find_many.call_args
|
||||
where_arg = (call_kwargs.kwargs.get("where") or
|
||||
(call_kwargs.args[0] if call_kwargs.args else None))
|
||||
assert where_arg == {"isReverted": False}
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_migrations_include_reverted(mocker):
|
||||
"""list_migrations(include_reverted=True) passes where=None."""
|
||||
records = [_make_migration(), _make_migration(id="mig-2", is_reverted=True)]
|
||||
prisma_find_many = AsyncMock(return_value=records)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_many = prisma_find_many
|
||||
|
||||
result = await db_write.list_migrations(include_reverted=True)
|
||||
|
||||
call_kwargs = prisma_find_many.call_args
|
||||
where_arg = call_kwargs.kwargs.get("where")
|
||||
assert where_arg is None
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_migration_success(mocker):
|
||||
"""revert_migration re-enables source model, updates nodes, marks migration reverted."""
|
||||
migration = _make_migration(migrated_node_ids=["n1", "n2"])
|
||||
source_model = _make_model(is_enabled=False)
|
||||
|
||||
prisma_migration_mock = mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value
|
||||
prisma_migration_mock.find_unique = AsyncMock(return_value=migration)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=source_model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock()
|
||||
tx.execute_raw = AsyncMock(return_value=2)
|
||||
tx.llmmodelmigration.update = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.revert_migration(
|
||||
migration_id="mig-1", re_enable_source_model=True
|
||||
)
|
||||
|
||||
assert result["nodes_reverted"] == 2
|
||||
assert result["source_model_re_enabled"] is True
|
||||
tx.llmmodel.update.assert_called_once_with(
|
||||
where={"id": source_model.id},
|
||||
data={"isEnabled": True},
|
||||
)
|
||||
tx.llmmodelmigration.update.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_migration_already_reverted(mocker):
|
||||
"""revert_migration raises ValueError when migration is already reverted."""
|
||||
migration = _make_migration(is_reverted=True)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=migration)
|
||||
|
||||
with pytest.raises(ValueError, match="already been reverted"):
|
||||
await db_write.revert_migration(migration_id="mig-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_migration_not_found(mocker):
|
||||
"""revert_migration raises ValueError when migration does not exist."""
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await db_write.revert_migration(migration_id="ghost")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_migration_no_re_enable(mocker):
|
||||
"""revert_migration with re_enable_source_model=False does not re-enable source."""
|
||||
migration = _make_migration(migrated_node_ids=["n1", "n2"])
|
||||
source_model = _make_model(is_enabled=False)
|
||||
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=migration)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=source_model)
|
||||
|
||||
tx = AsyncMock()
|
||||
tx.llmmodel.update = AsyncMock()
|
||||
tx.execute_raw = AsyncMock(return_value=2)
|
||||
tx.llmmodelmigration.update = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.db_write.transaction",
|
||||
return_value=_make_tx_ctx(tx),
|
||||
)
|
||||
|
||||
result = await db_write.revert_migration(
|
||||
migration_id="mig-1", re_enable_source_model=False
|
||||
)
|
||||
|
||||
# Model should NOT have been re-enabled
|
||||
tx.llmmodel.update.assert_not_called()
|
||||
assert result["source_model_re_enabled"] is False
|
||||
assert result["nodes_reverted"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_migration_source_model_gone(mocker):
|
||||
"""revert_migration raises ValueError when the source model no longer exists."""
|
||||
migration = _make_migration()
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModelMigration.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=migration)
|
||||
mocker.patch(
|
||||
"prisma.models.LlmModel.prisma"
|
||||
).return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="no longer exists"):
|
||||
await db_write.revert_migration(migration_id="mig-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_runtime_caches(mocker):
|
||||
"""refresh_runtime_caches clears cache, refreshes registry, publishes notification."""
|
||||
mock_clear = mocker.patch("backend.data.llm_registry.clear_registry_cache")
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.data.llm_registry.refresh_llm_registry",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
mock_publish = mocker.patch(
|
||||
"backend.data.llm_registry.notifications.publish_registry_refresh_notification",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
await db_write.refresh_runtime_caches()
|
||||
|
||||
mock_clear.assert_called_once()
|
||||
mock_refresh.assert_called_once()
|
||||
mock_publish.assert_called_once()
|
||||
69
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
69
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Pydantic models for LLM registry public API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class LlmModelCost(pydantic.BaseModel):
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int = pydantic.Field(ge=0)
|
||||
credential_provider: str
|
||||
credential_id: str | None = None
|
||||
credential_type: str | None = None
|
||||
currency: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCreator(pydantic.BaseModel):
|
||||
"""Represents the organization that created/trained the model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
|
||||
|
||||
class LlmModel(pydantic.BaseModel):
|
||||
"""Public-facing LLM model information."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
provider_name: str
|
||||
creator: LlmModelCreator | None = None
|
||||
context_window: int
|
||||
max_output_tokens: int | None = None
|
||||
price_tier: int # 1=cheapest, 2=medium, 3=expensive
|
||||
is_enabled: bool = True
|
||||
is_recommended: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvider(pydantic.BaseModel):
|
||||
"""Provider with its enabled models."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
models: list[LlmModel] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmModelsResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/models."""
|
||||
|
||||
models: list[LlmModel]
|
||||
total: int
|
||||
|
||||
|
||||
class LlmProvidersResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/providers."""
|
||||
|
||||
providers: list[LlmProvider]
|
||||
total: int
|
||||
96
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
96
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Public read-only API for LLM registry."""
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data.llm_registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCreator,
|
||||
get_all_models,
|
||||
get_enabled_models,
|
||||
)
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/llm",
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(
|
||||
creator: RegistryModelCreator | None,
|
||||
) -> llm_model.LlmModelCreator | None:
|
||||
if not creator:
|
||||
return None
|
||||
return llm_model.LlmModelCreator(
|
||||
id=creator.id,
|
||||
name=creator.name,
|
||||
display_name=creator.display_name,
|
||||
description=creator.description,
|
||||
website_url=creator.website_url,
|
||||
logo_url=creator.logo_url,
|
||||
)
|
||||
|
||||
|
||||
def _map_model(model: RegistryModel) -> llm_model.LlmModel:
|
||||
return llm_model.LlmModel(
|
||||
slug=model.slug,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
provider_name=model.provider_display_name,
|
||||
creator=_map_creator(model.creator),
|
||||
context_window=model.metadata.context_window,
|
||||
max_output_tokens=model.metadata.max_output_tokens,
|
||||
price_tier=model.metadata.price_tier,
|
||||
is_enabled=model.is_enabled,
|
||||
is_recommended=model.is_recommended,
|
||||
capabilities=model.capabilities,
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
unit=cost.unit,
|
||||
credit_cost=cost.credit_cost,
|
||||
credential_provider=cost.credential_provider,
|
||||
credential_id=cost.credential_id,
|
||||
credential_type=cost.credential_type,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata,
|
||||
)
|
||||
for cost in model.costs
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models(
|
||||
enabled_only: bool = fastapi.Query(
|
||||
default=True, description="Only return enabled models"
|
||||
),
|
||||
):
|
||||
registry_models = get_enabled_models() if enabled_only else get_all_models()
|
||||
models = [_map_model(m) for m in registry_models]
|
||||
return llm_model.LlmModelsResponse(models=models, total=len(models))
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
registry_models = get_enabled_models()
|
||||
|
||||
provider_map: dict[str, list[RegistryModel]] = {}
|
||||
for model in registry_models:
|
||||
provider_key = model.metadata.provider
|
||||
if provider_key not in provider_map:
|
||||
provider_map[provider_key] = []
|
||||
provider_map[provider_key].append(model)
|
||||
|
||||
providers = [
|
||||
llm_model.LlmProvider(
|
||||
name=provider_key,
|
||||
display_name=models[0].provider_display_name if models else provider_key,
|
||||
models=[
|
||||
_map_model(m) for m in sorted(models, key=lambda m: m.display_name)
|
||||
],
|
||||
)
|
||||
for provider_key, models in sorted(provider_map.items())
|
||||
]
|
||||
|
||||
return llm_model.LlmProvidersResponse(providers=providers, total=len(providers))
|
||||
282
autogpt_platform/backend/backend/server/v2/llm/routes_test.py
Normal file
282
autogpt_platform/backend/backend/server/v2/llm/routes_test.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Tests for public read-only LLM registry routes (routes.py).
|
||||
|
||||
Covers:
|
||||
- GET /llm/models (enabled_only=True default, enabled_only=False)
|
||||
- GET /llm/providers
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.llm.routes import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Bypass JWT auth for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_cost(
|
||||
unit: str = "RUN",
|
||||
credit_cost: int = 10,
|
||||
credential_provider: str = "openai",
|
||||
credential_id: str | None = None,
|
||||
credential_type: str | None = None,
|
||||
currency: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> Mock:
|
||||
cost = Mock()
|
||||
cost.unit = unit
|
||||
cost.credit_cost = credit_cost
|
||||
cost.credential_provider = credential_provider
|
||||
cost.credential_id = credential_id
|
||||
cost.credential_type = credential_type
|
||||
cost.currency = currency
|
||||
cost.metadata = metadata or {}
|
||||
return cost
|
||||
|
||||
|
||||
def _make_mock_creator(
|
||||
id: str = "creator-1",
|
||||
name: str = "openai",
|
||||
display_name: str = "OpenAI",
|
||||
description: str | None = "AI company",
|
||||
website_url: str | None = "https://openai.com",
|
||||
logo_url: str | None = None,
|
||||
) -> Mock:
|
||||
creator = Mock()
|
||||
creator.id = id
|
||||
creator.name = name
|
||||
creator.display_name = display_name
|
||||
creator.description = description
|
||||
creator.website_url = website_url
|
||||
creator.logo_url = logo_url
|
||||
return creator
|
||||
|
||||
|
||||
def _make_mock_model(
|
||||
slug: str = "gpt-4",
|
||||
display_name: str = "GPT-4",
|
||||
description: str | None = "Latest GPT",
|
||||
provider_display_name: str = "OpenAI",
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
capabilities: dict | None = None,
|
||||
provider_key: str = "openai",
|
||||
context_window: int = 128000,
|
||||
max_output_tokens: int | None = 4096,
|
||||
price_tier: int = 2,
|
||||
creator: Mock | None = None,
|
||||
costs: list | None = None,
|
||||
) -> Mock:
|
||||
model = Mock()
|
||||
model.slug = slug
|
||||
model.display_name = display_name
|
||||
model.description = description
|
||||
model.provider_display_name = provider_display_name
|
||||
model.is_enabled = is_enabled
|
||||
model.is_recommended = is_recommended
|
||||
model.capabilities = capabilities or {}
|
||||
model.creator = creator
|
||||
model.costs = costs or []
|
||||
|
||||
meta = Mock()
|
||||
meta.provider = provider_key
|
||||
meta.context_window = context_window
|
||||
meta.max_output_tokens = max_output_tokens
|
||||
meta.price_tier = price_tier
|
||||
model.metadata = meta
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /llm/models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_models_enabled_only(mocker):
|
||||
"""Default enabled_only=True calls get_enabled_models and returns correct shape."""
|
||||
mock_model = _make_mock_model()
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models", return_value=[mock_model]
|
||||
)
|
||||
mocker.patch("backend.server.v2.llm.routes.get_all_models", return_value=[])
|
||||
|
||||
response = client.get("/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert len(data["models"]) == 1
|
||||
first = data["models"][0]
|
||||
assert first["slug"] == "gpt-4"
|
||||
assert first["display_name"] == "GPT-4"
|
||||
assert first["provider_name"] == "OpenAI"
|
||||
assert first["is_enabled"] is True
|
||||
assert first["is_recommended"] is False
|
||||
assert first["context_window"] == 128000
|
||||
assert first["price_tier"] == 2
|
||||
assert first["creator"] is None
|
||||
assert first["costs"] == []
|
||||
|
||||
|
||||
def test_list_models_all(mocker):
|
||||
"""enabled_only=false calls get_all_models instead of get_enabled_models."""
|
||||
mock_model = _make_mock_model(is_enabled=False, slug="gpt-3")
|
||||
mock_get_all = mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_all_models", return_value=[mock_model]
|
||||
)
|
||||
mock_get_enabled = mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models", return_value=[]
|
||||
)
|
||||
|
||||
response = client.get("/llm/models?enabled_only=false")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get_all.assert_called_once()
|
||||
mock_get_enabled.assert_not_called()
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["models"][0]["slug"] == "gpt-3"
|
||||
|
||||
|
||||
def test_list_models_empty(mocker):
|
||||
"""Empty registry returns an empty models list with total=0."""
|
||||
mocker.patch("backend.server.v2.llm.routes.get_enabled_models", return_value=[])
|
||||
|
||||
response = client.get("/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
assert data["models"] == []
|
||||
|
||||
|
||||
def test_list_models_with_creator(mocker):
|
||||
"""Model with a creator surfaces creator fields in the response."""
|
||||
creator = _make_mock_creator()
|
||||
mock_model = _make_mock_model(creator=creator)
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models", return_value=[mock_model]
|
||||
)
|
||||
|
||||
response = client.get("/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
creator_data = response.json()["models"][0]["creator"]
|
||||
assert creator_data is not None
|
||||
assert creator_data["id"] == "creator-1"
|
||||
assert creator_data["name"] == "openai"
|
||||
assert creator_data["display_name"] == "OpenAI"
|
||||
assert creator_data["website_url"] == "https://openai.com"
|
||||
|
||||
|
||||
def test_list_models_with_costs(mocker):
|
||||
"""Model with costs surfaces cost entries in the response."""
|
||||
cost = _make_mock_cost(unit="TOKENS", credit_cost=5, credential_provider="openai")
|
||||
mock_model = _make_mock_model(costs=[cost])
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models", return_value=[mock_model]
|
||||
)
|
||||
|
||||
response = client.get("/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
costs_data = response.json()["models"][0]["costs"]
|
||||
assert len(costs_data) == 1
|
||||
assert costs_data[0]["unit"] == "TOKENS"
|
||||
assert costs_data[0]["credit_cost"] == 5
|
||||
assert costs_data[0]["credential_provider"] == "openai"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /llm/providers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_providers(mocker):
|
||||
"""Single provider groups its models correctly."""
|
||||
model_a = _make_mock_model(
|
||||
slug="gpt-4", display_name="GPT-4", provider_key="openai"
|
||||
)
|
||||
model_b = _make_mock_model(
|
||||
slug="gpt-3.5", display_name="GPT-3.5", provider_key="openai"
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models",
|
||||
return_value=[model_a, model_b],
|
||||
)
|
||||
|
||||
response = client.get("/llm/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
assert len(providers) == 1
|
||||
provider = providers[0]
|
||||
assert provider["name"] == "openai"
|
||||
assert provider["display_name"] == "OpenAI"
|
||||
assert len(provider["models"]) == 2
|
||||
|
||||
|
||||
def test_list_providers_multiple_providers(mocker):
|
||||
"""Two different providers are both present and sorted alphabetically."""
|
||||
openai_model = _make_mock_model(
|
||||
slug="gpt-4",
|
||||
display_name="GPT-4",
|
||||
provider_key="openai",
|
||||
provider_display_name="OpenAI",
|
||||
)
|
||||
anthropic_model = _make_mock_model(
|
||||
slug="claude-3",
|
||||
display_name="Claude 3",
|
||||
provider_key="anthropic",
|
||||
provider_display_name="Anthropic",
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.llm.routes.get_enabled_models",
|
||||
return_value=[openai_model, anthropic_model],
|
||||
)
|
||||
|
||||
response = client.get("/llm/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
assert len(providers) == 2
|
||||
provider_names = [p["name"] for p in providers]
|
||||
# sorted alphabetically: anthropic before openai
|
||||
assert provider_names == sorted(provider_names)
|
||||
assert "anthropic" in provider_names
|
||||
assert "openai" in provider_names
|
||||
|
||||
|
||||
def test_list_providers_empty(mocker):
|
||||
"""Empty registry returns an empty providers list."""
|
||||
mocker.patch("backend.server.v2.llm.routes.get_enabled_models", return_value=[])
|
||||
|
||||
response = client.get("/llm/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["providers"] == []
|
||||
@@ -0,0 +1,148 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmProvider" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"defaultCredentialProvider" TEXT,
|
||||
"defaultCredentialId" TEXT,
|
||||
"defaultCredentialType" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCreator" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"websiteUrl" TEXT,
|
||||
"logoUrl" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModel" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"slug" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"providerId" TEXT NOT NULL,
|
||||
"creatorId" TEXT,
|
||||
"contextWindow" INTEGER NOT NULL,
|
||||
"maxOutputTokens" INTEGER,
|
||||
"priceTier" INTEGER NOT NULL DEFAULT 1,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"isRecommended" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsTools" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsReasoning" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsParallelToolCalls" BOOLEAN NOT NULL DEFAULT false,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}',
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCost" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
|
||||
"creditCost" INTEGER NOT NULL,
|
||||
"credentialProvider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"credentialType" TEXT,
|
||||
"currency" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
"llmModelId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelMigration" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"sourceModelSlug" TEXT NOT NULL,
|
||||
"targetModelSlug" TEXT NOT NULL,
|
||||
"reason" TEXT,
|
||||
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
|
||||
"nodeCount" INTEGER NOT NULL,
|
||||
"customCreditCost" INTEGER,
|
||||
"isReverted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"revertedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmProvider_name_key" ON "LlmProvider"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModel_slug_key" ON "LlmModel"("slug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- CreateIndex (partial unique for default costs - no specific credential)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_default_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL;
|
||||
|
||||
-- CreateIndex (partial unique for credential-specific costs)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_credential_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "credentialId", "unit") WHERE "credentialId" IS NOT NULL;
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");
|
||||
|
||||
-- CreateIndex (partial unique to prevent multiple active migrations per source)
|
||||
CREATE UNIQUE INDEX "LlmModelMigration_active_source_key" ON "LlmModelMigration"("sourceModelSlug") WHERE "isReverted" = false;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey" FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_sourceModelSlug_fkey" FOREIGN KEY ("sourceModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_targetModelSlug_fkey" FOREIGN KEY ("targetModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddCheckConstraints (enforce data integrity)
|
||||
ALTER TABLE "LlmModel"
|
||||
ADD CONSTRAINT "LlmModel_priceTier_check" CHECK ("priceTier" BETWEEN 1 AND 3);
|
||||
|
||||
ALTER TABLE "LlmModelCost"
|
||||
ADD CONSTRAINT "LlmModelCost_creditCost_check" CHECK ("creditCost" >= 0);
|
||||
|
||||
ALTER TABLE "LlmModelMigration"
|
||||
ADD CONSTRAINT "LlmModelMigration_nodeCount_check" CHECK ("nodeCount" >= 0),
|
||||
ADD CONSTRAINT "LlmModelMigration_customCreditCost_check" CHECK ("customCreditCost" IS NULL OR "customCreditCost" >= 0);
|
||||
@@ -0,0 +1,287 @@
|
||||
-- Seed LLM Registry from existing hard-coded data
|
||||
-- This migration populates the LlmProvider, LlmModelCreator, LlmModel, and LlmModelCost tables
|
||||
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
|
||||
|
||||
-- Insert Providers
|
||||
INSERT INTO "LlmProvider" ("id", "createdAt", "updatedAt", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Model Creators
|
||||
INSERT INTO "LlmModelCreator" ("id", "createdAt", "updatedAt", "name", "displayName", "description", "websiteUrl", "logoUrl", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT, O1, O3, and DALL-E models', 'https://openai.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude AI models', 'https://anthropic.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama foundation models', 'https://llama.meta.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini and PaLM models', 'https://deepmind.google', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'mistralai', 'Mistral AI', 'Creator of Mistral and Codestral models', 'https://mistral.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command language models', 'https://cohere.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek reasoning models', 'https://deepseek.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'alibaba', 'Alibaba', 'Creator of Qwen language models', 'https://qwenlm.github.io', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 AI models', 'https://v0.dev', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of Phi models', 'https://microsoft.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar search models', 'https://perplexity.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'nousresearch', 'Nous Research', 'Creator of Hermes language models', 'https://nousresearch.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova language models', 'https://aws.amazon.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'moonshotai', 'Moonshot AI', 'Creator of Kimi language models', 'https://moonshot.ai', NULL, '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Models (using CTEs to reference provider and creator IDs)
|
||||
WITH provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
),
|
||||
creator_ids AS (
|
||||
SELECT "id", "name" FROM "LlmModelCreator"
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "createdAt", "updatedAt", "slug", "displayName", "description", "providerId", "creatorId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP,
|
||||
model_slug,
|
||||
model_display_name,
|
||||
NULL,
|
||||
p."id",
|
||||
c."id",
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM (VALUES
|
||||
-- OpenAI models (creator: openai)
|
||||
('o3-2025-04-16', 'O3', 'openai', 'openai', 200000, 100000),
|
||||
('o3-mini', 'O3 Mini', 'openai', 'openai', 200000, 100000),
|
||||
('o1', 'O1', 'openai', 'openai', 200000, 100000),
|
||||
('o1-mini', 'O1 Mini', 'openai', 'openai', 128000, 65536),
|
||||
('gpt-5.2-2025-12-11', 'GPT-5.2', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-2025-08-07', 'GPT 5', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 'openai', 400000, 16384),
|
||||
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 'openai', 1000000, 32768),
|
||||
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 'openai', 1047576, 32768),
|
||||
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 'openai', 128000, 16384),
|
||||
('gpt-4o', 'GPT 4o', 'openai', 'openai', 128000, 16384),
|
||||
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 'openai', 128000, 4096),
|
||||
-- Anthropic models (creator: anthropic)
|
||||
('claude-opus-4-6', 'Claude Opus 4.6', 'anthropic', 'anthropic', 200000, 128000),
|
||||
('claude-sonnet-4-6', 'Claude Sonnet 4.6', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 'anthropic', 200000, 32000),
|
||||
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 'anthropic', 200000, 32000),
|
||||
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 'anthropic', 200000, 4096),
|
||||
-- AI/ML API models (creators: alibaba, nvidia, meta)
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 'alibaba', 32000, 8000),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 'nvidia', 128000, 40000),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 'meta', 128000, NULL),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 'meta', 131000, 2000),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 'meta', 128000, NULL),
|
||||
-- Groq models (creator: meta for Llama)
|
||||
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 'meta', 128000, 32768),
|
||||
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 'meta', 128000, 8192),
|
||||
-- Ollama models (creators: meta for Llama, mistralai for Mistral)
|
||||
('llama3.3', 'Llama 3.3', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3.2', 'Llama 3.2', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3', 'Llama 3', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 'meta', 8192, NULL),
|
||||
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 'mistralai', 32768, NULL),
|
||||
-- OpenRouter models (creators: google, mistralai, cohere, deepseek, perplexity, nousresearch, openai, amazon, microsoft, gryphe, meta, xai, moonshotai, alibaba)
|
||||
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 'google', 1050000, 8192),
|
||||
('google/gemini-2.5-pro', 'Gemini 2.5 Pro', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-3.1-pro-preview', 'Gemini 3.1 Pro Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-3-flash-preview', 'Gemini 3 Flash Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 'google', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 'google', 1048576, 8192),
|
||||
('google/gemini-3.1-flash-lite-preview', 'Gemini 3.1 Flash Lite Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 'google', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 'google', 1048576, 8192),
|
||||
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 'mistralai', 128000, 4096),
|
||||
('mistralai/mistral-large-2512', 'Mistral Large 3 2512', 'open_router', 'mistralai', 262144, NULL),
|
||||
('mistralai/mistral-medium-3.1', 'Mistral Medium 3.1', 'open_router', 'mistralai', 131072, NULL),
|
||||
('mistralai/mistral-small-3.2-24b-instruct', 'Mistral Small 3.2 24B', 'open_router', 'mistralai', 131072, 131072),
|
||||
('mistralai/codestral-2508', 'Codestral 2508', 'open_router', 'mistralai', 256000, NULL),
|
||||
('cohere/command-r-08-2024', 'Command R', 'open_router', 'cohere', 128000, 4096),
|
||||
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 'cohere', 128000, 4096),
|
||||
('cohere/command-a-03-2025', 'Command A 03.2025', 'open_router', 'cohere', 256000, 8192),
|
||||
('cohere/command-a-reasoning-08-2025', 'Command A Reasoning 08.2025', 'open_router', 'cohere', 256000, 32768),
|
||||
('cohere/command-a-translate-08-2025', 'Command A Translate 08.2025', 'open_router', 'cohere', 128000, 8192),
|
||||
('cohere/command-a-vision-07-2025', 'Command A Vision 07.2025', 'open_router', 'cohere', 128000, 8192),
|
||||
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 'deepseek', 64000, 2048),
|
||||
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 'deepseek', 163840, 163840),
|
||||
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 'perplexity', 127000, 8000),
|
||||
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 'perplexity', 200000, 8000),
|
||||
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 'perplexity', 128000, 16000),
|
||||
('perplexity/sonar-reasoning-pro', 'Sonar Reasoning Pro', 'open_router', 'perplexity', 128000, 8000),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 'nousresearch', 131000, 4096),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 'nousresearch', 12288, 12288),
|
||||
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 'openai', 131072, 131072),
|
||||
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 'openai', 131072, 32768),
|
||||
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 'amazon', 300000, 5120),
|
||||
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 'amazon', 128000, 5120),
|
||||
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 'amazon', 300000, 5120),
|
||||
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 'microsoft', 65536, 4096),
|
||||
('microsoft/phi-4', 'Phi-4', 'open_router', 'microsoft', 16384, 16384),
|
||||
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 'gryphe', 4096, 4096),
|
||||
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 'meta', 131072, 131072),
|
||||
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 'meta', 1048576, 1000000),
|
||||
('x-ai/grok-3', 'Grok 3', 'open_router', 'xai', 131072, 131072),
|
||||
('x-ai/grok-4', 'Grok 4', 'open_router', 'xai', 256000, 256000),
|
||||
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 'xai', 2000000, 30000),
|
||||
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 'xai', 2000000, 30000),
|
||||
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 'xai', 256000, 10000),
|
||||
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 'moonshotai', 131000, 131000),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 'alibaba', 262144, 262144),
|
||||
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 'alibaba', 262144, 262144),
|
||||
-- Llama API models (creator: meta)
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 'meta', 128000, 4028),
|
||||
-- v0 models (creator: vercel)
|
||||
('v0-1.5-md', 'v0 1.5 MD', 'v0', 'vercel', 128000, 64000),
|
||||
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 'vercel', 512000, 64000),
|
||||
('v0-1.0-md', 'v0 1.0 MD', 'v0', 'vercel', 128000, 64000)
|
||||
) AS models(model_slug, model_display_name, provider_name, creator_name, context_window, max_output_tokens)
|
||||
JOIN provider_ids p ON p."name" = models.provider_name
|
||||
JOIN creator_ids c ON c."name" = models.creator_name
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Insert Costs (using CTEs to reference model IDs)
|
||||
WITH model_ids AS (
|
||||
SELECT "id", "slug", "providerId" FROM "LlmModel"
|
||||
),
|
||||
provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "createdAt", "updatedAt", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP,
|
||||
'RUN'::"LlmCostUnit",
|
||||
cost,
|
||||
p."name",
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM (VALUES
|
||||
-- OpenAI costs
|
||||
('o3-2025-04-16', 4),
|
||||
('o3-mini', 2),
|
||||
('o1', 16),
|
||||
('o1-mini', 4),
|
||||
('gpt-5.2-2025-12-11', 5),
|
||||
('gpt-5-2025-08-07', 2),
|
||||
('gpt-5.1-2025-11-13', 5),
|
||||
('gpt-5-mini-2025-08-07', 1),
|
||||
('gpt-5-nano-2025-08-07', 1),
|
||||
('gpt-5-chat-latest', 5),
|
||||
('gpt-4.1-2025-04-14', 2),
|
||||
('gpt-4.1-mini-2025-04-14', 1),
|
||||
('gpt-4o-mini', 1),
|
||||
('gpt-4o', 3),
|
||||
('gpt-4-turbo', 10),
|
||||
-- Anthropic costs
|
||||
('claude-opus-4-6', 21),
|
||||
('claude-sonnet-4-6', 5),
|
||||
('claude-opus-4-1-20250805', 21),
|
||||
('claude-opus-4-20250514', 21),
|
||||
('claude-sonnet-4-20250514', 5),
|
||||
('claude-haiku-4-5-20251001', 4),
|
||||
('claude-opus-4-5-20251101', 14),
|
||||
('claude-sonnet-4-5-20250929', 9),
|
||||
('claude-3-haiku-20240307', 1),
|
||||
-- AI/ML API costs
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
|
||||
-- Groq costs
|
||||
('llama-3.3-70b-versatile', 1),
|
||||
('llama-3.1-8b-instant', 1),
|
||||
-- Ollama costs
|
||||
('llama3.3', 1),
|
||||
('llama3.2', 1),
|
||||
('llama3', 1),
|
||||
('llama3.1:405b', 1),
|
||||
('dolphin-mistral:latest', 1),
|
||||
-- OpenRouter costs
|
||||
('google/gemini-2.5-pro-preview-03-25', 4),
|
||||
('google/gemini-2.5-pro', 4),
|
||||
('google/gemini-3.1-pro-preview', 5),
|
||||
('google/gemini-3-flash-preview', 3),
|
||||
('google/gemini-3.1-flash-lite-preview', 1),
|
||||
('mistralai/mistral-nemo', 1),
|
||||
('mistralai/mistral-large-2512', 3),
|
||||
('mistralai/mistral-medium-3.1', 2),
|
||||
('mistralai/mistral-small-3.2-24b-instruct', 1),
|
||||
('mistralai/codestral-2508', 2),
|
||||
('cohere/command-r-08-2024', 1),
|
||||
('cohere/command-r-plus-08-2024', 3),
|
||||
('cohere/command-a-03-2025', 2),
|
||||
('cohere/command-a-reasoning-08-2025', 3),
|
||||
('cohere/command-a-translate-08-2025', 1),
|
||||
('cohere/command-a-vision-07-2025', 2),
|
||||
('deepseek/deepseek-chat', 2),
|
||||
('perplexity/sonar', 1),
|
||||
('perplexity/sonar-pro', 5),
|
||||
('perplexity/sonar-deep-research', 10),
|
||||
('perplexity/sonar-reasoning-pro', 5),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 1),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 1),
|
||||
('amazon/nova-lite-v1', 1),
|
||||
('amazon/nova-micro-v1', 1),
|
||||
('amazon/nova-pro-v1', 1),
|
||||
('microsoft/wizardlm-2-8x22b', 1),
|
||||
('microsoft/phi-4', 1),
|
||||
('gryphe/mythomax-l2-13b', 1),
|
||||
('meta-llama/llama-4-scout', 1),
|
||||
('meta-llama/llama-4-maverick', 1),
|
||||
('x-ai/grok-3', 5),
|
||||
('x-ai/grok-4', 9),
|
||||
('x-ai/grok-4-fast', 1),
|
||||
('x-ai/grok-4.1-fast', 1),
|
||||
('x-ai/grok-code-fast-1', 1),
|
||||
('moonshotai/kimi-k2', 1),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 1),
|
||||
('qwen/qwen3-coder', 9),
|
||||
('google/gemini-2.5-flash', 1),
|
||||
('google/gemini-2.0-flash-001', 1),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 1),
|
||||
('google/gemini-2.0-flash-lite-001', 1),
|
||||
('deepseek/deepseek-r1-0528', 1),
|
||||
('openai/gpt-oss-120b', 1),
|
||||
('openai/gpt-oss-20b', 1),
|
||||
-- Llama API costs
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
|
||||
('Llama-3.3-8B-Instruct', 1),
|
||||
('Llama-3.3-70B-Instruct', 1),
|
||||
-- v0 costs
|
||||
('v0-1.5-md', 1),
|
||||
('v0-1.5-lg', 2),
|
||||
('v0-1.0-md', 1)
|
||||
) AS costs(model_slug, cost)
|
||||
JOIN model_ids m ON m."slug" = costs.model_slug
|
||||
JOIN provider_ids p ON p."id" = m."providerId"
|
||||
ON CONFLICT ("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL DO NOTHING;
|
||||
|
||||
@@ -1301,3 +1301,164 @@ model OAuthRefreshToken {
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LLM Registry Models
|
||||
// ============================================================================
|
||||
|
||||
enum LlmCostUnit {
|
||||
RUN
|
||||
TOKENS
|
||||
}
|
||||
|
||||
model LlmProvider {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
defaultCredentialProvider String?
|
||||
defaultCredentialId String?
|
||||
defaultCredentialType String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModel {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
slug String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
providerId String
|
||||
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
|
||||
|
||||
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
|
||||
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
|
||||
creatorId String?
|
||||
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
|
||||
|
||||
contextWindow Int
|
||||
maxOutputTokens Int?
|
||||
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive (DB constraint: 1-3)
|
||||
isEnabled Boolean @default(true)
|
||||
isRecommended Boolean @default(false)
|
||||
|
||||
// Model-specific capabilities
|
||||
// These vary per model even within the same provider (e.g., Hugging Face)
|
||||
// Default to false for safety - partially-seeded rows should not be assumed capable
|
||||
supportsTools Boolean @default(false)
|
||||
supportsJsonOutput Boolean @default(false)
|
||||
supportsReasoning Boolean @default(false)
|
||||
supportsParallelToolCalls Boolean @default(false)
|
||||
|
||||
capabilities Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
|
||||
Costs LlmModelCost[]
|
||||
SourceMigrations LlmModelMigration[] @relation("SourceMigrations")
|
||||
TargetMigrations LlmModelMigration[] @relation("TargetMigrations")
|
||||
|
||||
@@index([providerId, isEnabled])
|
||||
@@index([creatorId])
|
||||
// Note: slug already has @unique which creates an implicit index
|
||||
}
|
||||
|
||||
model LlmModelCost {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
unit LlmCostUnit @default(RUN)
|
||||
|
||||
creditCost Int // DB constraint: >= 0
|
||||
|
||||
// Provider identifier (e.g., "openai", "anthropic", "openrouter")
|
||||
// Used to determine which credential system provides the API key.
|
||||
// Allows different pricing for:
|
||||
// - Default provider costs (WHERE credentialId IS NULL)
|
||||
// - User's own API key costs (WHERE credentialId IS NOT NULL)
|
||||
credentialProvider String
|
||||
credentialId String?
|
||||
credentialType String?
|
||||
currency String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
llmModelId String
|
||||
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Note: Unique constraints are implemented as partial indexes in migration SQL:
|
||||
// - One for default costs (WHERE credentialId IS NULL)
|
||||
// - One for credential-specific costs (WHERE credentialId IS NOT NULL)
|
||||
// This allows both provider-level defaults and credential-specific overrides
|
||||
}
|
||||
|
||||
model LlmModelCreator {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique // e.g., "openai", "anthropic", "meta"
|
||||
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
|
||||
description String?
|
||||
websiteUrl String? // Link to creator's website
|
||||
logoUrl String? // URL to creator's logo
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModelMigration {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
sourceModelSlug String // The original model that was disabled
|
||||
targetModelSlug String // The model workflows were migrated to
|
||||
reason String? // Why the migration happened (e.g., "Provider outage")
|
||||
|
||||
// FK constraints ensure slugs reference valid models
|
||||
SourceModel LlmModel @relation("SourceMigrations", fields: [sourceModelSlug], references: [slug], onDelete: Restrict, onUpdate: Cascade)
|
||||
TargetModel LlmModel @relation("TargetMigrations", fields: [targetModelSlug], references: [slug], onDelete: Restrict, onUpdate: Cascade)
|
||||
|
||||
// Track affected nodes as JSON array of node IDs
|
||||
// Format: ["node-uuid-1", "node-uuid-2", ...]
|
||||
migratedNodeIds Json @default("[]")
|
||||
nodeCount Int // Number of nodes migrated (DB constraint: >= 0)
|
||||
|
||||
// Custom pricing override for migrated workflows during the migration period.
|
||||
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
|
||||
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
|
||||
// to avoid billing surprises, or offer a discount during the transition.
|
||||
//
|
||||
// IMPORTANT: This field is intended for integration with the billing system.
|
||||
// When billing calculates costs for nodes affected by this migration, it should
|
||||
// check if customCreditCost is set and use it instead of the target model's cost.
|
||||
// If null, the target model's normal cost applies.
|
||||
//
|
||||
// TODO: Integrate with billing system to apply this override during cost calculation.
|
||||
// LIMITATION: This is a simple Int and doesn't distinguish RUN vs TOKENS pricing.
|
||||
// For token-priced models, this may be ambiguous. Consider migrating to a relation
|
||||
// with LlmModelCost or a dedicated override model in a follow-up PR.
|
||||
customCreditCost Int? // DB constraint: >= 0 when not null
|
||||
|
||||
// Revert tracking
|
||||
isReverted Boolean @default(false)
|
||||
revertedAt DateTime?
|
||||
|
||||
// Note: Partial unique index in migration SQL prevents multiple active migrations per source:
|
||||
// UNIQUE (sourceModelSlug) WHERE isReverted = false
|
||||
@@index([targetModelSlug])
|
||||
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
import { Gauge, ChatsCircle } from "@phosphor-icons/react/dist/ssr";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
@@ -21,11 +22,21 @@ const sidebarLinkGroups = [
|
||||
href: "/admin/impersonation",
|
||||
icon: <UserSearch className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Rate Limits",
|
||||
href: "/admin/rate-limits",
|
||||
icon: <Gauge className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Execution Analytics",
|
||||
href: "/admin/execution-analytics",
|
||||
icon: <FileText className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "LLM Registry",
|
||||
href: "/admin/llms",
|
||||
icon: <ChatsCircle className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Admin User Management",
|
||||
href: "/admin/settings",
|
||||
|
||||
@@ -0,0 +1,411 @@
|
||||
"use server";
|
||||
|
||||
import { revalidatePath } from "next/cache";
|
||||
import {
|
||||
createRequestHeaders,
|
||||
getServerAuthToken,
|
||||
} from "@/lib/autogpt-server-api/helpers";
|
||||
import { environment } from "@/services/environment";
|
||||
|
||||
const ADMIN_LLM_PATH = "/admin/llms";
|
||||
|
||||
// =============================================================================
|
||||
// Authenticated Fetch Helper
|
||||
// =============================================================================
|
||||
|
||||
async function adminFetch(
|
||||
endpoint: string,
|
||||
options: RequestInit = {},
|
||||
): Promise<{ status: number; data: unknown }> {
|
||||
const baseUrl = environment.getAGPTServerBaseUrl();
|
||||
const token = await getServerAuthToken();
|
||||
const headers = createRequestHeaders(
|
||||
token,
|
||||
!!options.body,
|
||||
"application/json",
|
||||
);
|
||||
|
||||
const response = await fetch(`${baseUrl}${endpoint}`, {
|
||||
...options,
|
||||
headers: {
|
||||
...headers,
|
||||
...((options.headers as Record<string, string>) || {}),
|
||||
},
|
||||
});
|
||||
|
||||
let data: unknown = null;
|
||||
if (response.status !== 204) {
|
||||
const contentType = response.headers.get("content-type");
|
||||
const text = await response.text();
|
||||
if (text && contentType?.includes("application/json")) {
|
||||
try {
|
||||
data = JSON.parse(text);
|
||||
} catch {
|
||||
data = text;
|
||||
}
|
||||
} else {
|
||||
data = text;
|
||||
}
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = data as Record<string, string> | null;
|
||||
const errorMessage =
|
||||
errorData?.detail || errorData?.message || `HTTP ${response.status}`;
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
return { status: response.status, data };
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Utilities
|
||||
// =============================================================================
|
||||
|
||||
function getRequiredFormField(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): string {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = raw ? String(raw).trim() : "";
|
||||
if (!value) {
|
||||
throw new Error(`${displayName || fieldName} is required`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
function getRequiredPositiveNumber(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): number {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = Number(raw);
|
||||
if (raw === null || raw === "" || !Number.isFinite(value) || value <= 0) {
|
||||
throw new Error(`${displayName || fieldName} must be a positive number`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
function getRequiredNumber(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): number {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = Number(raw);
|
||||
if (raw === null || raw === "" || !Number.isFinite(value)) {
|
||||
throw new Error(`${displayName || fieldName} is required`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Provider Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmProviders() {
|
||||
const { data } = await adminFetch("/api/llm/admin/providers");
|
||||
return data;
|
||||
}
|
||||
|
||||
export async function createLlmProviderAction(formData: FormData) {
|
||||
const payload = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
default_credential_provider: formData.get("default_credential_provider")
|
||||
? String(formData.get("default_credential_provider")).trim()
|
||||
: undefined,
|
||||
default_credential_id: formData.get("default_credential_id")
|
||||
? String(formData.get("default_credential_id")).trim()
|
||||
: undefined,
|
||||
default_credential_type: formData.get("default_credential_type")
|
||||
? String(formData.get("default_credential_type")).trim()
|
||||
: "api_key",
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
await adminFetch("/api/llm/providers", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmProviderAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const providerName = getRequiredFormField(
|
||||
formData,
|
||||
"provider_id",
|
||||
"Provider",
|
||||
);
|
||||
await adminFetch(`/api/llm/providers/${providerName}`, { method: "DELETE" });
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmProviderAction(formData: FormData) {
|
||||
const providerName = getRequiredFormField(
|
||||
formData,
|
||||
"provider_id",
|
||||
"Provider",
|
||||
);
|
||||
|
||||
const payload = {
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
default_credential_provider: formData.get("default_credential_provider")
|
||||
? String(formData.get("default_credential_provider")).trim()
|
||||
: undefined,
|
||||
default_credential_id: formData.get("default_credential_id")
|
||||
? String(formData.get("default_credential_id")).trim()
|
||||
: undefined,
|
||||
default_credential_type: formData.get("default_credential_type")
|
||||
? String(formData.get("default_credential_type")).trim()
|
||||
: "api_key",
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
await adminFetch(`/api/llm/providers/${providerName}`, {
|
||||
method: "PATCH",
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Model Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmModels(page?: number, pageSize?: number) {
|
||||
const params = new URLSearchParams();
|
||||
if (page) params.set("page", String(page));
|
||||
if (pageSize) params.set("page_size", String(pageSize));
|
||||
params.set("enabled_only", "false");
|
||||
const query = params.toString() ? `?${params.toString()}` : "";
|
||||
const { data } = await adminFetch(`/api/llm/admin/models${query}`);
|
||||
return data;
|
||||
}
|
||||
|
||||
export async function createLlmModelAction(formData: FormData) {
|
||||
const creditCost = getRequiredNumber(formData, "credit_cost", "Credit cost");
|
||||
|
||||
const payload = {
|
||||
slug: String(formData.get("slug") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
provider_name: getRequiredFormField(formData, "provider_id", "Provider"),
|
||||
creator_id: formData.get("creator_id")
|
||||
? String(formData.get("creator_id"))
|
||||
: undefined,
|
||||
context_window: getRequiredPositiveNumber(
|
||||
formData,
|
||||
"context_window",
|
||||
"Context window",
|
||||
),
|
||||
max_output_tokens: formData.get("max_output_tokens")
|
||||
? Number(formData.get("max_output_tokens"))
|
||||
: undefined,
|
||||
price_tier: Number(formData.get("price_tier") || 1),
|
||||
is_enabled: formData.getAll("is_enabled").includes("on"),
|
||||
capabilities: {},
|
||||
metadata: {},
|
||||
costs: [
|
||||
{
|
||||
unit: String(formData.get("unit") || "RUN"),
|
||||
credit_cost: creditCost,
|
||||
metadata: {},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
await adminFetch("/api/llm/models", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmModelAction(formData: FormData) {
|
||||
const modelSlug = getRequiredFormField(formData, "model_id", "Model");
|
||||
|
||||
const payload: Record<string, string | number | boolean> = {};
|
||||
|
||||
const displayName = formData.get("display_name");
|
||||
if (displayName !== null && String(displayName).trim())
|
||||
payload.display_name = String(displayName).trim();
|
||||
const description = formData.get("description");
|
||||
if (description !== null) payload.description = String(description);
|
||||
const creatorId = formData.get("creator_id");
|
||||
if (creatorId !== null && String(creatorId).trim())
|
||||
payload.creator_id = String(creatorId).trim();
|
||||
const contextWindow = formData.get("context_window");
|
||||
if (contextWindow !== null && String(contextWindow).trim())
|
||||
payload.context_window = Number(contextWindow);
|
||||
const maxOutputTokens = formData.get("max_output_tokens");
|
||||
if (maxOutputTokens !== null && String(maxOutputTokens).trim())
|
||||
payload.max_output_tokens = Number(maxOutputTokens);
|
||||
if (formData.has("is_enabled"))
|
||||
payload.is_enabled = formData.getAll("is_enabled").includes("on");
|
||||
|
||||
await adminFetch(`/api/llm/models/${modelSlug}`, {
|
||||
method: "PATCH",
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function toggleLlmModelAction(formData: FormData): Promise<void> {
|
||||
const modelSlug = getRequiredFormField(formData, "model_id", "Model");
|
||||
const shouldEnable = formData.get("is_enabled") === "true";
|
||||
|
||||
const payload: Record<string, string | number | boolean> = {
|
||||
is_enabled: shouldEnable,
|
||||
};
|
||||
|
||||
// Migration params (only when disabling)
|
||||
if (!shouldEnable) {
|
||||
const migrateToSlug = formData.get("migrate_to_slug");
|
||||
if (migrateToSlug) payload.migrate_to_slug = String(migrateToSlug);
|
||||
const reason = formData.get("migration_reason");
|
||||
if (reason) payload.migration_reason = String(reason);
|
||||
const customCost = formData.get("custom_credit_cost");
|
||||
if (customCost) payload.custom_credit_cost = Number(customCost);
|
||||
}
|
||||
|
||||
await adminFetch(`/api/llm/models/${modelSlug}/toggle`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmModelAction(formData: FormData): Promise<void> {
|
||||
const modelSlug = getRequiredFormField(formData, "model_id", "Model");
|
||||
const replacementSlug = formData.get("replacement_model_slug");
|
||||
const params = new URLSearchParams();
|
||||
if (replacementSlug)
|
||||
params.set("replacement_model_slug", String(replacementSlug));
|
||||
const query = params.toString() ? `?${params.toString()}` : "";
|
||||
await adminFetch(`/api/llm/models/${modelSlug}${query}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function fetchLlmModelUsage(modelSlug: string) {
|
||||
const { data } = await adminFetch(`/api/llm/models/${modelSlug}/usage`);
|
||||
return data;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Migration Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmMigrations(includeReverted: boolean = false) {
|
||||
const params = new URLSearchParams();
|
||||
if (includeReverted) params.set("include_reverted", "true");
|
||||
const query = params.toString() ? `?${params.toString()}` : "";
|
||||
const { data } = await adminFetch(`/api/llm/migrations${query}`);
|
||||
return data;
|
||||
}
|
||||
|
||||
export async function revertLlmMigrationAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const migrationId = getRequiredFormField(
|
||||
formData,
|
||||
"migration_id",
|
||||
"Migration",
|
||||
);
|
||||
await adminFetch(`/api/llm/migrations/${migrationId}/revert`, {
|
||||
method: "POST",
|
||||
});
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Creator Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmCreators() {
|
||||
const { data } = await adminFetch(`/api/llm/creators`);
|
||||
return data;
|
||||
}
|
||||
|
||||
export async function createLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const payload = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
website_url: formData.get("website_url")
|
||||
? String(formData.get("website_url"))
|
||||
: undefined,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
await adminFetch("/api/llm/creators", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const creatorName = getRequiredFormField(formData, "creator_id", "Creator");
|
||||
|
||||
const payload: Record<string, string> = {};
|
||||
const displayName = formData.get("display_name");
|
||||
if (displayName !== null && String(displayName).trim())
|
||||
payload.display_name = String(displayName).trim();
|
||||
const description = formData.get("description");
|
||||
if (description !== null) payload.description = String(description);
|
||||
const websiteUrl = formData.get("website_url");
|
||||
if (websiteUrl !== null && String(websiteUrl).trim())
|
||||
payload.website_url = String(websiteUrl).trim();
|
||||
|
||||
await adminFetch(`/api/llm/creators/${creatorName}`, {
|
||||
method: "PATCH",
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const creatorName = getRequiredFormField(formData, "creator_id", "Creator");
|
||||
await adminFetch(`/api/llm/creators/${creatorName}`, { method: "DELETE" });
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Recommended Model Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function setRecommendedModelAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const modelSlug = getRequiredFormField(formData, "model_id", "Model");
|
||||
|
||||
// Set recommended by updating the model
|
||||
await adminFetch(`/api/llm/models/${modelSlug}`, {
|
||||
method: "PATCH",
|
||||
body: JSON.stringify({ is_recommended: true }),
|
||||
});
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { createLlmCreatorAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function AddCreatorModal() {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to create creator");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "512px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Creator
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Add a new model creator (the organization that made/trained the
|
||||
model).
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Name (slug) <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Lowercase identifier (e.g., openai, meta, anthropic)
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={2}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Creator of GPT models..."
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="website_url"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Website URL
|
||||
</label>
|
||||
<input
|
||||
id="website_url"
|
||||
name="website_url"
|
||||
type="url"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="https://openai.com"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Add Creator"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,314 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmProvider } from "../types";
|
||||
import type { LlmModelCreator } from "../types";
|
||||
import { createLlmModelAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
interface Props {
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}
|
||||
|
||||
export function AddModelModal({ providers, creators }: Props) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [selectedCreatorId, setSelectedCreatorId] = useState("");
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to create model");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
// When provider changes, auto-select matching creator if one exists
|
||||
function handleProviderChange(providerName: string) {
|
||||
const provider = providers.find((p) => p.name === providerName);
|
||||
if (provider) {
|
||||
// Find creator with same name as provider (e.g., "openai" -> "openai")
|
||||
const matchingCreator = creators.find((c) => c.name === provider.name);
|
||||
if (matchingCreator) {
|
||||
setSelectedCreatorId(matchingCreator.id);
|
||||
} else {
|
||||
// No matching creator (e.g., OpenRouter hosts other creators' models)
|
||||
setSelectedCreatorId("");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Model"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Model
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Register a new model slug, metadata, and pricing.
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core model details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="slug"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Model Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="slug"
|
||||
required
|
||||
name="slug"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="gpt-4.1-mini-2025-04-14"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="GPT 4.1 Mini"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Model Configuration */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Model Configuration
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Model capabilities and limits
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="provider_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<select
|
||||
id="provider_id"
|
||||
required
|
||||
name="provider_id"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
defaultValue=""
|
||||
onChange={(e) => handleProviderChange(e.target.value)}
|
||||
>
|
||||
<option value="" disabled>
|
||||
Select provider
|
||||
</option>
|
||||
{providers.map((provider) => (
|
||||
<option key={provider.name} value={provider.name}>
|
||||
{provider.display_name} ({provider.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Who hosts/serves the model
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="creator_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Creator
|
||||
</label>
|
||||
<select
|
||||
id="creator_id"
|
||||
name="creator_id"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
value={selectedCreatorId}
|
||||
onChange={(e) => setSelectedCreatorId(e.target.value)}
|
||||
>
|
||||
<option value="">No creator selected</option>
|
||||
{creators.map((creator) => (
|
||||
<option key={creator.id} value={creator.id}>
|
||||
{creator.display_name} ({creator.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Who made/trained the model (e.g., OpenAI, Meta)
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="context_window"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Context Window <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="context_window"
|
||||
required
|
||||
type="number"
|
||||
name="context_window"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="128000"
|
||||
min={1}
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="max_output_tokens"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Max Output Tokens
|
||||
</label>
|
||||
<input
|
||||
id="max_output_tokens"
|
||||
type="number"
|
||||
name="max_output_tokens"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="16384"
|
||||
min={1}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Pricing */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">Pricing</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credit cost per run (credentials are managed via the provider)
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-1">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="credit_cost"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credit Cost <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="credit_cost"
|
||||
required
|
||||
type="number"
|
||||
name="credit_cost"
|
||||
step="1"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="5"
|
||||
min={0}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credit cost is always in platform credits. Credentials are
|
||||
inherited from the selected provider.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Enabled Toggle */}
|
||||
<div className="flex items-center gap-3 border-t border-border pt-6">
|
||||
<input type="hidden" name="is_enabled" value="off" />
|
||||
<input
|
||||
id="is_enabled"
|
||||
type="checkbox"
|
||||
name="is_enabled"
|
||||
defaultChecked
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor="is_enabled"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Enabled by default
|
||||
</label>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Save Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { createLlmProviderAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function AddProviderModal() {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to create provider",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Provider
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Define a new upstream provider and default credential information.
|
||||
</div>
|
||||
|
||||
{/* Setup Instructions */}
|
||||
<div className="mb-6 rounded-lg border border-primary/30 bg-primary/5 p-4">
|
||||
<div className="space-y-2">
|
||||
<h4 className="text-sm font-semibold text-foreground">
|
||||
Before Adding a Provider
|
||||
</h4>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
To use a new provider, you must first configure its credentials in
|
||||
the backend:
|
||||
</p>
|
||||
<ol className="list-inside list-decimal space-y-1 text-xs text-muted-foreground">
|
||||
<li>
|
||||
Add the credential to{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
backend/integrations/credentials_store.py
|
||||
</code>{" "}
|
||||
with a UUID, provider name, and settings secret reference
|
||||
</li>
|
||||
<li>
|
||||
Add it to the{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>{" "}
|
||||
dictionary in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
backend/data/block_cost_config.py
|
||||
</code>
|
||||
</li>
|
||||
<li>
|
||||
Use the <strong>same provider name</strong> in the
|
||||
"Credential Provider" field below that matches the key
|
||||
in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</li>
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core provider details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="e.g. openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Default Credentials */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Default Credentials
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credential provider name that matches the key in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_provider"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Provider <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_provider"
|
||||
name="default_credential_provider"
|
||||
required
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
<strong>Important:</strong> This must exactly match the key in
|
||||
the{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>{" "}
|
||||
dictionary in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
block_cost_config.py
|
||||
</code>
|
||||
. Common values: "openai", "anthropic",
|
||||
"groq", "open_router", etc.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Capabilities */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Capabilities
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Provider feature flags
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
{[
|
||||
{ name: "supports_tools", label: "Supports tools" },
|
||||
{ name: "supports_json_output", label: "Supports JSON output" },
|
||||
{ name: "supports_reasoning", label: "Supports reasoning" },
|
||||
{
|
||||
name: "supports_parallel_tool",
|
||||
label: "Supports parallel tool calls",
|
||||
},
|
||||
].map(({ name, label }) => (
|
||||
<div
|
||||
key={name}
|
||||
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
|
||||
>
|
||||
<input type="hidden" name={name} value="off" />
|
||||
<input
|
||||
id={name}
|
||||
type="checkbox"
|
||||
name={name}
|
||||
defaultChecked={
|
||||
name !== "supports_reasoning" &&
|
||||
name !== "supports_parallel_tool"
|
||||
}
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor={name}
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Save Provider"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { LlmModelCreator } from "../types";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { updateLlmCreatorAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { DeleteCreatorModal } from "./DeleteCreatorModal";
|
||||
|
||||
export function CreatorsTable({ creators }: { creators: LlmModelCreator[] }) {
|
||||
if (!creators.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No creators registered yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Creator</TableHead>
|
||||
<TableHead>Description</TableHead>
|
||||
<TableHead>Website</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{creators.map((creator) => (
|
||||
<TableRow key={creator.id}>
|
||||
<TableCell>
|
||||
<div className="font-medium">{creator.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{creator.name}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span className="text-sm text-muted-foreground">
|
||||
{creator.description || "—"}
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{creator.website_url ? (
|
||||
<a
|
||||
href={creator.website_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-sm text-primary hover:underline"
|
||||
>
|
||||
{(() => {
|
||||
try {
|
||||
return new URL(creator.website_url).hostname;
|
||||
} catch {
|
||||
return creator.website_url;
|
||||
}
|
||||
})()}
|
||||
</a>
|
||||
) : (
|
||||
<span className="text-muted-foreground">—</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
<EditCreatorModal creator={creator} />
|
||||
<DeleteCreatorModal creator={creator} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EditCreatorModal({ creator }: { creator: LlmModelCreator }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to update creator");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "512px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small" className="min-w-0">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<input type="hidden" name="creator_id" value={creator.name} />
|
||||
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Name (slug)</label>
|
||||
<input
|
||||
required
|
||||
name="name"
|
||||
defaultValue={creator.name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Display Name</label>
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={creator.display_name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Description</label>
|
||||
<textarea
|
||||
name="description"
|
||||
rows={2}
|
||||
defaultValue={creator.description ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Website URL</label>
|
||||
<input
|
||||
name="website_url"
|
||||
type="url"
|
||||
defaultValue={creator.website_url ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Updating..." : "Update"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModelCreator } from "../types";
|
||||
import { deleteLlmCreatorAction } from "../actions";
|
||||
|
||||
export function DeleteCreatorModal({ creator }: { creator: LlmModelCreator }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to delete creator");
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "480px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{creator.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">
|
||||
({creator.name})
|
||||
</span>
|
||||
</p>
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Models using this creator will have their creator field
|
||||
cleared. This is safe and won't affect model
|
||||
functionality.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="creator_id" value={creator.name} />
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={isDeleting}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
{isDeleting ? "Deleting..." : "Delete Creator"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "../types";
|
||||
import { deleteLlmModelAction, fetchLlmModelUsage } from "../actions";
|
||||
|
||||
export function DeleteModelModal({
|
||||
model,
|
||||
availableModels,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
availableModels: LlmModel[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [selectedReplacement, setSelectedReplacement] = useState<string>("");
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [usageCount, setUsageCount] = useState<number | null>(null);
|
||||
const [usageLoading, setUsageLoading] = useState(false);
|
||||
const [usageError, setUsageError] = useState<string | null>(null);
|
||||
|
||||
// Filter out the current model and disabled models from replacement options
|
||||
const replacementOptions = availableModels.filter(
|
||||
(m) => m.id !== model.id && m.is_enabled,
|
||||
);
|
||||
|
||||
// Check if migration is required (has blocks using this model)
|
||||
const requiresMigration = usageCount !== null && usageCount > 0;
|
||||
|
||||
async function fetchUsage() {
|
||||
setUsageLoading(true);
|
||||
setUsageError(null);
|
||||
try {
|
||||
const usage = await fetchLlmModelUsage(model.slug);
|
||||
setUsageCount(usage.usage_count);
|
||||
} catch (err) {
|
||||
console.error("Failed to fetch model usage:", err);
|
||||
setUsageError("Failed to load usage count");
|
||||
setUsageCount(null);
|
||||
} finally {
|
||||
setUsageLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to delete model");
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if delete button should be enabled
|
||||
const canDelete =
|
||||
!isDeleting &&
|
||||
!usageLoading &&
|
||||
usageCount !== null &&
|
||||
(requiresMigration
|
||||
? selectedReplacement && replacementOptions.length > 0
|
||||
: true);
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Model"
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: async (isOpen) => {
|
||||
setOpen(isOpen);
|
||||
if (isOpen) {
|
||||
setUsageCount(null);
|
||||
setUsageError(null);
|
||||
setError(null);
|
||||
setSelectedReplacement("");
|
||||
await fetchUsage();
|
||||
}
|
||||
},
|
||||
}}
|
||||
styling={{ maxWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
{requiresMigration
|
||||
? "This action cannot be undone. All workflows using this model will be migrated to the replacement model you select."
|
||||
: "This action cannot be undone."}
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{model.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">({model.slug})</span>
|
||||
</p>
|
||||
{usageLoading && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Loading usage count...
|
||||
</p>
|
||||
)}
|
||||
{usageError && (
|
||||
<p className="mt-2 text-destructive">{usageError}</p>
|
||||
)}
|
||||
{!usageLoading && !usageError && usageCount !== null && (
|
||||
<p className="mt-2 font-semibold">
|
||||
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
|
||||
currently use this model
|
||||
</p>
|
||||
)}
|
||||
{requiresMigration && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
All workflows currently using this model will be
|
||||
automatically updated to use the replacement model you
|
||||
choose below.
|
||||
</p>
|
||||
)}
|
||||
{!usageLoading && usageCount === 0 && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
No workflows are using this model. It can be safely deleted.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.slug} />
|
||||
<input
|
||||
type="hidden"
|
||||
name="replacement_model_slug"
|
||||
value={selectedReplacement}
|
||||
/>
|
||||
|
||||
{requiresMigration && (
|
||||
<label className="text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Select Replacement Model{" "}
|
||||
<span className="text-destructive">*</span>
|
||||
</span>
|
||||
<select
|
||||
required
|
||||
value={selectedReplacement}
|
||||
onChange={(e) => setSelectedReplacement(e.target.value)}
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="">-- Choose a replacement model --</option>
|
||||
{replacementOptions.map((m) => (
|
||||
<option key={m.id} value={m.slug}>
|
||||
{m.display_name} ({m.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
{replacementOptions.length === 0 && (
|
||||
<p className="mt-2 text-xs text-destructive">
|
||||
No replacement models available. You must have at least one
|
||||
other enabled model before deleting this one.
|
||||
</p>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setSelectedReplacement("");
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={!canDelete}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
{isDeleting
|
||||
? "Deleting..."
|
||||
: requiresMigration
|
||||
? "Delete and Migrate"
|
||||
: "Delete"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmProvider } from "../types";
|
||||
import { deleteLlmProviderAction } from "../actions";
|
||||
|
||||
export function DeleteProviderModal({ provider }: { provider: LlmProvider }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
const modelCount = provider.models?.length ?? 0;
|
||||
const hasModels = modelCount > 0;
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to delete provider",
|
||||
);
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "480px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<div
|
||||
className={`rounded-lg border p-4 ${
|
||||
hasModels
|
||||
? "border-destructive/30 bg-destructive/10"
|
||||
: "border-amber-500/30 bg-amber-500/10 dark:border-amber-400/30 dark:bg-amber-400/10"
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-start gap-3">
|
||||
<div
|
||||
className={`flex-shrink-0 ${
|
||||
hasModels
|
||||
? "text-destructive"
|
||||
: "text-amber-600 dark:text-amber-400"
|
||||
}`}
|
||||
>
|
||||
{hasModels ? "🚫" : "⚠️"}
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{provider.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">
|
||||
({provider.name})
|
||||
</span>
|
||||
</p>
|
||||
{hasModels ? (
|
||||
<p className="mt-2 text-destructive">
|
||||
This provider has {modelCount} model(s). You must delete all
|
||||
models before you can delete this provider.
|
||||
</p>
|
||||
) : (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
This provider has no models and can be safely deleted.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="provider_id" value={provider.name} />
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={isDeleting || hasModels}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90 disabled:opacity-50"
|
||||
>
|
||||
{isDeleting ? "Deleting..." : "Delete Provider"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,365 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "../types";
|
||||
import { toggleLlmModelAction, fetchLlmModelUsage } from "../actions";
|
||||
|
||||
export function DisableModelModal({
|
||||
model,
|
||||
availableModels,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
availableModels: LlmModel[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDisabling, setIsDisabling] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [usageCount, setUsageCount] = useState<number | null>(null);
|
||||
const [selectedMigration, setSelectedMigration] = useState<string>("");
|
||||
const [migrationReason, setMigrationReason] = useState("");
|
||||
const [customCreditCost, setCustomCreditCost] = useState<string>("");
|
||||
|
||||
const migrationOptions = availableModels.filter(
|
||||
(m) => m.id !== model.id && m.is_enabled,
|
||||
);
|
||||
|
||||
async function fetchUsage() {
|
||||
try {
|
||||
const usage = await fetchLlmModelUsage(model.slug);
|
||||
setUsageCount(usage.node_count);
|
||||
} catch {
|
||||
setUsageCount(null);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDisable(formData: FormData) {
|
||||
setIsDisabling(true);
|
||||
setError(null);
|
||||
try {
|
||||
await toggleLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to disable model");
|
||||
} finally {
|
||||
setIsDisabling(false);
|
||||
}
|
||||
}
|
||||
|
||||
function resetState() {
|
||||
setError(null);
|
||||
setSelectedMigration("");
|
||||
setMigrationReason("");
|
||||
setCustomCreditCost("");
|
||||
setUsageCount(null);
|
||||
}
|
||||
|
||||
const hasUsage = usageCount !== null && usageCount > 0;
|
||||
const isLoading = usageCount === null && !model.is_recommended;
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Disable Model"
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: async (isOpen) => {
|
||||
setOpen(isOpen);
|
||||
if (isOpen) {
|
||||
resetState();
|
||||
if (!model.is_recommended) {
|
||||
await fetchUsage();
|
||||
}
|
||||
}
|
||||
},
|
||||
}}
|
||||
styling={{ maxWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0"
|
||||
>
|
||||
Disable
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
{model.is_recommended ? (
|
||||
<RecommendedModelBlock model={model} onClose={() => setOpen(false)} />
|
||||
) : (
|
||||
<DisableForm
|
||||
model={model}
|
||||
usageCount={usageCount}
|
||||
isLoading={isLoading}
|
||||
hasUsage={hasUsage}
|
||||
migrationOptions={migrationOptions}
|
||||
selectedMigration={selectedMigration}
|
||||
setSelectedMigration={setSelectedMigration}
|
||||
migrationReason={migrationReason}
|
||||
setMigrationReason={setMigrationReason}
|
||||
customCreditCost={customCreditCost}
|
||||
setCustomCreditCost={setCustomCreditCost}
|
||||
isDisabling={isDisabling}
|
||||
error={error}
|
||||
onClose={() => {
|
||||
setOpen(false);
|
||||
resetState();
|
||||
}}
|
||||
onSubmit={handleDisable}
|
||||
/>
|
||||
)}
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
function RecommendedModelBlock({
|
||||
model,
|
||||
onClose,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
onClose: () => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-4">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-destructive">🔒</div>
|
||||
<div className="text-sm">
|
||||
<p className="font-semibold text-destructive">
|
||||
Cannot disable the recommended model
|
||||
</p>
|
||||
<p className="mt-1 text-foreground">
|
||||
<span className="font-medium">{model.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">({model.slug})</span> is
|
||||
currently set as the recommended model.
|
||||
</p>
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Change the recommended model to a different enabled model before
|
||||
disabling this one.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<Dialog.Footer>
|
||||
<Button variant="ghost" size="small" onClick={onClose}>
|
||||
Close
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface DisableFormProps {
|
||||
model: LlmModel;
|
||||
usageCount: number | null;
|
||||
isLoading: boolean;
|
||||
hasUsage: boolean;
|
||||
migrationOptions: LlmModel[];
|
||||
selectedMigration: string;
|
||||
setSelectedMigration: (v: string) => void;
|
||||
migrationReason: string;
|
||||
setMigrationReason: (v: string) => void;
|
||||
customCreditCost: string;
|
||||
setCustomCreditCost: (v: string) => void;
|
||||
isDisabling: boolean;
|
||||
error: string | null;
|
||||
onClose: () => void;
|
||||
onSubmit: (formData: FormData) => Promise<void>;
|
||||
}
|
||||
|
||||
function DisableForm({
|
||||
model,
|
||||
usageCount,
|
||||
isLoading,
|
||||
hasUsage,
|
||||
migrationOptions,
|
||||
selectedMigration,
|
||||
setSelectedMigration,
|
||||
migrationReason,
|
||||
setMigrationReason,
|
||||
customCreditCost,
|
||||
setCustomCreditCost,
|
||||
isDisabling,
|
||||
error,
|
||||
onClose,
|
||||
onSubmit,
|
||||
}: DisableFormProps) {
|
||||
const submitDisabled =
|
||||
isDisabling ||
|
||||
isLoading ||
|
||||
(hasUsage && !selectedMigration) ||
|
||||
(hasUsage && migrationOptions.length === 0);
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="text-sm text-muted-foreground">
|
||||
Disabling a model will hide it from users when creating new workflows.
|
||||
</div>
|
||||
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600">⚠️</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to disable:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{model.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">({model.slug})</span>
|
||||
</p>
|
||||
{isLoading ? (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Loading usage data...
|
||||
</p>
|
||||
) : hasUsage ? (
|
||||
<p className="mt-2 font-semibold text-amber-700">
|
||||
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
|
||||
currently use this model — migration required
|
||||
</p>
|
||||
) : (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
No workflows are currently using this model.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{hasUsage && (
|
||||
<div className="space-y-4 rounded-lg border border-border bg-muted/50 p-4">
|
||||
<div className="text-sm">
|
||||
<p className="font-medium">Migration required</p>
|
||||
<p className="mt-1 text-muted-foreground">
|
||||
Workflows using this model must be migrated to a replacement
|
||||
before disabling. This creates a revertible migration record.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4 border-t border-border pt-4">
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Replacement Model <span className="text-destructive">*</span>
|
||||
</span>
|
||||
<select
|
||||
required
|
||||
value={selectedMigration}
|
||||
onChange={(e) => setSelectedMigration(e.target.value)}
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="">-- Choose a replacement model --</option>
|
||||
{migrationOptions.map((m) => (
|
||||
<option key={m.id} value={m.slug}>
|
||||
{m.display_name} ({m.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
{migrationOptions.length === 0 && (
|
||||
<p className="mt-2 text-xs text-destructive">
|
||||
No other enabled models available. Enable another model before
|
||||
disabling this one.
|
||||
</p>
|
||||
)}
|
||||
</label>
|
||||
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Migration Reason{" "}
|
||||
<span className="font-normal text-muted-foreground">
|
||||
(optional)
|
||||
</span>
|
||||
</span>
|
||||
<input
|
||||
type="text"
|
||||
value={migrationReason}
|
||||
onChange={(e) => setMigrationReason(e.target.value)}
|
||||
placeholder="e.g., Provider outage, Cost reduction"
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
</label>
|
||||
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Custom Credit Cost{" "}
|
||||
<span className="font-normal text-muted-foreground">
|
||||
(optional)
|
||||
</span>
|
||||
</span>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
value={customCreditCost}
|
||||
onChange={(e) => setCustomCreditCost(e.target.value)}
|
||||
placeholder="Leave blank to use target model's cost"
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
<p className="mt-1 text-xs text-muted-foreground">
|
||||
Override pricing for migrated workflows.
|
||||
</p>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<form action={onSubmit} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.slug} />
|
||||
<input type="hidden" name="is_enabled" value="false" />
|
||||
{hasUsage && selectedMigration && (
|
||||
<>
|
||||
<input
|
||||
type="hidden"
|
||||
name="migrate_to_slug"
|
||||
value={selectedMigration}
|
||||
/>
|
||||
{migrationReason && (
|
||||
<input
|
||||
type="hidden"
|
||||
name="migration_reason"
|
||||
value={migrationReason}
|
||||
/>
|
||||
)}
|
||||
{customCreditCost && (
|
||||
<input
|
||||
type="hidden"
|
||||
name="custom_credit_cost"
|
||||
value={customCreditCost}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={onClose}
|
||||
disabled={isDisabling}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={submitDisabled}
|
||||
>
|
||||
{isDisabling
|
||||
? "Disabling..."
|
||||
: hasUsage && selectedMigration
|
||||
? "Disable & Migrate"
|
||||
: "Disable Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "../types";
|
||||
import type { LlmModelCreator } from "../types";
|
||||
import type { LlmProvider } from "../types";
|
||||
import { updateLlmModelAction } from "../actions";
|
||||
|
||||
export function EditModelModal({
|
||||
model,
|
||||
providers,
|
||||
creators,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const cost = model.costs?.[0];
|
||||
const provider = providers.find((p) => p.id === model.provider_id);
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to update model");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Model"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small" className="min-w-0">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Update model metadata and pricing information.
|
||||
</div>
|
||||
{error && (
|
||||
<div className="mb-4 rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.slug} />
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Display Name
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={model.display_name}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Provider
|
||||
<select
|
||||
required
|
||||
name="provider_id"
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
defaultValue={model.provider_id}
|
||||
>
|
||||
{providers.map((p) => (
|
||||
<option key={p.id} value={p.id}>
|
||||
{p.display_name} ({p.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Who hosts/serves the model
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Creator
|
||||
<select
|
||||
name="creator_id"
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
defaultValue={model.creator_id ?? ""}
|
||||
>
|
||||
<option value="">No creator selected</option>
|
||||
{creators.map((c) => (
|
||||
<option key={c.id} value={c.id}>
|
||||
{c.display_name} ({c.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Who made/trained the model (e.g., OpenAI, Meta)
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<label className="text-sm font-medium">
|
||||
Description
|
||||
<textarea
|
||||
name="description"
|
||||
rows={2}
|
||||
defaultValue={model.description ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</label>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Context Window
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="context_window"
|
||||
defaultValue={model.context_window}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Max Output Tokens
|
||||
<input
|
||||
type="number"
|
||||
name="max_output_tokens"
|
||||
defaultValue={model.max_output_tokens ?? undefined}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Credit Cost
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="credit_cost"
|
||||
defaultValue={cost?.credit_cost ?? 0}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={0}
|
||||
/>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Credits charged per run
|
||||
</span>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Credential Provider
|
||||
<select
|
||||
required
|
||||
name="credential_provider"
|
||||
defaultValue={cost?.credential_provider ?? provider?.name ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="" disabled>
|
||||
Select provider
|
||||
</option>
|
||||
{providers.map((p) => (
|
||||
<option key={p.id} value={p.name}>
|
||||
{p.display_name} ({p.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Must match a key in PROVIDER_CREDENTIALS
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
{/* Hidden defaults for credential_type and unit */}
|
||||
<input
|
||||
type="hidden"
|
||||
name="credential_type"
|
||||
value={
|
||||
cost?.credential_type ??
|
||||
provider?.default_credential_type ??
|
||||
"api_key"
|
||||
}
|
||||
/>
|
||||
<input type="hidden" name="unit" value={cost?.unit ?? "RUN"} />
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => setOpen(false)}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Updating..." : "Update Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,263 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { updateLlmProviderAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { LlmProvider } from "../types";
|
||||
|
||||
export function EditProviderModal({ provider }: { provider: LlmProvider }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to update provider",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Update provider configuration and capabilities.
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
<input type="hidden" name="provider_id" value={provider.name} />
|
||||
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core provider details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
defaultValue={provider.name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="e.g. openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={provider.display_name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
defaultValue={provider.description ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Default Credentials */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Default Credentials
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credential provider name that matches the key in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_provider"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Provider
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_provider"
|
||||
name="default_credential_provider"
|
||||
defaultValue={provider.default_credential_provider ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential ID
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_id"
|
||||
name="default_credential_id"
|
||||
defaultValue={provider.default_credential_id ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional credential ID"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_type"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Type
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_type"
|
||||
name="default_credential_type"
|
||||
defaultValue={provider.default_credential_type ?? "api_key"}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="api_key"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Capabilities */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Capabilities
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Provider feature flags
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
{[
|
||||
{
|
||||
name: "supports_tools",
|
||||
label: "Supports tools",
|
||||
checked: provider.supports_tools,
|
||||
},
|
||||
{
|
||||
name: "supports_json_output",
|
||||
label: "Supports JSON output",
|
||||
checked: provider.supports_json_output,
|
||||
},
|
||||
{
|
||||
name: "supports_reasoning",
|
||||
label: "Supports reasoning",
|
||||
checked: provider.supports_reasoning,
|
||||
},
|
||||
{
|
||||
name: "supports_parallel_tool",
|
||||
label: "Supports parallel tool calls",
|
||||
checked: provider.supports_parallel_tool,
|
||||
},
|
||||
].map(({ name, label, checked }) => (
|
||||
<div
|
||||
key={name}
|
||||
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
|
||||
>
|
||||
<input type="hidden" name={name} value="off" />
|
||||
<input
|
||||
id={name}
|
||||
type="checkbox"
|
||||
name={name}
|
||||
defaultChecked={checked}
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor={name}
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Saving..." : "Save Changes"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
"use client";
|
||||
|
||||
import type { LlmModel } from "../types";
|
||||
import type { LlmModelCreator } from "../types";
|
||||
import type { LlmModelMigration } from "../types";
|
||||
import type { LlmProvider } from "../types";
|
||||
import { ErrorBoundary } from "@/components/molecules/ErrorBoundary/ErrorBoundary";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { AddProviderModal } from "./AddProviderModal";
|
||||
import { AddModelModal } from "./AddModelModal";
|
||||
import { AddCreatorModal } from "./AddCreatorModal";
|
||||
import { ProviderList } from "./ProviderList";
|
||||
import { ModelsTable } from "./ModelsTable";
|
||||
import { MigrationsTable } from "./MigrationsTable";
|
||||
import { CreatorsTable } from "./CreatorsTable";
|
||||
import { RecommendedModelSelector } from "./RecommendedModelSelector";
|
||||
|
||||
interface Props {
|
||||
providers: LlmProvider[];
|
||||
models: LlmModel[];
|
||||
migrations: LlmModelMigration[];
|
||||
creators: LlmModelCreator[];
|
||||
}
|
||||
|
||||
function AdminErrorFallback() {
|
||||
return (
|
||||
<div className="mx-auto max-w-xl p-6">
|
||||
<ErrorCard
|
||||
responseError={{
|
||||
message:
|
||||
"An error occurred while loading the LLM Registry. Please refresh the page.",
|
||||
}}
|
||||
context="llm-registry"
|
||||
onRetry={() => window.location.reload()}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function LlmRegistryDashboard({
|
||||
providers,
|
||||
models,
|
||||
migrations,
|
||||
creators,
|
||||
}: Props) {
|
||||
return (
|
||||
<ErrorBoundary fallback={<AdminErrorFallback />} context="llm-registry">
|
||||
<div className="mx-auto p-6">
|
||||
<div className="flex flex-col gap-6">
|
||||
{/* Header */}
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">LLM Registry</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Manage providers, creators, models, and credit pricing
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Active Migrations Section - Only show if there are migrations */}
|
||||
{migrations.length > 0 && (
|
||||
<div className="rounded-lg border border-primary/30 bg-primary/5 p-6 shadow-sm">
|
||||
<div className="mb-4">
|
||||
<h2 className="text-xl font-semibold">Active Migrations</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
These migrations can be reverted to restore workflows to their
|
||||
original model
|
||||
</p>
|
||||
</div>
|
||||
<MigrationsTable migrations={migrations} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Providers & Creators Section - Side by Side */}
|
||||
<div className="grid gap-6 lg:grid-cols-2">
|
||||
{/* Providers */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Providers</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Who hosts/serves the models
|
||||
</p>
|
||||
</div>
|
||||
<AddProviderModal />
|
||||
</div>
|
||||
<ProviderList providers={providers} />
|
||||
</div>
|
||||
|
||||
{/* Creators */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Creators</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Who made/trained the models
|
||||
</p>
|
||||
</div>
|
||||
<AddCreatorModal />
|
||||
</div>
|
||||
<CreatorsTable creators={creators} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Models Section */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Models</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Toggle availability, adjust context windows, and update credit
|
||||
pricing
|
||||
</p>
|
||||
</div>
|
||||
<AddModelModal providers={providers} creators={creators} />
|
||||
</div>
|
||||
|
||||
{/* Recommended Model Selector */}
|
||||
<div className="mb-6">
|
||||
<RecommendedModelSelector models={models} />
|
||||
</div>
|
||||
|
||||
<ModelsTable
|
||||
models={models}
|
||||
providers={providers}
|
||||
creators={creators}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { LlmModelMigration } from "../types";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { revertLlmMigrationAction } from "../actions";
|
||||
|
||||
export function MigrationsTable({
|
||||
migrations,
|
||||
}: {
|
||||
migrations: LlmModelMigration[];
|
||||
}) {
|
||||
if (!migrations.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No active migrations. Migrations are created when you disable a model
|
||||
with the "Migrate existing workflows" option.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Migration</TableHead>
|
||||
<TableHead>Reason</TableHead>
|
||||
<TableHead>Nodes Affected</TableHead>
|
||||
<TableHead>Custom Cost</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead className="text-right">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{migrations.map((migration) => (
|
||||
<MigrationRow key={migration.id} migration={migration} />
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function MigrationRow({ migration }: { migration: LlmModelMigration }) {
|
||||
const [isReverting, setIsReverting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
async function handleRevert(formData: FormData) {
|
||||
setIsReverting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await revertLlmMigrationAction(formData);
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to revert migration",
|
||||
);
|
||||
} finally {
|
||||
setIsReverting(false);
|
||||
}
|
||||
}
|
||||
|
||||
const createdDate = new Date(migration.created_at);
|
||||
|
||||
return (
|
||||
<>
|
||||
<TableRow>
|
||||
<TableCell>
|
||||
<div className="text-sm">
|
||||
<span className="font-medium">{migration.source_model_slug}</span>
|
||||
<span className="mx-2 text-muted-foreground">→</span>
|
||||
<span className="font-medium">{migration.target_model_slug}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{migration.reason || "—"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm">{migration.node_count}</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm">
|
||||
{migration.custom_credit_cost !== null &&
|
||||
migration.custom_credit_cost !== undefined
|
||||
? `${migration.custom_credit_cost} credits`
|
||||
: "—"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{createdDate.toLocaleDateString()}{" "}
|
||||
{createdDate.toLocaleTimeString([], {
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
})}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<form action={handleRevert} className="inline">
|
||||
<input type="hidden" name="migration_id" value={migration.id} />
|
||||
<Button
|
||||
type="submit"
|
||||
variant="outline"
|
||||
size="small"
|
||||
disabled={isReverting}
|
||||
>
|
||||
{isReverting ? "Reverting..." : "Revert"}
|
||||
</Button>
|
||||
</form>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
{error && (
|
||||
<TableRow>
|
||||
<TableCell colSpan={6}>
|
||||
<div className="rounded border border-destructive/30 bg-destructive/10 p-2 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import type { LlmModel } from "../types";
|
||||
import type { LlmModelCreator } from "../types";
|
||||
import type { LlmProvider } from "../types";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { toggleLlmModelAction, fetchLlmModels } from "../actions";
|
||||
import { DeleteModelModal } from "./DeleteModelModal";
|
||||
import { DisableModelModal } from "./DisableModelModal";
|
||||
import { EditModelModal } from "./EditModelModal";
|
||||
import { Star, Spinner } from "@phosphor-icons/react";
|
||||
|
||||
const PAGE_SIZE = 50;
|
||||
|
||||
export function ModelsTable({
|
||||
models: initialModels,
|
||||
providers,
|
||||
creators,
|
||||
}: {
|
||||
models: LlmModel[];
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}) {
|
||||
const [models, setModels] = useState<LlmModel[]>(initialModels);
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
const [hasMore, setHasMore] = useState(initialModels.length === PAGE_SIZE);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const loadedPagesRef = useRef(1);
|
||||
|
||||
// Sync with parent when initialModels changes (e.g., after enable/disable)
|
||||
// Re-fetch all loaded pages to preserve expanded state
|
||||
useEffect(() => {
|
||||
async function refetchAllPages() {
|
||||
const pagesToLoad = loadedPagesRef.current;
|
||||
|
||||
if (pagesToLoad === 1) {
|
||||
// Only first page loaded, just use initialModels
|
||||
setModels(initialModels);
|
||||
setHasMore(initialModels.length === PAGE_SIZE);
|
||||
return;
|
||||
}
|
||||
|
||||
// Re-fetch all pages we had loaded
|
||||
const allModels: LlmModel[] = [...initialModels];
|
||||
let lastPageHadFullResults = initialModels.length === PAGE_SIZE;
|
||||
|
||||
for (let page = 2; page <= pagesToLoad; page++) {
|
||||
try {
|
||||
const response = await fetchLlmModels(page, PAGE_SIZE);
|
||||
allModels.push(...response.models);
|
||||
lastPageHadFullResults = response.models.length === PAGE_SIZE;
|
||||
} catch (err) {
|
||||
console.error(`Error refetching page ${page}:`, err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
setModels(allModels);
|
||||
setHasMore(lastPageHadFullResults);
|
||||
}
|
||||
|
||||
refetchAllPages();
|
||||
}, [initialModels]);
|
||||
|
||||
async function loadMore() {
|
||||
if (isLoading) return;
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const nextPage = currentPage + 1;
|
||||
const response = await fetchLlmModels(nextPage, PAGE_SIZE);
|
||||
|
||||
setModels((prev) => [...prev, ...response.models]);
|
||||
setCurrentPage(nextPage);
|
||||
loadedPagesRef.current = nextPage;
|
||||
setHasMore(response.models.length === PAGE_SIZE);
|
||||
} catch (err) {
|
||||
console.error("Error loading more models:", err);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
if (!models.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No models registered yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const providerLookup = new Map(
|
||||
providers.map((provider) => [provider.id, provider]),
|
||||
);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Model</TableHead>
|
||||
<TableHead>Provider</TableHead>
|
||||
<TableHead>Creator</TableHead>
|
||||
<TableHead>Context Window</TableHead>
|
||||
<TableHead>Max Output</TableHead>
|
||||
<TableHead>Cost</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{models.map((model) => {
|
||||
const cost = model.costs?.[0];
|
||||
const provider = providerLookup.get(model.provider_id);
|
||||
return (
|
||||
<TableRow
|
||||
key={model.id}
|
||||
className={model.is_enabled ? "" : "opacity-60"}
|
||||
>
|
||||
<TableCell>
|
||||
<div className="font-medium">{model.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.slug}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{provider ? (
|
||||
<>
|
||||
<div>{provider.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{provider.name}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
model.provider_id
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{model.creator ? (
|
||||
<>
|
||||
<div>{model.creator.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.creator.name}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<span className="text-muted-foreground">—</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>{model.context_window.toLocaleString()}</TableCell>
|
||||
<TableCell>
|
||||
{model.max_output_tokens
|
||||
? model.max_output_tokens.toLocaleString()
|
||||
: "—"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{cost ? (
|
||||
<>
|
||||
<div className="font-medium">
|
||||
{cost.credit_cost} credits
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{cost.credential_provider}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
"—"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex flex-col gap-1">
|
||||
<span
|
||||
className={`inline-flex rounded-full px-2.5 py-1 text-xs font-semibold ${
|
||||
model.is_enabled
|
||||
? "bg-primary/10 text-primary"
|
||||
: "bg-muted text-muted-foreground"
|
||||
}`}
|
||||
>
|
||||
{model.is_enabled ? "Enabled" : "Disabled"}
|
||||
</span>
|
||||
{model.is_recommended && (
|
||||
<span className="inline-flex items-center gap-1 rounded-full bg-amber-500/10 px-2.5 py-1 text-xs font-semibold text-amber-600 dark:text-amber-400">
|
||||
<Star size={12} weight="fill" />
|
||||
Recommended
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{model.is_enabled ? (
|
||||
<DisableModelModal
|
||||
model={model}
|
||||
availableModels={models}
|
||||
/>
|
||||
) : (
|
||||
<EnableModelButton modelId={model.slug} />
|
||||
)}
|
||||
<EditModelModal
|
||||
model={model}
|
||||
providers={providers}
|
||||
creators={creators}
|
||||
/>
|
||||
<DeleteModelModal
|
||||
model={model}
|
||||
availableModels={models}
|
||||
/>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
{hasMore && (
|
||||
<div className="mt-4 flex justify-center">
|
||||
<Button onClick={loadMore} disabled={isLoading} variant="outline">
|
||||
{isLoading ? (
|
||||
<>
|
||||
<Spinner className="mr-2 h-4 w-4 animate-spin" />
|
||||
Loading...
|
||||
</>
|
||||
) : (
|
||||
"Load More"
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EnableModelButton({ modelId }: { modelId: string }) {
|
||||
return (
|
||||
<form action={toggleLlmModelAction} className="inline">
|
||||
<input type="hidden" name="model_id" value={modelId} />
|
||||
<input type="hidden" name="is_enabled" value="true" />
|
||||
<Button type="submit" variant="outline" size="small" className="min-w-0">
|
||||
Enable
|
||||
</Button>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import type { LlmProvider } from "../types";
|
||||
import { DeleteProviderModal } from "./DeleteProviderModal";
|
||||
import { EditProviderModal } from "./EditProviderModal";
|
||||
|
||||
export function ProviderList({ providers }: { providers: LlmProvider[] }) {
|
||||
if (!providers.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No providers configured yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Display Name</TableHead>
|
||||
<TableHead>Default Credential</TableHead>
|
||||
<TableHead>Capabilities</TableHead>
|
||||
<TableHead>Models</TableHead>
|
||||
<TableHead className="w-[100px]">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{providers.map((provider) => (
|
||||
<TableRow key={provider.id}>
|
||||
<TableCell className="font-medium">{provider.name}</TableCell>
|
||||
<TableCell>{provider.display_name}</TableCell>
|
||||
<TableCell>
|
||||
{provider.default_credential_provider
|
||||
? `${provider.default_credential_provider} (${provider.default_credential_id ?? "id?"})`
|
||||
: "—"}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-muted-foreground">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{provider.supports_tools && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Tools
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_json_output && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
JSON
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_reasoning && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Reasoning
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_parallel_tool && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Parallel Tools
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-sm">
|
||||
<span
|
||||
className={
|
||||
(provider.models?.length ?? 0) > 0
|
||||
? "text-foreground"
|
||||
: "text-muted-foreground"
|
||||
}
|
||||
>
|
||||
{provider.models?.length ?? 0}
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex gap-2">
|
||||
<EditProviderModal provider={provider} />
|
||||
<DeleteProviderModal provider={provider} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { LlmModel } from "../types";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { setRecommendedModelAction } from "../actions";
|
||||
import { Star } from "@phosphor-icons/react";
|
||||
|
||||
export function RecommendedModelSelector({ models }: { models: LlmModel[] }) {
|
||||
const router = useRouter();
|
||||
const enabledModels = models.filter((m) => m.is_enabled);
|
||||
const currentRecommended = models.find((m) => m.is_recommended);
|
||||
|
||||
const [selectedModelSlug, setSelectedModelSlug] = useState<string>(
|
||||
currentRecommended?.slug || "",
|
||||
);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const hasChanges = selectedModelSlug !== (currentRecommended?.slug || "");
|
||||
|
||||
async function handleSave() {
|
||||
if (!selectedModelSlug) return;
|
||||
|
||||
setIsSaving(true);
|
||||
setError(null);
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.set("model_id", selectedModelSlug);
|
||||
await setRecommendedModelAction(formData);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to save");
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border border-border bg-card p-4">
|
||||
<div className="mb-3 flex items-center gap-2">
|
||||
<Star size={20} weight="fill" className="text-amber-500" />
|
||||
<h3 className="text-sm font-semibold">Recommended Model</h3>
|
||||
</div>
|
||||
<p className="mb-3 text-xs text-muted-foreground">
|
||||
The recommended model is shown as the default suggestion in model
|
||||
selection dropdowns throughout the platform.
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-3">
|
||||
<select
|
||||
value={selectedModelSlug}
|
||||
onChange={(e) => setSelectedModelSlug(e.target.value)}
|
||||
className="flex-1 rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
disabled={isSaving}
|
||||
>
|
||||
<option value="">-- Select a model --</option>
|
||||
{enabledModels.map((model) => (
|
||||
<option key={model.slug} value={model.slug}>
|
||||
{model.display_name} ({model.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleSave}
|
||||
disabled={!hasChanges || !selectedModelSlug || isSaving}
|
||||
>
|
||||
{isSaving ? "Saving..." : "Save"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{error && <p className="mt-2 text-xs text-destructive">{error}</p>}
|
||||
|
||||
{currentRecommended && !hasChanges && (
|
||||
<p className="mt-2 text-xs text-muted-foreground">
|
||||
Currently set to:{" "}
|
||||
<span className="font-medium">{currentRecommended.display_name}</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
import {
|
||||
fetchLlmProviders,
|
||||
fetchLlmModels,
|
||||
fetchLlmCreators,
|
||||
fetchLlmMigrations,
|
||||
} from "./actions";
|
||||
|
||||
export async function getLlmRegistryPageData() {
|
||||
// Fetch all data in parallel
|
||||
const [providersData, modelsData, creatorsData, migrationsData] =
|
||||
await Promise.all([
|
||||
fetchLlmProviders(),
|
||||
fetchLlmModels(),
|
||||
fetchLlmCreators(),
|
||||
fetchLlmMigrations(),
|
||||
]);
|
||||
|
||||
return {
|
||||
providers: providersData.providers || [],
|
||||
models: modelsData.models || [],
|
||||
creators: creatorsData.creators || [],
|
||||
migrations: migrationsData.migrations || [],
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { getLlmRegistryPageData } from "./getLlmRegistryPage";
|
||||
import { LlmRegistryDashboard } from "./components/LlmRegistryDashboard";
|
||||
|
||||
async function LlmRegistryPage() {
|
||||
const data = await getLlmRegistryPageData();
|
||||
return <LlmRegistryDashboard {...data} />;
|
||||
}
|
||||
|
||||
export default async function AdminLlmRegistryPage() {
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedLlmRegistryPage = await withAdminAccess(LlmRegistryPage);
|
||||
return <ProtectedLlmRegistryPage />;
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
// Type definitions for LLM registry admin UI
|
||||
// These match the API response formats from our admin endpoints
|
||||
|
||||
export interface LlmProvider {
|
||||
id: string;
|
||||
name: string;
|
||||
display_name: string;
|
||||
description: string | null;
|
||||
default_credential_provider: string | null;
|
||||
default_credential_id: string | null;
|
||||
default_credential_type: string | null;
|
||||
metadata: Record<string, any>;
|
||||
created_at: string | null;
|
||||
updated_at: string | null;
|
||||
models?: LlmModel[];
|
||||
supports_tools?: boolean;
|
||||
supports_json_output?: boolean;
|
||||
supports_reasoning?: boolean;
|
||||
supports_parallel_tool?: boolean;
|
||||
}
|
||||
|
||||
export interface LlmModelCost {
|
||||
unit: string;
|
||||
credit_cost: number;
|
||||
credential_provider: string;
|
||||
credential_type?: string;
|
||||
metadata: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface LlmModel {
|
||||
id: string;
|
||||
slug: string;
|
||||
display_name: string;
|
||||
description: string | null;
|
||||
provider_id: string;
|
||||
creator_id: string | null;
|
||||
creator?: LlmModelCreator;
|
||||
context_window: number;
|
||||
max_output_tokens: number | null;
|
||||
price_tier: number;
|
||||
is_enabled: boolean;
|
||||
is_recommended: boolean;
|
||||
supports_tools: boolean;
|
||||
supports_json_output: boolean;
|
||||
supports_reasoning: boolean;
|
||||
supports_parallel_tool_calls: boolean;
|
||||
capabilities: Record<string, any>;
|
||||
metadata: Record<string, any>;
|
||||
costs?: LlmModelCost[];
|
||||
created_at: string | null;
|
||||
updated_at: string | null;
|
||||
}
|
||||
|
||||
export interface LlmModelCreator {
|
||||
id: string;
|
||||
name: string;
|
||||
display_name: string;
|
||||
description: string | null;
|
||||
website_url: string | null;
|
||||
logo_url: string | null;
|
||||
metadata: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface LlmModelMigration {
|
||||
id: string;
|
||||
source_model_slug: string;
|
||||
target_model_slug: string;
|
||||
reason: string | null;
|
||||
migrated_node_ids: any[];
|
||||
node_count: number;
|
||||
custom_credit_cost: number | null;
|
||||
is_reverted: boolean;
|
||||
reverted_at: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
123
autogpt_platform/frontend/src/components/atoms/Table/Table.tsx
Normal file
123
autogpt_platform/frontend/src/components/atoms/Table/Table.tsx
Normal file
@@ -0,0 +1,123 @@
|
||||
import * as React from "react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const Table = React.forwardRef<
|
||||
HTMLTableElement,
|
||||
React.HTMLAttributes<HTMLTableElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div className="relative w-full overflow-auto">
|
||||
<table
|
||||
ref={ref}
|
||||
className={cn("w-full caption-bottom text-sm", className)}
|
||||
{...props}
|
||||
/>
|
||||
</div>
|
||||
));
|
||||
Table.displayName = "Table";
|
||||
|
||||
const TableHeader = React.forwardRef<
|
||||
HTMLTableSectionElement,
|
||||
React.HTMLAttributes<HTMLTableSectionElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<thead ref={ref} className={cn("[&_tr]:border-b", className)} {...props} />
|
||||
));
|
||||
TableHeader.displayName = "TableHeader";
|
||||
|
||||
const TableBody = React.forwardRef<
|
||||
HTMLTableSectionElement,
|
||||
React.HTMLAttributes<HTMLTableSectionElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<tbody
|
||||
ref={ref}
|
||||
className={cn("[&_tr:last-child]:border-0", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableBody.displayName = "TableBody";
|
||||
|
||||
const TableFooter = React.forwardRef<
|
||||
HTMLTableSectionElement,
|
||||
React.HTMLAttributes<HTMLTableSectionElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<tfoot
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"border-t bg-neutral-100/50 font-medium dark:bg-neutral-800/50 [&>tr]:last:border-b-0",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableFooter.displayName = "TableFooter";
|
||||
|
||||
const TableRow = React.forwardRef<
|
||||
HTMLTableRowElement,
|
||||
React.HTMLAttributes<HTMLTableRowElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<tr
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"border-b transition-colors data-[state=selected]:bg-neutral-100 hover:bg-neutral-100/50 dark:data-[state=selected]:bg-neutral-800 dark:hover:bg-neutral-800/50",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableRow.displayName = "TableRow";
|
||||
|
||||
const TableHead = React.forwardRef<
|
||||
HTMLTableCellElement,
|
||||
React.ThHTMLAttributes<HTMLTableCellElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<th
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"h-10 px-2 text-left align-middle font-medium text-neutral-500 dark:text-neutral-400 [&:has([role=checkbox])]:pr-0 [&>[role=checkbox]]:translate-y-[2px]",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableHead.displayName = "TableHead";
|
||||
|
||||
const TableCell = React.forwardRef<
|
||||
HTMLTableCellElement,
|
||||
React.TdHTMLAttributes<HTMLTableCellElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<td
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"p-2 align-middle [&:has([role=checkbox])]:pr-0 [&>[role=checkbox]]:translate-y-[2px]",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableCell.displayName = "TableCell";
|
||||
|
||||
const TableCaption = React.forwardRef<
|
||||
HTMLTableCaptionElement,
|
||||
React.HTMLAttributes<HTMLTableCaptionElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<caption
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"mt-4 text-sm text-neutral-500 dark:text-neutral-400",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableCaption.displayName = "TableCaption";
|
||||
|
||||
export {
|
||||
Table,
|
||||
TableHeader,
|
||||
TableBody,
|
||||
TableFooter,
|
||||
TableHead,
|
||||
TableRow,
|
||||
TableCell,
|
||||
TableCaption,
|
||||
};
|
||||
Reference in New Issue
Block a user