Add registry payloads, caching, and tests

Support sending full LLM registry data in refresh notifications and protect DB fetches with a cached/thundering-herd-safe loader. Introduces a cached _fetch_registry_from_db() with Redis-backed TTL and updates refresh_llm_registry() to accept optional models_data so executors can refresh from a provided payload. Notifications now serialize JSON payloads (backwards-compatible with plain "refresh"), and subscribers extract models_data and pass it to the refresh handler. Admin refresh flow now clears the cache before fetching and publishes the refreshed data; executor refresh logic accepts models_data and adds jitter to spread load. Added unit/integration tests for caching, thundering-herd protection, cache_clear behavior, and notification payload handling; plus various logging and small API/ import adjustments.
This commit is contained in:
Bentlybro
2026-03-03 16:36:49 +00:00
parent 7273f5096a
commit bf606434b9
6 changed files with 496 additions and 70 deletions

View File

@@ -3,8 +3,13 @@ import logging
import autogpt_libs.auth
import fastapi
from backend.api.features.builder import db as builder_db
from backend.api.features.v1 import _get_cached_blocks
from backend.blocks._base import BlockSchema
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.llm_registry import publish_registry_refresh_notification
from backend.data.llm_registry.registry import _fetch_registry_from_db
from backend.server.v2.llm import db as llm_db
from backend.server.v2.llm import model as llm_model
@@ -20,20 +25,19 @@ async def _refresh_runtime_state() -> None:
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
logger.info("Refreshing LLM registry runtime state...")
try:
_fetch_registry_from_db.cache_clear()
logger.debug("Cleared Redis cache for LLM registry")
# Refresh registry from database
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated model options
from backend.blocks._base import BlockSchema
BlockSchema.clear_all_schema_caches()
logger.info("Cleared all block schema caches")
# Clear the /blocks endpoint cache so frontend gets updated schemas
try:
from backend.api.features.v1 import _get_cached_blocks
_get_cached_blocks.cache_clear()
logger.info("Cleared /blocks endpoint cache")
except Exception as e:
@@ -41,8 +45,6 @@ async def _refresh_runtime_state() -> None:
# Clear the v2 builder caches
try:
from backend.api.features.builder import db as builder_db
builder_db._get_all_providers.cache_clear()
logger.info("Cleared v2 builder providers cache")
builder_db._build_cached_search_results.cache_clear()
@@ -52,11 +54,13 @@ async def _refresh_runtime_state() -> None:
except Exception as e:
logger.debug("Could not clear v2 builder cache: %s", e)
# Notify all executor services to refresh their registry cache
from backend.data.llm_registry import publish_registry_refresh_notification
# Fetch fresh data for notification (now contains updated data from DB)
models_data = await _fetch_registry_from_db()
await publish_registry_refresh_notification()
logger.info("Published registry refresh notification")
await publish_registry_refresh_notification(models_data=models_data)
logger.info(
"Published registry refresh notification with %d models", len(models_data)
)
except Exception as exc:
logger.exception(
"LLM runtime state refresh failed; caches may be stale: %s", exc

View File

@@ -0,0 +1,145 @@
"""Integration tests for LLM registry notification system."""
import asyncio
import json
import time
import pytest
from unittest.mock import AsyncMock, patch
from backend.data.llm_registry import notifications
from backend.executor import llm_registry_init
@pytest.mark.asyncio
async def test_notification_with_data_payload():
"""Verify notification can carry model data."""
models_data = [
{
"slug": "gpt-4o",
"displayName": "GPT-4o",
"contextWindow": 128000,
}
]
# Mock Redis
with patch("backend.data.llm_registry.notifications.connect_async") as mock_redis:
mock_client = AsyncMock()
mock_redis.return_value = mock_client
# Publish notification
await notifications.publish_registry_refresh_notification(models_data)
# Verify Redis publish was called with JSON payload
assert mock_client.publish.call_count == 1
channel, payload = mock_client.publish.call_args[0]
assert channel == "llm_registry:refresh"
# Parse and verify payload
parsed = json.loads(payload)
assert parsed["action"] == "refresh"
assert parsed["data"] == models_data
@pytest.mark.asyncio
async def test_notification_backwards_compatibility():
"""Verify notifications work without data payload (backwards compatibility)."""
with patch("backend.data.llm_registry.notifications.connect_async") as mock_redis:
mock_client = AsyncMock()
mock_redis.return_value = mock_client
# Publish without data
await notifications.publish_registry_refresh_notification(models_data=None)
# Verify simple string payload
assert mock_client.publish.call_count == 1
_, payload = mock_client.publish.call_args[0]
assert payload == "refresh"
@pytest.mark.asyncio
async def test_subscribe_extracts_data_from_notification():
"""Verify subscriber can extract data from notification payload."""
received_data = None
async def mock_callback(data):
nonlocal received_data
received_data = data
models_data = [{"slug": "gpt-4o", "displayName": "GPT-4o"}]
# Simulate receiving a notification message
message = {
"type": "message",
"channel": b"llm_registry:refresh",
"data": json.dumps({"action": "refresh", "data": models_data}).encode("utf-8"),
}
# Mock Redis pubsub
with patch("backend.data.llm_registry.notifications.connect_async") as mock_redis:
mock_client = AsyncMock()
mock_pubsub = AsyncMock()
mock_redis.return_value = mock_client
mock_client.pubsub.return_value = mock_pubsub
# Return the message once, then None to stop the loop
mock_pubsub.get_message.side_effect = [message, None]
# Start subscription in a task and cancel after first message
async def run_subscription():
await notifications.subscribe_to_registry_refresh(mock_callback)
task = asyncio.create_task(run_subscription())
await asyncio.sleep(0.1) # Let it process the message
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Verify callback was called with extracted data
assert received_data == models_data
@pytest.mark.asyncio
async def test_jitter_adds_delay():
"""Verify jitter is applied before refresh."""
with patch("backend.data.llm_registry.registry.refresh_llm_registry"), patch(
"backend.data.block_cost_config.refresh_llm_costs"
), patch("backend.blocks._base.BlockSchema.clear_all_schema_caches"), patch(
"backend.data.db.is_connected", return_value=True
):
start = time.time()
await llm_registry_init.refresh_registry_on_notification(models_data=[])
elapsed = time.time() - start
# Should have at least some delay (0-2 seconds)
# We can't test the exact delay due to jitter randomness,
# but we can verify it took some time
assert elapsed >= 0
assert elapsed <= 3 # Allow some overhead
@pytest.mark.asyncio
async def test_refresh_uses_provided_data():
"""Verify refresh uses provided data instead of fetching."""
models_data = [{"slug": "test", "displayName": "Test"}]
with patch(
"backend.data.llm_registry.registry.refresh_llm_registry"
) as mock_refresh, patch("backend.data.block_cost_config.refresh_llm_costs"), patch(
"backend.blocks._base.BlockSchema.clear_all_schema_caches"
), patch(
"backend.data.db.is_connected", return_value=True
):
await llm_registry_init.refresh_registry_on_notification(
models_data=models_data
)
# Verify refresh was called with the data
assert mock_refresh.call_count == 1
assert mock_refresh.call_args[1]["models_data"] == models_data

View File

@@ -7,6 +7,7 @@ ensuring they refresh their registry cache in real-time.
"""
import asyncio
import json
import logging
from typing import Any
@@ -18,15 +19,34 @@ logger = logging.getLogger(__name__)
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
async def publish_registry_refresh_notification() -> None:
async def publish_registry_refresh_notification(
models_data: list[dict[str, Any]] | None = None
) -> None:
"""
Publish a notification to Redis that the LLM registry has been updated.
All executor services subscribed to this channel will refresh their registry.
Args:
models_data: Optional full registry data to include in notification
"""
try:
redis = await connect_async()
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
logger.info("Published LLM registry refresh notification to Redis")
# Prepare payload
if models_data is not None:
payload = json.dumps({"action": "refresh", "data": models_data})
else:
payload = "refresh" # Backwards compatible
await redis.publish(REGISTRY_REFRESH_CHANNEL, payload)
if models_data:
logger.info(
"Published LLM registry refresh notification with %d models",
len(models_data),
)
else:
logger.info("Published LLM registry refresh notification")
except Exception as exc:
logger.warning(
"Failed to publish LLM registry refresh notification: %s",
@@ -36,14 +56,13 @@ async def publish_registry_refresh_notification() -> None:
async def subscribe_to_registry_refresh(
on_refresh: Any, # Async callable that takes no args
on_refresh: Any, # Async callable that takes optional models_data
) -> None:
"""
Subscribe to Redis notifications for LLM registry updates.
This runs in a loop and processes messages as they arrive.
Args:
on_refresh: Async callable to execute when a refresh notification is received
on_refresh: Async callable(models_data: list[dict] | None) -> None
"""
try:
redis = await connect_async()
@@ -66,8 +85,28 @@ async def subscribe_to_registry_refresh(
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info("Received LLM registry refresh notification")
# Extract models_data if present
models_data = None
try:
await on_refresh()
payload = message["data"]
if isinstance(payload, bytes):
payload = payload.decode("utf-8")
# Try to parse as JSON
if payload != "refresh":
parsed = json.loads(payload)
models_data = parsed.get("data")
logger.debug(
"Notification includes %d models",
len(models_data) if models_data else 0,
)
except (json.JSONDecodeError, AttributeError):
# Backwards compatible: simple "refresh" string
pass
try:
await on_refresh(models_data)
except Exception as exc:
logger.error(
"Error refreshing LLM registry from notification: %s",

View File

@@ -10,6 +10,7 @@ from typing import Any, Iterable
import prisma.models
from backend.data.llm_registry.model import ModelMetadata
from backend.util.cache import cached
logger = logging.getLogger(__name__)
@@ -113,18 +114,95 @@ def _build_schema_options() -> list[dict[str, str]]:
return options
async def refresh_llm_registry() -> None:
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
@cached(maxsize=1, ttl_seconds=300, shared_cache=True, refresh_ttl_on_get=True)
async def _fetch_registry_from_db() -> list[dict[str, Any]]:
"""
Fetch all LLM models from database with related data.
Cached in Redis with 300s TTL. Thundering herd protection ensures
only one executor queries DB even if 1000 receive notification simultaneously.
"""
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.debug("Fetched %d LLM model records from database", len(records))
# Serialize to plain dicts for caching
return [
{
"id": record.id,
"slug": record.slug,
"displayName": record.displayName,
"description": record.description,
"providerId": record.providerId,
"creatorId": record.creatorId,
"contextWindow": record.contextWindow,
"maxOutputTokens": record.maxOutputTokens,
"priceTier": getattr(record, "priceTier", 1) or 1,
"isEnabled": record.isEnabled,
"isRecommended": record.isRecommended,
"capabilities": record.capabilities,
"metadata": record.metadata,
"Provider": (
{
"name": (
record.Provider.name if record.Provider else record.providerId
),
"displayName": (
record.Provider.displayName
if record.Provider
else record.providerId
),
}
if record.Provider
else None
),
"Costs": [
{
"creditCost": cost.creditCost,
"credentialProvider": cost.credentialProvider,
"credentialId": cost.credentialId,
"credentialType": cost.credentialType,
"currency": cost.currency,
"metadata": cost.metadata,
}
for cost in (record.Costs or [])
],
"Creator": (
{
"id": record.Creator.id,
"name": record.Creator.name,
"displayName": record.Creator.displayName,
"description": record.Creator.description,
"websiteUrl": record.Creator.websiteUrl,
"logoUrl": record.Creator.logoUrl,
}
if record.Creator
else None
),
}
for record in records
]
async def refresh_llm_registry(models_data: list[dict[str, Any]] | None = None) -> None:
"""
Refresh the LLM registry from the database or provided data.
Args:
models_data: Optional pre-fetched model data from notification payload
"""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.debug("Found %d LLM model records in database", len(records))
if models_data is None:
# Fetch from cache (thundering herd protected)
models_data = await _fetch_registry_from_db()
logger.debug("Processing %d LLM model records", len(models_data))
except Exception as exc:
logger.error(
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
@@ -132,69 +210,66 @@ async def refresh_llm_registry() -> None:
return
dynamic: dict[str, RegistryModel] = {}
for record in records:
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
for record_dict in models_data:
provider = record_dict.get("Provider")
creator_data = record_dict.get("Creator")
provider_name = provider["name"] if provider else record_dict["providerId"]
provider_display_name = (
record.Provider.displayName if record.Provider else record.providerId
provider["displayName"] if provider else record_dict["providerId"]
)
# Creator name: prefer Creator.name, fallback to provider display name
creator_name = (
record.Creator.name if record.Creator else provider_display_name
creator_data["name"] if creator_data else provider_display_name
)
# Price tier: default to 1 (cheapest) if not set
price_tier = getattr(record, "priceTier", 1) or 1
price_tier = record_dict.get("priceTier", 1) or 1
# Clamp to valid range 1-3
price_tier = max(1, min(3, price_tier))
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
display_name=record.displayName,
context_window=record_dict["contextWindow"],
max_output_tokens=record_dict["maxOutputTokens"],
display_name=record_dict["displayName"],
provider_name=provider_display_name,
creator_name=creator_name,
price_tier=price_tier, # type: ignore[arg-type]
)
costs = tuple(
RegistryModelCost(
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=_json_to_dict(cost.metadata),
credit_cost=cost["creditCost"],
credential_provider=cost["credentialProvider"],
credential_id=cost.get("credentialId"),
credential_type=cost.get("credentialType"),
currency=cost.get("currency"),
metadata=_json_to_dict(cost.get("metadata")),
)
for cost in (record.Costs or [])
for cost in record_dict.get("Costs", [])
)
# Map creator if present
creator = None
if record.Creator:
if creator_data:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
id=creator_data["id"],
name=creator_data["name"],
display_name=creator_data["displayName"],
description=creator_data.get("description"),
website_url=creator_data.get("websiteUrl"),
logo_url=creator_data.get("logoUrl"),
)
dynamic[record.slug] = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
dynamic[record_dict["slug"]] = RegistryModel(
slug=record_dict["slug"],
display_name=record_dict["displayName"],
description=record_dict.get("description"),
metadata=metadata,
capabilities=_json_to_dict(record.capabilities),
extra_metadata=_json_to_dict(record.metadata),
provider_display_name=(
record.Provider.displayName
if record.Provider
else record.providerId
),
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
capabilities=_json_to_dict(record_dict.get("capabilities")),
extra_metadata=_json_to_dict(record_dict.get("metadata")),
provider_display_name=provider_display_name,
is_enabled=record_dict["isEnabled"],
is_recommended=record_dict["isRecommended"],
costs=costs,
creator=creator,
)

View File

@@ -0,0 +1,143 @@
"""Unit tests for LLM registry caching and thundering herd protection."""
import asyncio
import pytest
from unittest.mock import AsyncMock, patch
from backend.data.llm_registry import registry
@pytest.mark.asyncio
async def test_fetch_registry_from_db_caching():
"""Verify @cached prevents duplicate DB calls."""
with patch("backend.data.llm_registry.registry.prisma.models.LlmModel") as mock:
mock.prisma().find_many = AsyncMock(return_value=[])
# Clear cache first
registry._fetch_registry_from_db.cache_clear()
# Call twice
await registry._fetch_registry_from_db()
await registry._fetch_registry_from_db()
# Verify only called once (cached)
assert mock.prisma().find_many.call_count == 1
@pytest.mark.asyncio
async def test_thundering_herd_protection():
"""Verify only 1 DB call with 100 concurrent requests."""
call_count = 0
async def mock_db_fetch(*args, **kwargs):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1) # Simulate slow DB
return []
with patch("backend.data.llm_registry.registry.prisma.models.LlmModel") as mock:
mock.prisma().find_many = mock_db_fetch
# Clear cache first
registry._fetch_registry_from_db.cache_clear()
# Launch 100 concurrent fetches
tasks = [registry._fetch_registry_from_db() for _ in range(100)]
await asyncio.gather(*tasks)
# Verify only 1 DB call due to thundering herd protection
assert call_count == 1
@pytest.mark.asyncio
async def test_refresh_accepts_models_data():
"""Verify refresh_llm_registry can accept pre-fetched data."""
models_data = [
{
"id": "test-id",
"slug": "gpt-4o",
"displayName": "GPT-4o",
"description": "Test model",
"providerId": "openai",
"creatorId": None,
"contextWindow": 128000,
"maxOutputTokens": 4096,
"priceTier": 2,
"isEnabled": True,
"isRecommended": False,
"capabilities": {},
"metadata": {},
"Provider": {
"name": "openai",
"displayName": "OpenAI",
},
"Costs": [],
"Creator": None,
}
]
with patch("backend.data.llm_registry.registry.prisma.models.LlmModel") as mock:
# Should NOT call DB if data provided
await registry.refresh_llm_registry(models_data=models_data)
mock.prisma().find_many.assert_not_called()
# Verify model was added to registry
assert "gpt-4o" in registry._dynamic_models
assert registry._dynamic_models["gpt-4o"].display_name == "GPT-4o"
@pytest.mark.asyncio
async def test_refresh_falls_back_to_cache():
"""Verify refresh_llm_registry fetches from cache when no data provided."""
with patch(
"backend.data.llm_registry.registry._fetch_registry_from_db"
) as mock_fetch:
mock_fetch.return_value = []
# Call without data - should fetch from cache
await registry.refresh_llm_registry(models_data=None)
# Verify cache fetch was called
assert mock_fetch.call_count == 1
@pytest.mark.asyncio
async def test_cache_clear_forces_fresh_fetch():
"""
Verify that cache_clear() forces a fresh DB fetch.
This is CRITICAL for admin updates - when an admin changes a model,
we must clear the cache before fetching to ensure fresh data is broadcast.
"""
fetch_count = 0
async def mock_db_fetch(*args, **kwargs):
nonlocal fetch_count
fetch_count += 1
return [{"slug": f"model-{fetch_count}", "displayName": f"Model {fetch_count}"}]
with patch("backend.data.llm_registry.registry.prisma.models.LlmModel") as mock:
mock.prisma().find_many = mock_db_fetch
# Clear cache and fetch first time
registry._fetch_registry_from_db.cache_clear()
result1 = await registry._fetch_registry_from_db()
assert result1[0]["slug"] == "model-1"
assert fetch_count == 1
# Fetch second time (should use cache, no DB call)
result2 = await registry._fetch_registry_from_db()
assert result2[0]["slug"] == "model-1" # Same cached data
assert fetch_count == 1 # No additional DB call
# Clear cache (simulating admin update)
registry._fetch_registry_from_db.cache_clear()
# Fetch third time (should hit DB for fresh data)
result3 = await registry._fetch_registry_from_db()
assert result3[0]["slug"] == "model-2" # Fresh data!
assert fetch_count == 2 # New DB call
# Verify cache_clear() method exists and is callable
assert callable(registry._fetch_registry_from_db.cache_clear)

View File

@@ -5,7 +5,9 @@ These functions handle refreshing the LLM registry when the executor starts
and subscribing to real-time updates via Redis pub/sub.
"""
import asyncio
import logging
import random
from backend.blocks._base import BlockSchema
from backend.data import db, llm_registry
@@ -39,21 +41,39 @@ async def initialize_registry_for_executor() -> None:
)
async def refresh_registry_on_notification() -> None:
"""Refresh LLM registry when notified via Redis pub/sub."""
async def refresh_registry_on_notification(
models_data: list[dict] | None = None,
) -> None:
"""
Refresh LLM registry when notified via Redis pub/sub.
Args:
models_data: Optional pre-fetched model data from notification
"""
# Add jitter to spread load across executors (0-2 seconds)
jitter = random.uniform(0, 2.0)
await asyncio.sleep(jitter)
logger.debug("[GraphExecutor] Starting registry refresh after %.2fs jitter", jitter)
try:
# Ensure DB is connected
if not db.is_connected():
await db.connect()
# Refresh registry and costs
await llm_registry.refresh_llm_registry()
# Refresh registry (uses provided data or fetches from cache)
await llm_registry.refresh_llm_registry(models_data=models_data)
await refresh_llm_costs()
# Clear block schema caches so they regenerate with new model options
BlockSchema.clear_all_schema_caches()
logger.info("[GraphExecutor] LLM registry refreshed from notification")
if models_data:
logger.info(
"[GraphExecutor] LLM registry refreshed from notification data (%d models)",
len(models_data),
)
else:
logger.info("[GraphExecutor] LLM registry refreshed from cache")
except Exception as exc:
logger.error(
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",