mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
format
This commit is contained in:
@@ -15,6 +15,7 @@ from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -23,7 +24,6 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.llm_model_types import ModelMetadata
|
||||
from backend.data import llm_registry
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -76,7 +76,7 @@ def AICredentialsField() -> AICredentials:
|
||||
# Get the mapping now - it may be empty initially, but will be refreshed
|
||||
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
|
||||
mapping = llm_registry.get_llm_discriminator_mapping()
|
||||
|
||||
|
||||
return CredentialsField(
|
||||
description="API key for the LLM provider.",
|
||||
discriminator="model",
|
||||
@@ -191,7 +191,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
raise ValueError(f"Missing metadata for model: {self.value}. Model not found in LLM registry.")
|
||||
raise ValueError(
|
||||
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
|
||||
)
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
@@ -341,41 +343,46 @@ async def llm_call(
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
|
||||
|
||||
# Check if model is enabled - get from registry
|
||||
from backend.data.llm_registry import _dynamic_models
|
||||
|
||||
if llm_model.value in _dynamic_models:
|
||||
model_info = _dynamic_models[llm_model.value]
|
||||
if not model_info.is_enabled:
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' is disabled."
|
||||
)
|
||||
raise ValueError(f"LLM model '{llm_model.value}' is disabled.")
|
||||
except ValueError as e:
|
||||
# Re-raise if it's our disabled model error
|
||||
if "is disabled" in str(e):
|
||||
raise
|
||||
# Model not in cache - try refreshing the registry once if we have DB access
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"Model %s not found in registry cache",
|
||||
llm_model.value,
|
||||
)
|
||||
|
||||
|
||||
# Try refreshing the registry if we have database access
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
try:
|
||||
logger.info("Refreshing LLM registry and retrying lookup for %s", llm_model.value)
|
||||
logger.info(
|
||||
"Refreshing LLM registry and retrying lookup for %s",
|
||||
llm_model.value,
|
||||
)
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Try again after refresh
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
|
||||
|
||||
# Check if model is enabled after refresh
|
||||
from backend.data.llm_registry import _dynamic_models
|
||||
|
||||
if llm_model.value in _dynamic_models:
|
||||
model_info = _dynamic_models[llm_model.value]
|
||||
if not model_info.is_enabled:
|
||||
@@ -383,7 +390,7 @@ async def llm_call(
|
||||
f"LLM model '{llm_model.value}' is disabled. "
|
||||
"Please enable it in the LLM registry via the admin UI to use this model."
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
"Successfully loaded model %s metadata after registry refresh",
|
||||
llm_model.value,
|
||||
@@ -398,7 +405,9 @@ async def llm_call(
|
||||
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
|
||||
)
|
||||
except Exception as refresh_exc:
|
||||
logger.error("Failed to refresh LLM registry: %s", refresh_exc, exc_info=True)
|
||||
logger.error(
|
||||
"Failed to refresh LLM registry: %s", refresh_exc, exc_info=True
|
||||
)
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
|
||||
"Please ensure the model is added to the LLM registry via the admin UI."
|
||||
|
||||
@@ -15,8 +15,8 @@ from backend.blocks.llm import (
|
||||
LlmModel,
|
||||
ModelMetadata,
|
||||
)
|
||||
from backend.data import llm_registry
|
||||
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
||||
from backend.data import llm_registry
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
|
||||
@@ -152,13 +152,14 @@ class BlockSchema(BaseModel):
|
||||
@staticmethod
|
||||
def clear_all_schema_caches() -> None:
|
||||
"""Clear cached JSON schemas for all BlockSchema subclasses."""
|
||||
|
||||
def clear_recursive(cls: type) -> None:
|
||||
"""Recursively clear cache for class and all subclasses."""
|
||||
if hasattr(cls, 'clear_schema_cache'):
|
||||
if hasattr(cls, "clear_schema_cache"):
|
||||
cls.clear_schema_cache()
|
||||
for subclass in cls.__subclasses__():
|
||||
clear_recursive(subclass)
|
||||
|
||||
|
||||
clear_recursive(BlockSchema)
|
||||
|
||||
@classmethod
|
||||
@@ -172,7 +173,9 @@ class BlockSchema(BaseModel):
|
||||
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||
keys = {"allOf", "anyOf", "oneOf"}
|
||||
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
||||
one_key = next(
|
||||
(k for k in keys if k in obj and len(obj[k]) == 1), None
|
||||
)
|
||||
if one_key:
|
||||
obj.update(obj[one_key][0])
|
||||
|
||||
@@ -187,7 +190,7 @@ class BlockSchema(BaseModel):
|
||||
return obj
|
||||
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
|
||||
|
||||
# Always post-process to ensure LLM registry data is up-to-date
|
||||
# This refreshes model options and discriminator mappings even if schema was cached
|
||||
update_schema_with_llm_registry(cls.cached_jsonschema, cls)
|
||||
@@ -760,18 +763,23 @@ async def initialize_blocks() -> None:
|
||||
try:
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
|
||||
|
||||
# Only refresh if we have DB access (check if Prisma is connected)
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
logger.info("LLM registry refreshed during block initialization")
|
||||
else:
|
||||
logger.warning("Prisma not connected, skipping LLM registry refresh during block initialization")
|
||||
logger.warning(
|
||||
"Prisma not connected, skipping LLM registry refresh during block initialization"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to refresh LLM registry during block initialization: %s", exc)
|
||||
|
||||
logger.warning(
|
||||
"Failed to refresh LLM registry during block initialization: %s", exc
|
||||
)
|
||||
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
|
||||
@@ -5,4 +5,3 @@ class ModelMetadata(NamedTuple):
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
|
||||
|
||||
@@ -93,7 +93,9 @@ async def refresh_llm_registry() -> None:
|
||||
)
|
||||
logger.debug("Found %d LLM model records in database", len(records))
|
||||
except Exception as exc:
|
||||
logger.error("Failed to refresh LLM registry from DB: %s", exc, exc_info=True)
|
||||
logger.error(
|
||||
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
dynamic: dict[str, RegistryModel] = {}
|
||||
@@ -125,9 +127,11 @@ async def refresh_llm_registry() -> None:
|
||||
metadata=metadata,
|
||||
capabilities=record.capabilities or {},
|
||||
extra_metadata=record.metadata or {},
|
||||
provider_display_name=record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId,
|
||||
provider_display_name=(
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
),
|
||||
is_enabled=record.isEnabled,
|
||||
costs=costs,
|
||||
)
|
||||
@@ -211,4 +215,3 @@ def get_dynamic_model_slugs() -> set[str]:
|
||||
|
||||
def iter_dynamic_models() -> Iterable[RegistryModel]:
|
||||
return tuple(_dynamic_models.values())
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ When models are added/updated/removed via the admin UI, this module
|
||||
publishes notifications to Redis that all executor services subscribe to,
|
||||
ensuring they refresh their registry cache in real-time.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -28,7 +29,9 @@ def publish_registry_refresh_notification() -> None:
|
||||
logger.info("Published LLM registry refresh notification to Redis")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to publish LLM registry refresh notification: %s", exc, exc_info=True
|
||||
"Failed to publish LLM registry refresh notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -38,23 +41,32 @@ async def subscribe_to_registry_refresh(
|
||||
"""
|
||||
Subscribe to Redis notifications for LLM registry updates.
|
||||
This runs in a loop and processes messages as they arrive.
|
||||
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable to execute when a refresh notification is received
|
||||
"""
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
|
||||
try:
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info("Subscribed to LLM registry refresh notifications on channel: %s", REGISTRY_REFRESH_CHANNEL)
|
||||
|
||||
logger.info(
|
||||
"Subscribed to LLM registry refresh notifications on channel: %s",
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
)
|
||||
|
||||
# Process messages in a loop
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
if message and message["type"] == "message" and message["channel"] == REGISTRY_REFRESH_CHANNEL:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if (
|
||||
message
|
||||
and message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info("Received LLM registry refresh notification")
|
||||
try:
|
||||
await on_refresh()
|
||||
@@ -77,4 +89,3 @@ async def subscribe_to_registry_refresh(
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ Helper utilities for LLM registry integration with block schemas.
|
||||
This module handles the dynamic injection of discriminator mappings
|
||||
and model options from the LLM registry into block schemas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -69,8 +70,7 @@ def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
|
||||
|
||||
|
||||
def update_schema_with_llm_registry(
|
||||
schema: dict[str, Any],
|
||||
model_class: type | None = None
|
||||
schema: dict[str, Any], model_class: type | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update a JSON schema with current LLM registry data.
|
||||
@@ -98,7 +98,8 @@ def update_schema_with_llm_registry(
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh LLM options for field %s: %s",
|
||||
field_name, exc
|
||||
field_name,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Refresh discriminator mapping for fields that use model discrimination
|
||||
@@ -107,5 +108,6 @@ def update_schema_with_llm_registry(
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh discriminator mapping for field %s: %s",
|
||||
field_name, exc
|
||||
field_name,
|
||||
exc,
|
||||
)
|
||||
|
||||
@@ -533,7 +533,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
else:
|
||||
schema["credentials_provider"] = allowed_providers
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
|
||||
|
||||
# Ensure LLM discriminators are populated (delegates to shared helper)
|
||||
update_schema_with_llm_registry(schema, model_class)
|
||||
|
||||
@@ -686,13 +686,13 @@ def CredentialsField(
|
||||
|
||||
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
|
||||
field_schema_extra: dict[str, Any] = {}
|
||||
|
||||
|
||||
# Always include discriminator if provided
|
||||
if discriminator is not None:
|
||||
field_schema_extra["discriminator"] = discriminator
|
||||
# Always include discriminator_mapping when discriminator is set (even if empty initially)
|
||||
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
|
||||
|
||||
|
||||
# Include other optional fields (only if not None)
|
||||
if required_scopes:
|
||||
field_schema_extra["credentials_scopes"] = list(required_scopes)
|
||||
|
||||
@@ -4,6 +4,7 @@ Helper functions for LLM registry initialization in executor context.
|
||||
These functions handle refreshing the LLM registry when the executor starts
|
||||
and subscribing to real-time updates via Redis pub/sub.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from backend.data import db, llm_registry
|
||||
|
||||
@@ -621,7 +621,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
|
||||
|
||||
# Initialize LLM registry and subscribe to updates
|
||||
from backend.executor.llm_registry_init import (
|
||||
initialize_registry_for_executor,
|
||||
@@ -634,7 +634,7 @@ class ExecutionProcessor:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
subscribe_to_registry_updates(), self.node_execution_loop
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
|
||||
@@ -20,8 +20,6 @@ import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
@@ -36,13 +34,14 @@ import backend.server.v2.executions.review.routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
import backend.server.v2.llm.routes as public_llm_routes
|
||||
import backend.server.v2.otto.routes
|
||||
import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.server.v2.llm.routes as public_llm_routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
@@ -111,9 +110,10 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated discriminator_mapping
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
@@ -20,35 +20,41 @@ router = fastapi.APIRouter(
|
||||
async def _refresh_runtime_state() -> None:
|
||||
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
|
||||
logger.info("Refreshing LLM registry runtime state...")
|
||||
|
||||
|
||||
# Refresh registry from database
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated model options
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
logger.info("Cleared all block schema caches")
|
||||
|
||||
|
||||
# Clear the /blocks endpoint cache so frontend gets updated schemas
|
||||
try:
|
||||
from backend.server.routers.v1 import _get_cached_blocks
|
||||
|
||||
_get_cached_blocks.cache_clear()
|
||||
logger.info("Cleared /blocks endpoint cache")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to clear /blocks cache: %s", e)
|
||||
|
||||
|
||||
# Clear the v2 builder providers cache (if it exists)
|
||||
try:
|
||||
from backend.server.v2.builder import db as builder_db
|
||||
if hasattr(builder_db, '_get_all_providers'):
|
||||
|
||||
if hasattr(builder_db, "_get_all_providers"):
|
||||
builder_db._get_all_providers.cache_clear()
|
||||
logger.info("Cleared v2 builder providers cache")
|
||||
except Exception as e:
|
||||
logger.debug("Could not clear v2 builder cache: %s", e)
|
||||
|
||||
|
||||
# Notify all executor services to refresh their registry cache
|
||||
from backend.data.llm_registry_notifications import publish_registry_refresh_notification
|
||||
from backend.data.llm_registry_notifications import (
|
||||
publish_registry_refresh_notification,
|
||||
)
|
||||
|
||||
publish_registry_refresh_notification()
|
||||
logger.info("Published registry refresh notification")
|
||||
|
||||
@@ -133,7 +139,9 @@ async def toggle_llm_model(
|
||||
request: llm_model.ToggleLlmModelRequest,
|
||||
):
|
||||
try:
|
||||
model = await llm_db.toggle_model(model_id=model_id, is_enabled=request.is_enabled)
|
||||
model = await llm_db.toggle_model(
|
||||
model_id=model_id, is_enabled=request.is_enabled
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
except Exception as exc:
|
||||
@@ -171,8 +179,7 @@ async def get_llm_model_usage(model_id: str):
|
||||
async def delete_llm_model(
|
||||
model_id: str,
|
||||
replacement_model_slug: str = fastapi.Query(
|
||||
...,
|
||||
description="Slug of the model to migrate existing workflows to"
|
||||
..., description="Slug of the model to migrate existing workflows to"
|
||||
),
|
||||
):
|
||||
"""
|
||||
@@ -189,15 +196,14 @@ async def delete_llm_model(
|
||||
"""
|
||||
try:
|
||||
result = await llm_db.delete_model(
|
||||
model_id=model_id,
|
||||
replacement_model_slug=replacement_model_slug
|
||||
model_id=model_id, replacement_model_slug=replacement_model_slug
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Deleted model '%s' and migrated %d nodes to '%s'",
|
||||
result.deleted_model_slug,
|
||||
result.nodes_migrated,
|
||||
result.replacement_model_slug
|
||||
result.replacement_model_slug,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
@@ -210,4 +216,3 @@ async def delete_llm_model(
|
||||
status_code=500,
|
||||
detail="Failed to delete model and migrate workflows",
|
||||
) from exc
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -166,9 +166,7 @@ def test_create_llm_provider_success(
|
||||
mock_notify.assert_called_once()
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(
|
||||
response_data, "create_llm_provider_success.json"
|
||||
)
|
||||
configured_snapshot.assert_match(response_data, "create_llm_provider_success.json")
|
||||
|
||||
|
||||
def test_create_llm_model_success(
|
||||
@@ -352,12 +350,12 @@ def test_delete_llm_model_success(
|
||||
"replacement_model_slug": "gpt-4o-mini",
|
||||
"nodes_migrated": 42,
|
||||
"message": "Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
|
||||
"and migrated 42 workflow node(s) to 'gpt-4o-mini'."
|
||||
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(return_value=type('obj', (object,), mock_response)()),
|
||||
new=AsyncMock(return_value=type("obj", (object,), mock_response)()),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
@@ -391,9 +389,7 @@ def test_delete_llm_model_validation_error(
|
||||
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
|
||||
)
|
||||
|
||||
response = client.delete(
|
||||
"/admin/llm/models/model-1?replacement_model_slug=invalid"
|
||||
)
|
||||
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Replacement model 'invalid' not found" in response.json()["detail"]
|
||||
|
||||
@@ -71,11 +71,7 @@ def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
|
||||
|
||||
|
||||
async def list_providers(include_models: bool = True) -> list[llm_model.LlmProvider]:
|
||||
include = (
|
||||
{"Models": {"include": {"Costs": True}}}
|
||||
if include_models
|
||||
else None
|
||||
)
|
||||
include = {"Models": {"include": {"Costs": True}}} if include_models else None
|
||||
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
|
||||
return [_map_provider(record) for record in records]
|
||||
|
||||
@@ -208,9 +204,7 @@ async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
|
||||
"""Get usage count for a model."""
|
||||
import prisma as prisma_module
|
||||
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}
|
||||
)
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
@@ -220,19 +214,15 @@ async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
model.slug
|
||||
model.slug,
|
||||
)
|
||||
node_count = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
return llm_model.LlmModelUsageResponse(
|
||||
model_slug=model.slug,
|
||||
node_count=node_count
|
||||
)
|
||||
return llm_model.LlmModelUsageResponse(model_slug=model.slug, node_count=node_count)
|
||||
|
||||
|
||||
async def delete_model(
|
||||
model_id: str,
|
||||
replacement_model_slug: str
|
||||
model_id: str, replacement_model_slug: str
|
||||
) -> llm_model.DeleteLlmModelResponse:
|
||||
"""
|
||||
Delete a model and migrate all AgentNodes using it to a replacement model.
|
||||
@@ -258,8 +248,7 @@ async def delete_model(
|
||||
|
||||
# 1. Get the model being deleted
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id},
|
||||
include={"Costs": True}
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
@@ -286,7 +275,7 @@ async def delete_model(
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
deleted_slug
|
||||
deleted_slug,
|
||||
)
|
||||
nodes_affected = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
@@ -303,7 +292,7 @@ async def delete_model(
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug
|
||||
deleted_slug,
|
||||
)
|
||||
|
||||
# 5. Delete the model (CASCADE will delete costs automatically)
|
||||
@@ -317,6 +306,5 @@ async def delete_model(
|
||||
message=(
|
||||
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}) "
|
||||
f"and migrated {nodes_affected} workflow node(s) to '{replacement_model_slug}'."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -119,4 +119,3 @@ class DeleteLlmModelResponse(pydantic.BaseModel):
|
||||
class LlmModelUsageResponse(pydantic.BaseModel):
|
||||
model_slug: str
|
||||
node_count: int
|
||||
|
||||
|
||||
@@ -21,4 +21,3 @@ async def list_models():
|
||||
async def list_providers():
|
||||
providers = await llm_db.list_providers(include_models=True)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
|
||||
|
||||
@@ -85,11 +85,18 @@ async def event_broadcaster(manager: ConnectionManager):
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info("Subscribed to LLM registry refresh notifications for WebSocket broadcast")
|
||||
logger.info(
|
||||
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
|
||||
)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message" and message["channel"] == REGISTRY_REFRESH_CHANNEL:
|
||||
logger.info("Broadcasting LLM registry refresh to all WebSocket clients")
|
||||
if (
|
||||
message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info(
|
||||
"Broadcasting LLM registry refresh to all WebSocket clients"
|
||||
)
|
||||
await manager.broadcast_to_all(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data={
|
||||
|
||||
Reference in New Issue
Block a user