This commit is contained in:
Bentlybro
2025-12-02 14:49:03 +00:00
parent 7fe6b576ae
commit ec705bbbcf
17 changed files with 125 additions and 98 deletions

View File

@@ -15,6 +15,7 @@ from anthropic.types import ToolParam
from groq import AsyncGroq
from pydantic import BaseModel, SecretStr
from backend.data import llm_registry
from backend.data.block import (
Block,
BlockCategory,
@@ -23,7 +24,6 @@ from backend.data.block import (
BlockSchemaOutput,
)
from backend.data.llm_model_types import ModelMetadata
from backend.data import llm_registry
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -76,7 +76,7 @@ def AICredentialsField() -> AICredentials:
# Get the mapping now - it may be empty initially, but will be refreshed
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
mapping = llm_registry.get_llm_discriminator_mapping()
return CredentialsField(
description="API key for the LLM provider.",
discriminator="model",
@@ -191,7 +191,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
raise ValueError(f"Missing metadata for model: {self.value}. Model not found in LLM registry.")
raise ValueError(
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
)
@property
def provider(self) -> str:
@@ -341,41 +343,46 @@ async def llm_call(
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
# Check if model is enabled - get from registry
from backend.data.llm_registry import _dynamic_models
if llm_model.value in _dynamic_models:
model_info = _dynamic_models[llm_model.value]
if not model_info.is_enabled:
raise ValueError(
f"LLM model '{llm_model.value}' is disabled."
)
raise ValueError(f"LLM model '{llm_model.value}' is disabled.")
except ValueError as e:
# Re-raise if it's our disabled model error
if "is disabled" in str(e):
raise
# Model not in cache - try refreshing the registry once if we have DB access
import logging
logger = logging.getLogger(__name__)
logger.warning(
"Model %s not found in registry cache",
llm_model.value,
)
# Try refreshing the registry if we have database access
from backend.data.db import is_connected
if is_connected():
try:
logger.info("Refreshing LLM registry and retrying lookup for %s", llm_model.value)
logger.info(
"Refreshing LLM registry and retrying lookup for %s",
llm_model.value,
)
await llm_registry.refresh_llm_registry()
# Try again after refresh
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
# Check if model is enabled after refresh
from backend.data.llm_registry import _dynamic_models
if llm_model.value in _dynamic_models:
model_info = _dynamic_models[llm_model.value]
if not model_info.is_enabled:
@@ -383,7 +390,7 @@ async def llm_call(
f"LLM model '{llm_model.value}' is disabled. "
"Please enable it in the LLM registry via the admin UI to use this model."
)
logger.info(
"Successfully loaded model %s metadata after registry refresh",
llm_model.value,
@@ -398,7 +405,9 @@ async def llm_call(
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
)
except Exception as refresh_exc:
logger.error("Failed to refresh LLM registry: %s", refresh_exc, exc_info=True)
logger.error(
"Failed to refresh LLM registry: %s", refresh_exc, exc_info=True
)
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
"Please ensure the model is added to the LLM registry via the admin UI."

View File

@@ -15,8 +15,8 @@ from backend.blocks.llm import (
LlmModel,
ModelMetadata,
)
from backend.data import llm_registry
from backend.blocks.stagehand._config import stagehand as stagehand_provider
from backend.data import llm_registry
from backend.sdk import (
APIKeyCredentials,
Block,

View File

@@ -152,13 +152,14 @@ class BlockSchema(BaseModel):
@staticmethod
def clear_all_schema_caches() -> None:
"""Clear cached JSON schemas for all BlockSchema subclasses."""
def clear_recursive(cls: type) -> None:
"""Recursively clear cache for class and all subclasses."""
if hasattr(cls, 'clear_schema_cache'):
if hasattr(cls, "clear_schema_cache"):
cls.clear_schema_cache()
for subclass in cls.__subclasses__():
clear_recursive(subclass)
clear_recursive(BlockSchema)
@classmethod
@@ -172,7 +173,9 @@ class BlockSchema(BaseModel):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
one_key = next(
(k for k in keys if k in obj and len(obj[k]) == 1), None
)
if one_key:
obj.update(obj[one_key][0])
@@ -187,7 +190,7 @@ class BlockSchema(BaseModel):
return obj
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
# Always post-process to ensure LLM registry data is up-to-date
# This refreshes model options and discriminator mappings even if schema was cached
update_schema_with_llm_registry(cls.cached_jsonschema, cls)
@@ -760,18 +763,23 @@ async def initialize_blocks() -> None:
try:
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
# Only refresh if we have DB access (check if Prisma is connected)
from backend.data.db import is_connected
if is_connected():
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
logger.info("LLM registry refreshed during block initialization")
else:
logger.warning("Prisma not connected, skipping LLM registry refresh during block initialization")
logger.warning(
"Prisma not connected, skipping LLM registry refresh during block initialization"
)
except Exception as exc:
logger.warning("Failed to refresh LLM registry during block initialization: %s", exc)
logger.warning(
"Failed to refresh LLM registry during block initialization: %s", exc
)
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.sdk.cost_integration import sync_all_provider_costs

View File

@@ -5,4 +5,3 @@ class ModelMetadata(NamedTuple):
provider: str
context_window: int
max_output_tokens: int | None

View File

@@ -93,7 +93,9 @@ async def refresh_llm_registry() -> None:
)
logger.debug("Found %d LLM model records in database", len(records))
except Exception as exc:
logger.error("Failed to refresh LLM registry from DB: %s", exc, exc_info=True)
logger.error(
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
)
return
dynamic: dict[str, RegistryModel] = {}
@@ -125,9 +127,11 @@ async def refresh_llm_registry() -> None:
metadata=metadata,
capabilities=record.capabilities or {},
extra_metadata=record.metadata or {},
provider_display_name=record.Provider.displayName
if record.Provider
else record.providerId,
provider_display_name=(
record.Provider.displayName
if record.Provider
else record.providerId
),
is_enabled=record.isEnabled,
costs=costs,
)
@@ -211,4 +215,3 @@ def get_dynamic_model_slugs() -> set[str]:
def iter_dynamic_models() -> Iterable[RegistryModel]:
return tuple(_dynamic_models.values())

View File

@@ -5,6 +5,7 @@ When models are added/updated/removed via the admin UI, this module
publishes notifications to Redis that all executor services subscribe to,
ensuring they refresh their registry cache in real-time.
"""
import asyncio
import logging
from typing import Any
@@ -28,7 +29,9 @@ def publish_registry_refresh_notification() -> None:
logger.info("Published LLM registry refresh notification to Redis")
except Exception as exc:
logger.warning(
"Failed to publish LLM registry refresh notification: %s", exc, exc_info=True
"Failed to publish LLM registry refresh notification: %s",
exc,
exc_info=True,
)
@@ -38,23 +41,32 @@ async def subscribe_to_registry_refresh(
"""
Subscribe to Redis notifications for LLM registry updates.
This runs in a loop and processes messages as they arrive.
Args:
on_refresh: Async callable to execute when a refresh notification is received
"""
from backend.data.redis_client import connect_async
try:
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info("Subscribed to LLM registry refresh notifications on channel: %s", REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications on channel: %s",
REGISTRY_REFRESH_CHANNEL,
)
# Process messages in a loop
while True:
try:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message and message["type"] == "message" and message["channel"] == REGISTRY_REFRESH_CHANNEL:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if (
message
and message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info("Received LLM registry refresh notification")
try:
await on_refresh()
@@ -77,4 +89,3 @@ async def subscribe_to_registry_refresh(
exc_info=True,
)
raise

View File

@@ -4,6 +4,7 @@ Helper utilities for LLM registry integration with block schemas.
This module handles the dynamic injection of discriminator mappings
and model options from the LLM registry into block schemas.
"""
import logging
from typing import Any
@@ -69,8 +70,7 @@ def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
def update_schema_with_llm_registry(
schema: dict[str, Any],
model_class: type | None = None
schema: dict[str, Any], model_class: type | None = None
) -> None:
"""
Update a JSON schema with current LLM registry data.
@@ -98,7 +98,8 @@ def update_schema_with_llm_registry(
except Exception as exc:
logger.warning(
"Failed to refresh LLM options for field %s: %s",
field_name, exc
field_name,
exc,
)
# Refresh discriminator mapping for fields that use model discrimination
@@ -107,5 +108,6 @@ def update_schema_with_llm_registry(
except Exception as exc:
logger.warning(
"Failed to refresh discriminator mapping for field %s: %s",
field_name, exc
field_name,
exc,
)

View File

@@ -533,7 +533,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
else:
schema["credentials_provider"] = allowed_providers
schema["credentials_types"] = model_class.allowed_cred_types()
# Ensure LLM discriminators are populated (delegates to shared helper)
update_schema_with_llm_registry(schema, model_class)
@@ -686,13 +686,13 @@ def CredentialsField(
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
field_schema_extra: dict[str, Any] = {}
# Always include discriminator if provided
if discriminator is not None:
field_schema_extra["discriminator"] = discriminator
# Always include discriminator_mapping when discriminator is set (even if empty initially)
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
# Include other optional fields (only if not None)
if required_scopes:
field_schema_extra["credentials_scopes"] = list(required_scopes)

View File

@@ -4,6 +4,7 @@ Helper functions for LLM registry initialization in executor context.
These functions handle refreshing the LLM registry when the executor starts
and subscribing to real-time updates via Redis pub/sub.
"""
import logging
from backend.data import db, llm_registry

View File

@@ -621,7 +621,7 @@ class ExecutionProcessor:
)
self.node_execution_thread.start()
self.node_evaluation_thread.start()
# Initialize LLM registry and subscribe to updates
from backend.executor.llm_registry_init import (
initialize_registry_for_executor,
@@ -634,7 +634,7 @@ class ExecutionProcessor:
asyncio.run_coroutine_threadsafe(
subscribe_to_registry_updates(), self.node_execution_loop
)
logger.info(f"[GraphExecutor] {self.tid} started")
@error_logged(swallow=False)

View File

@@ -20,8 +20,6 @@ import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
import backend.integrations.webhooks.utils
import backend.server.routers.postmark.postmark
import backend.server.routers.v1
@@ -36,13 +34,14 @@ import backend.server.v2.executions.review.routes
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.library.routes
import backend.server.v2.llm.routes as public_llm_routes
import backend.server.v2.otto.routes
import backend.server.v2.store.model
import backend.server.v2.store.routes
import backend.server.v2.llm.routes as public_llm_routes
import backend.util.service
import backend.util.settings
from backend.blocks.llm import LlmModel
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -111,9 +110,10 @@ async def lifespan_context(app: fastapi.FastAPI):
# Refresh LLM registry before initializing blocks so blocks can use registry data
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated discriminator_mapping
from backend.data.block import BlockSchema
BlockSchema.clear_all_schema_caches()
await backend.data.block.initialize_blocks()

View File

@@ -20,35 +20,41 @@ router = fastapi.APIRouter(
async def _refresh_runtime_state() -> None:
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
logger.info("Refreshing LLM registry runtime state...")
# Refresh registry from database
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated model options
from backend.data.block import BlockSchema
BlockSchema.clear_all_schema_caches()
logger.info("Cleared all block schema caches")
# Clear the /blocks endpoint cache so frontend gets updated schemas
try:
from backend.server.routers.v1 import _get_cached_blocks
_get_cached_blocks.cache_clear()
logger.info("Cleared /blocks endpoint cache")
except Exception as e:
logger.warning("Failed to clear /blocks cache: %s", e)
# Clear the v2 builder providers cache (if it exists)
try:
from backend.server.v2.builder import db as builder_db
if hasattr(builder_db, '_get_all_providers'):
if hasattr(builder_db, "_get_all_providers"):
builder_db._get_all_providers.cache_clear()
logger.info("Cleared v2 builder providers cache")
except Exception as e:
logger.debug("Could not clear v2 builder cache: %s", e)
# Notify all executor services to refresh their registry cache
from backend.data.llm_registry_notifications import publish_registry_refresh_notification
from backend.data.llm_registry_notifications import (
publish_registry_refresh_notification,
)
publish_registry_refresh_notification()
logger.info("Published registry refresh notification")
@@ -133,7 +139,9 @@ async def toggle_llm_model(
request: llm_model.ToggleLlmModelRequest,
):
try:
model = await llm_db.toggle_model(model_id=model_id, is_enabled=request.is_enabled)
model = await llm_db.toggle_model(
model_id=model_id, is_enabled=request.is_enabled
)
await _refresh_runtime_state()
return model
except Exception as exc:
@@ -171,8 +179,7 @@ async def get_llm_model_usage(model_id: str):
async def delete_llm_model(
model_id: str,
replacement_model_slug: str = fastapi.Query(
...,
description="Slug of the model to migrate existing workflows to"
..., description="Slug of the model to migrate existing workflows to"
),
):
"""
@@ -189,15 +196,14 @@ async def delete_llm_model(
"""
try:
result = await llm_db.delete_model(
model_id=model_id,
replacement_model_slug=replacement_model_slug
model_id=model_id, replacement_model_slug=replacement_model_slug
)
await _refresh_runtime_state()
logger.info(
"Deleted model '%s' and migrated %d nodes to '%s'",
result.deleted_model_slug,
result.nodes_migrated,
result.replacement_model_slug
result.replacement_model_slug,
)
return result
except ValueError as exc:
@@ -210,4 +216,3 @@ async def delete_llm_model(
status_code=500,
detail="Failed to delete model and migrate workflows",
) from exc

View File

@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -166,9 +166,7 @@ def test_create_llm_provider_success(
mock_notify.assert_called_once()
# Snapshot test the response
configured_snapshot.assert_match(
response_data, "create_llm_provider_success.json"
)
configured_snapshot.assert_match(response_data, "create_llm_provider_success.json")
def test_create_llm_model_success(
@@ -352,12 +350,12 @@ def test_delete_llm_model_success(
"replacement_model_slug": "gpt-4o-mini",
"nodes_migrated": 42,
"message": "Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
"and migrated 42 workflow node(s) to 'gpt-4o-mini'."
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
}
mocker.patch(
"backend.server.v2.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(return_value=type('obj', (object,), mock_response)()),
new=AsyncMock(return_value=type("obj", (object,), mock_response)()),
)
mock_refresh = mocker.patch(
@@ -391,9 +389,7 @@ def test_delete_llm_model_validation_error(
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
)
response = client.delete(
"/admin/llm/models/model-1?replacement_model_slug=invalid"
)
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
assert response.status_code == 400
assert "Replacement model 'invalid' not found" in response.json()["detail"]

View File

@@ -71,11 +71,7 @@ def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
async def list_providers(include_models: bool = True) -> list[llm_model.LlmProvider]:
include = (
{"Models": {"include": {"Costs": True}}}
if include_models
else None
)
include = {"Models": {"include": {"Costs": True}}} if include_models else None
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
return [_map_provider(record) for record in records]
@@ -208,9 +204,7 @@ async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
"""Get usage count for a model."""
import prisma as prisma_module
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}
)
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
@@ -220,19 +214,15 @@ async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
model.slug
model.slug,
)
node_count = int(count_result[0]["count"]) if count_result else 0
return llm_model.LlmModelUsageResponse(
model_slug=model.slug,
node_count=node_count
)
return llm_model.LlmModelUsageResponse(model_slug=model.slug, node_count=node_count)
async def delete_model(
model_id: str,
replacement_model_slug: str
model_id: str, replacement_model_slug: str
) -> llm_model.DeleteLlmModelResponse:
"""
Delete a model and migrate all AgentNodes using it to a replacement model.
@@ -258,8 +248,7 @@ async def delete_model(
# 1. Get the model being deleted
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id},
include={"Costs": True}
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
@@ -286,7 +275,7 @@ async def delete_model(
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
deleted_slug
deleted_slug,
)
nodes_affected = int(count_result[0]["count"]) if count_result else 0
@@ -303,7 +292,7 @@ async def delete_model(
WHERE "constantInput"::jsonb->>'model' = $2
""",
replacement_model_slug,
deleted_slug
deleted_slug,
)
# 5. Delete the model (CASCADE will delete costs automatically)
@@ -317,6 +306,5 @@ async def delete_model(
message=(
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}) "
f"and migrated {nodes_affected} workflow node(s) to '{replacement_model_slug}'."
)
),
)

View File

@@ -119,4 +119,3 @@ class DeleteLlmModelResponse(pydantic.BaseModel):
class LlmModelUsageResponse(pydantic.BaseModel):
model_slug: str
node_count: int

View File

@@ -21,4 +21,3 @@ async def list_models():
async def list_providers():
providers = await llm_db.list_providers(include_models=True)
return llm_model.LlmProvidersResponse(providers=providers)

View File

@@ -85,11 +85,18 @@ async def event_broadcaster(manager: ConnectionManager):
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info("Subscribed to LLM registry refresh notifications for WebSocket broadcast")
logger.info(
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
)
async for message in pubsub.listen():
if message["type"] == "message" and message["channel"] == REGISTRY_REFRESH_CHANNEL:
logger.info("Broadcasting LLM registry refresh to all WebSocket clients")
if (
message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info(
"Broadcasting LLM registry refresh to all WebSocket clients"
)
await manager.broadcast_to_all(
method=WSMethod.NOTIFICATION,
data={