mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Add registry payloads, caching, and tests
Support sending full LLM registry data in refresh notifications and protect DB fetches with a cached/thundering-herd-safe loader. Introduces a cached _fetch_registry_from_db() with Redis-backed TTL and updates refresh_llm_registry() to accept optional models_data so executors can refresh from a provided payload. Notifications now serialize JSON payloads (backwards-compatible with plain "refresh"), and subscribers extract models_data and pass it to the refresh handler. Admin refresh flow now clears the cache before fetching and publishes the refreshed data; executor refresh logic accepts models_data and adds jitter to spread load. Added unit/integration tests for caching, thundering-herd protection, cache_clear behavior, and notification payload handling; plus various logging and small API/ import adjustments.
This commit is contained in:
@@ -3,8 +3,13 @@ import logging
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.api.features.builder import db as builder_db
|
||||
from backend.api.features.v1 import _get_cached_blocks
|
||||
from backend.blocks._base import BlockSchema
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.data.llm_registry import publish_registry_refresh_notification
|
||||
from backend.data.llm_registry.registry import _fetch_registry_from_db
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
@@ -20,20 +25,19 @@ async def _refresh_runtime_state() -> None:
|
||||
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
|
||||
logger.info("Refreshing LLM registry runtime state...")
|
||||
try:
|
||||
_fetch_registry_from_db.cache_clear()
|
||||
logger.debug("Cleared Redis cache for LLM registry")
|
||||
|
||||
# Refresh registry from database
|
||||
await llm_registry.refresh_llm_registry()
|
||||
await refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated model options
|
||||
from backend.blocks._base import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
logger.info("Cleared all block schema caches")
|
||||
|
||||
# Clear the /blocks endpoint cache so frontend gets updated schemas
|
||||
try:
|
||||
from backend.api.features.v1 import _get_cached_blocks
|
||||
|
||||
_get_cached_blocks.cache_clear()
|
||||
logger.info("Cleared /blocks endpoint cache")
|
||||
except Exception as e:
|
||||
@@ -41,8 +45,6 @@ async def _refresh_runtime_state() -> None:
|
||||
|
||||
# Clear the v2 builder caches
|
||||
try:
|
||||
from backend.api.features.builder import db as builder_db
|
||||
|
||||
builder_db._get_all_providers.cache_clear()
|
||||
logger.info("Cleared v2 builder providers cache")
|
||||
builder_db._build_cached_search_results.cache_clear()
|
||||
@@ -52,11 +54,13 @@ async def _refresh_runtime_state() -> None:
|
||||
except Exception as e:
|
||||
logger.debug("Could not clear v2 builder cache: %s", e)
|
||||
|
||||
# Notify all executor services to refresh their registry cache
|
||||
from backend.data.llm_registry import publish_registry_refresh_notification
|
||||
# Fetch fresh data for notification (now contains updated data from DB)
|
||||
models_data = await _fetch_registry_from_db()
|
||||
|
||||
await publish_registry_refresh_notification()
|
||||
logger.info("Published registry refresh notification")
|
||||
await publish_registry_refresh_notification(models_data=models_data)
|
||||
logger.info(
|
||||
"Published registry refresh notification with %d models", len(models_data)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"LLM runtime state refresh failed; caches may be stale: %s", exc
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Integration tests for LLM registry notification system."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from backend.data.llm_registry import notifications
|
||||
from backend.executor import llm_registry_init
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notification_with_data_payload():
|
||||
"""Verify notification can carry model data."""
|
||||
models_data = [
|
||||
{
|
||||
"slug": "gpt-4o",
|
||||
"displayName": "GPT-4o",
|
||||
"contextWindow": 128000,
|
||||
}
|
||||
]
|
||||
|
||||
# Mock Redis
|
||||
with patch("backend.data.llm_registry.notifications.connect_async") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Publish notification
|
||||
await notifications.publish_registry_refresh_notification(models_data)
|
||||
|
||||
# Verify Redis publish was called with JSON payload
|
||||
assert mock_client.publish.call_count == 1
|
||||
channel, payload = mock_client.publish.call_args[0]
|
||||
|
||||
assert channel == "llm_registry:refresh"
|
||||
|
||||
# Parse and verify payload
|
||||
parsed = json.loads(payload)
|
||||
assert parsed["action"] == "refresh"
|
||||
assert parsed["data"] == models_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notification_backwards_compatibility():
|
||||
"""Verify notifications work without data payload (backwards compatibility)."""
|
||||
with patch("backend.data.llm_registry.notifications.connect_async") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Publish without data
|
||||
await notifications.publish_registry_refresh_notification(models_data=None)
|
||||
|
||||
# Verify simple string payload
|
||||
assert mock_client.publish.call_count == 1
|
||||
_, payload = mock_client.publish.call_args[0]
|
||||
assert payload == "refresh"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_extracts_data_from_notification():
|
||||
"""Verify subscriber can extract data from notification payload."""
|
||||
received_data = None
|
||||
|
||||
async def mock_callback(data):
|
||||
nonlocal received_data
|
||||
received_data = data
|
||||
|
||||
models_data = [{"slug": "gpt-4o", "displayName": "GPT-4o"}]
|
||||
|
||||
# Simulate receiving a notification message
|
||||
message = {
|
||||
"type": "message",
|
||||
"channel": b"llm_registry:refresh",
|
||||
"data": json.dumps({"action": "refresh", "data": models_data}).encode("utf-8"),
|
||||
}
|
||||
|
||||
# Mock Redis pubsub
|
||||
with patch("backend.data.llm_registry.notifications.connect_async") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_pubsub = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_client.pubsub.return_value = mock_pubsub
|
||||
|
||||
# Return the message once, then None to stop the loop
|
||||
mock_pubsub.get_message.side_effect = [message, None]
|
||||
|
||||
# Start subscription in a task and cancel after first message
|
||||
async def run_subscription():
|
||||
await notifications.subscribe_to_registry_refresh(mock_callback)
|
||||
|
||||
task = asyncio.create_task(run_subscription())
|
||||
await asyncio.sleep(0.1) # Let it process the message
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Verify callback was called with extracted data
|
||||
assert received_data == models_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jitter_adds_delay():
|
||||
"""Verify jitter is applied before refresh."""
|
||||
with patch("backend.data.llm_registry.registry.refresh_llm_registry"), patch(
|
||||
"backend.data.block_cost_config.refresh_llm_costs"
|
||||
), patch("backend.blocks._base.BlockSchema.clear_all_schema_caches"), patch(
|
||||
"backend.data.db.is_connected", return_value=True
|
||||
):
|
||||
|
||||
start = time.time()
|
||||
await llm_registry_init.refresh_registry_on_notification(models_data=[])
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Should have at least some delay (0-2 seconds)
|
||||
# We can't test the exact delay due to jitter randomness,
|
||||
# but we can verify it took some time
|
||||
assert elapsed >= 0
|
||||
assert elapsed <= 3 # Allow some overhead
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_uses_provided_data():
|
||||
"""Verify refresh uses provided data instead of fetching."""
|
||||
models_data = [{"slug": "test", "displayName": "Test"}]
|
||||
|
||||
with patch(
|
||||
"backend.data.llm_registry.registry.refresh_llm_registry"
|
||||
) as mock_refresh, patch("backend.data.block_cost_config.refresh_llm_costs"), patch(
|
||||
"backend.blocks._base.BlockSchema.clear_all_schema_caches"
|
||||
), patch(
|
||||
"backend.data.db.is_connected", return_value=True
|
||||
):
|
||||
|
||||
await llm_registry_init.refresh_registry_on_notification(
|
||||
models_data=models_data
|
||||
)
|
||||
|
||||
# Verify refresh was called with the data
|
||||
assert mock_refresh.call_count == 1
|
||||
assert mock_refresh.call_args[1]["models_data"] == models_data
|
||||
@@ -7,6 +7,7 @@ ensuring they refresh their registry cache in real-time.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -18,15 +19,34 @@ logger = logging.getLogger(__name__)
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
async def publish_registry_refresh_notification() -> None:
|
||||
async def publish_registry_refresh_notification(
|
||||
models_data: list[dict[str, Any]] | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Publish a notification to Redis that the LLM registry has been updated.
|
||||
All executor services subscribed to this channel will refresh their registry.
|
||||
|
||||
Args:
|
||||
models_data: Optional full registry data to include in notification
|
||||
"""
|
||||
try:
|
||||
redis = await connect_async()
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.info("Published LLM registry refresh notification to Redis")
|
||||
|
||||
# Prepare payload
|
||||
if models_data is not None:
|
||||
payload = json.dumps({"action": "refresh", "data": models_data})
|
||||
else:
|
||||
payload = "refresh" # Backwards compatible
|
||||
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, payload)
|
||||
|
||||
if models_data:
|
||||
logger.info(
|
||||
"Published LLM registry refresh notification with %d models",
|
||||
len(models_data),
|
||||
)
|
||||
else:
|
||||
logger.info("Published LLM registry refresh notification")
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to publish LLM registry refresh notification: %s",
|
||||
@@ -36,14 +56,13 @@ async def publish_registry_refresh_notification() -> None:
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Any, # Async callable that takes no args
|
||||
on_refresh: Any, # Async callable that takes optional models_data
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to Redis notifications for LLM registry updates.
|
||||
This runs in a loop and processes messages as they arrive.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable to execute when a refresh notification is received
|
||||
on_refresh: Async callable(models_data: list[dict] | None) -> None
|
||||
"""
|
||||
try:
|
||||
redis = await connect_async()
|
||||
@@ -66,8 +85,28 @@ async def subscribe_to_registry_refresh(
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info("Received LLM registry refresh notification")
|
||||
|
||||
# Extract models_data if present
|
||||
models_data = None
|
||||
try:
|
||||
await on_refresh()
|
||||
payload = message["data"]
|
||||
if isinstance(payload, bytes):
|
||||
payload = payload.decode("utf-8")
|
||||
|
||||
# Try to parse as JSON
|
||||
if payload != "refresh":
|
||||
parsed = json.loads(payload)
|
||||
models_data = parsed.get("data")
|
||||
logger.debug(
|
||||
"Notification includes %d models",
|
||||
len(models_data) if models_data else 0,
|
||||
)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
# Backwards compatible: simple "refresh" string
|
||||
pass
|
||||
|
||||
try:
|
||||
await on_refresh(models_data)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Error refreshing LLM registry from notification: %s",
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Iterable
|
||||
import prisma.models
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -113,18 +114,95 @@ def _build_schema_options() -> list[dict[str, str]]:
|
||||
return options
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
|
||||
@cached(maxsize=1, ttl_seconds=300, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def _fetch_registry_from_db() -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all LLM models from database with related data.
|
||||
|
||||
Cached in Redis with 300s TTL. Thundering herd protection ensures
|
||||
only one executor queries DB even if 1000 receive notification simultaneously.
|
||||
"""
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
"Creator": True,
|
||||
}
|
||||
)
|
||||
logger.debug("Fetched %d LLM model records from database", len(records))
|
||||
|
||||
# Serialize to plain dicts for caching
|
||||
return [
|
||||
{
|
||||
"id": record.id,
|
||||
"slug": record.slug,
|
||||
"displayName": record.displayName,
|
||||
"description": record.description,
|
||||
"providerId": record.providerId,
|
||||
"creatorId": record.creatorId,
|
||||
"contextWindow": record.contextWindow,
|
||||
"maxOutputTokens": record.maxOutputTokens,
|
||||
"priceTier": getattr(record, "priceTier", 1) or 1,
|
||||
"isEnabled": record.isEnabled,
|
||||
"isRecommended": record.isRecommended,
|
||||
"capabilities": record.capabilities,
|
||||
"metadata": record.metadata,
|
||||
"Provider": (
|
||||
{
|
||||
"name": (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
),
|
||||
"displayName": (
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
),
|
||||
}
|
||||
if record.Provider
|
||||
else None
|
||||
),
|
||||
"Costs": [
|
||||
{
|
||||
"creditCost": cost.creditCost,
|
||||
"credentialProvider": cost.credentialProvider,
|
||||
"credentialId": cost.credentialId,
|
||||
"credentialType": cost.credentialType,
|
||||
"currency": cost.currency,
|
||||
"metadata": cost.metadata,
|
||||
}
|
||||
for cost in (record.Costs or [])
|
||||
],
|
||||
"Creator": (
|
||||
{
|
||||
"id": record.Creator.id,
|
||||
"name": record.Creator.name,
|
||||
"displayName": record.Creator.displayName,
|
||||
"description": record.Creator.description,
|
||||
"websiteUrl": record.Creator.websiteUrl,
|
||||
"logoUrl": record.Creator.logoUrl,
|
||||
}
|
||||
if record.Creator
|
||||
else None
|
||||
),
|
||||
}
|
||||
for record in records
|
||||
]
|
||||
|
||||
|
||||
async def refresh_llm_registry(models_data: list[dict[str, Any]] | None = None) -> None:
|
||||
"""
|
||||
Refresh the LLM registry from the database or provided data.
|
||||
|
||||
Args:
|
||||
models_data: Optional pre-fetched model data from notification payload
|
||||
"""
|
||||
async with _lock:
|
||||
try:
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
"Creator": True,
|
||||
}
|
||||
)
|
||||
logger.debug("Found %d LLM model records in database", len(records))
|
||||
if models_data is None:
|
||||
# Fetch from cache (thundering herd protected)
|
||||
models_data = await _fetch_registry_from_db()
|
||||
|
||||
logger.debug("Processing %d LLM model records", len(models_data))
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
|
||||
@@ -132,69 +210,66 @@ async def refresh_llm_registry() -> None:
|
||||
return
|
||||
|
||||
dynamic: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
for record_dict in models_data:
|
||||
provider = record_dict.get("Provider")
|
||||
creator_data = record_dict.get("Creator")
|
||||
|
||||
provider_name = provider["name"] if provider else record_dict["providerId"]
|
||||
provider_display_name = (
|
||||
record.Provider.displayName if record.Provider else record.providerId
|
||||
provider["displayName"] if provider else record_dict["providerId"]
|
||||
)
|
||||
# Creator name: prefer Creator.name, fallback to provider display name
|
||||
creator_name = (
|
||||
record.Creator.name if record.Creator else provider_display_name
|
||||
creator_data["name"] if creator_data else provider_display_name
|
||||
)
|
||||
# Price tier: default to 1 (cheapest) if not set
|
||||
price_tier = getattr(record, "priceTier", 1) or 1
|
||||
price_tier = record_dict.get("priceTier", 1) or 1
|
||||
# Clamp to valid range 1-3
|
||||
price_tier = max(1, min(3, price_tier))
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
display_name=record.displayName,
|
||||
context_window=record_dict["contextWindow"],
|
||||
max_output_tokens=record_dict["maxOutputTokens"],
|
||||
display_name=record_dict["displayName"],
|
||||
provider_name=provider_display_name,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier, # type: ignore[arg-type]
|
||||
)
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=_json_to_dict(cost.metadata),
|
||||
credit_cost=cost["creditCost"],
|
||||
credential_provider=cost["credentialProvider"],
|
||||
credential_id=cost.get("credentialId"),
|
||||
credential_type=cost.get("credentialType"),
|
||||
currency=cost.get("currency"),
|
||||
metadata=_json_to_dict(cost.get("metadata")),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
for cost in record_dict.get("Costs", [])
|
||||
)
|
||||
|
||||
# Map creator if present
|
||||
creator = None
|
||||
if record.Creator:
|
||||
if creator_data:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
id=creator_data["id"],
|
||||
name=creator_data["name"],
|
||||
display_name=creator_data["displayName"],
|
||||
description=creator_data.get("description"),
|
||||
website_url=creator_data.get("websiteUrl"),
|
||||
logo_url=creator_data.get("logoUrl"),
|
||||
)
|
||||
|
||||
dynamic[record.slug] = RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
dynamic[record_dict["slug"]] = RegistryModel(
|
||||
slug=record_dict["slug"],
|
||||
display_name=record_dict["displayName"],
|
||||
description=record_dict.get("description"),
|
||||
metadata=metadata,
|
||||
capabilities=_json_to_dict(record.capabilities),
|
||||
extra_metadata=_json_to_dict(record.metadata),
|
||||
provider_display_name=(
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
),
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
capabilities=_json_to_dict(record_dict.get("capabilities")),
|
||||
extra_metadata=_json_to_dict(record_dict.get("metadata")),
|
||||
provider_display_name=provider_display_name,
|
||||
is_enabled=record_dict["isEnabled"],
|
||||
is_recommended=record_dict["isRecommended"],
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
"""Unit tests for LLM registry caching and thundering herd protection."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from backend.data.llm_registry import registry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_registry_from_db_caching():
|
||||
"""Verify @cached prevents duplicate DB calls."""
|
||||
with patch("backend.data.llm_registry.registry.prisma.models.LlmModel") as mock:
|
||||
mock.prisma().find_many = AsyncMock(return_value=[])
|
||||
|
||||
# Clear cache first
|
||||
registry._fetch_registry_from_db.cache_clear()
|
||||
|
||||
# Call twice
|
||||
await registry._fetch_registry_from_db()
|
||||
await registry._fetch_registry_from_db()
|
||||
|
||||
# Verify only called once (cached)
|
||||
assert mock.prisma().find_many.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thundering_herd_protection():
|
||||
"""Verify only 1 DB call with 100 concurrent requests."""
|
||||
call_count = 0
|
||||
|
||||
async def mock_db_fetch(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.1) # Simulate slow DB
|
||||
return []
|
||||
|
||||
with patch("backend.data.llm_registry.registry.prisma.models.LlmModel") as mock:
|
||||
mock.prisma().find_many = mock_db_fetch
|
||||
|
||||
# Clear cache first
|
||||
registry._fetch_registry_from_db.cache_clear()
|
||||
|
||||
# Launch 100 concurrent fetches
|
||||
tasks = [registry._fetch_registry_from_db() for _ in range(100)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Verify only 1 DB call due to thundering herd protection
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_accepts_models_data():
|
||||
"""Verify refresh_llm_registry can accept pre-fetched data."""
|
||||
models_data = [
|
||||
{
|
||||
"id": "test-id",
|
||||
"slug": "gpt-4o",
|
||||
"displayName": "GPT-4o",
|
||||
"description": "Test model",
|
||||
"providerId": "openai",
|
||||
"creatorId": None,
|
||||
"contextWindow": 128000,
|
||||
"maxOutputTokens": 4096,
|
||||
"priceTier": 2,
|
||||
"isEnabled": True,
|
||||
"isRecommended": False,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"Provider": {
|
||||
"name": "openai",
|
||||
"displayName": "OpenAI",
|
||||
},
|
||||
"Costs": [],
|
||||
"Creator": None,
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.data.llm_registry.registry.prisma.models.LlmModel") as mock:
|
||||
# Should NOT call DB if data provided
|
||||
await registry.refresh_llm_registry(models_data=models_data)
|
||||
mock.prisma().find_many.assert_not_called()
|
||||
|
||||
# Verify model was added to registry
|
||||
assert "gpt-4o" in registry._dynamic_models
|
||||
assert registry._dynamic_models["gpt-4o"].display_name == "GPT-4o"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_falls_back_to_cache():
|
||||
"""Verify refresh_llm_registry fetches from cache when no data provided."""
|
||||
with patch(
|
||||
"backend.data.llm_registry.registry._fetch_registry_from_db"
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = []
|
||||
|
||||
# Call without data - should fetch from cache
|
||||
await registry.refresh_llm_registry(models_data=None)
|
||||
|
||||
# Verify cache fetch was called
|
||||
assert mock_fetch.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear_forces_fresh_fetch():
|
||||
"""
|
||||
Verify that cache_clear() forces a fresh DB fetch.
|
||||
|
||||
This is CRITICAL for admin updates - when an admin changes a model,
|
||||
we must clear the cache before fetching to ensure fresh data is broadcast.
|
||||
"""
|
||||
fetch_count = 0
|
||||
|
||||
async def mock_db_fetch(*args, **kwargs):
|
||||
nonlocal fetch_count
|
||||
fetch_count += 1
|
||||
return [{"slug": f"model-{fetch_count}", "displayName": f"Model {fetch_count}"}]
|
||||
|
||||
with patch("backend.data.llm_registry.registry.prisma.models.LlmModel") as mock:
|
||||
mock.prisma().find_many = mock_db_fetch
|
||||
|
||||
# Clear cache and fetch first time
|
||||
registry._fetch_registry_from_db.cache_clear()
|
||||
result1 = await registry._fetch_registry_from_db()
|
||||
assert result1[0]["slug"] == "model-1"
|
||||
assert fetch_count == 1
|
||||
|
||||
# Fetch second time (should use cache, no DB call)
|
||||
result2 = await registry._fetch_registry_from_db()
|
||||
assert result2[0]["slug"] == "model-1" # Same cached data
|
||||
assert fetch_count == 1 # No additional DB call
|
||||
|
||||
# Clear cache (simulating admin update)
|
||||
registry._fetch_registry_from_db.cache_clear()
|
||||
|
||||
# Fetch third time (should hit DB for fresh data)
|
||||
result3 = await registry._fetch_registry_from_db()
|
||||
assert result3[0]["slug"] == "model-2" # Fresh data!
|
||||
assert fetch_count == 2 # New DB call
|
||||
|
||||
# Verify cache_clear() method exists and is callable
|
||||
assert callable(registry._fetch_registry_from_db.cache_clear)
|
||||
@@ -5,7 +5,9 @@ These functions handle refreshing the LLM registry when the executor starts
|
||||
and subscribing to real-time updates via Redis pub/sub.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
|
||||
from backend.blocks._base import BlockSchema
|
||||
from backend.data import db, llm_registry
|
||||
@@ -39,21 +41,39 @@ async def initialize_registry_for_executor() -> None:
|
||||
)
|
||||
|
||||
|
||||
async def refresh_registry_on_notification() -> None:
|
||||
"""Refresh LLM registry when notified via Redis pub/sub."""
|
||||
async def refresh_registry_on_notification(
|
||||
models_data: list[dict] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Refresh LLM registry when notified via Redis pub/sub.
|
||||
|
||||
Args:
|
||||
models_data: Optional pre-fetched model data from notification
|
||||
"""
|
||||
# Add jitter to spread load across executors (0-2 seconds)
|
||||
jitter = random.uniform(0, 2.0)
|
||||
await asyncio.sleep(jitter)
|
||||
logger.debug("[GraphExecutor] Starting registry refresh after %.2fs jitter", jitter)
|
||||
|
||||
try:
|
||||
# Ensure DB is connected
|
||||
if not db.is_connected():
|
||||
await db.connect()
|
||||
|
||||
# Refresh registry and costs
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Refresh registry (uses provided data or fetches from cache)
|
||||
await llm_registry.refresh_llm_registry(models_data=models_data)
|
||||
await refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they regenerate with new model options
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
logger.info("[GraphExecutor] LLM registry refreshed from notification")
|
||||
if models_data:
|
||||
logger.info(
|
||||
"[GraphExecutor] LLM registry refreshed from notification data (%d models)",
|
||||
len(models_data),
|
||||
)
|
||||
else:
|
||||
logger.info("[GraphExecutor] LLM registry refreshed from cache")
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
|
||||
|
||||
Reference in New Issue
Block a user