feat(backend/llm-registry): add Redis-backed cache and cross-process pub/sub sync

- Wrap DB fetch with @cached(shared_cache=True) so results are stored in
  Redis automatically — other workers skip the DB on warm cache
- Add notifications.py with publish/subscribe helpers using llm_registry:refresh
  pub/sub channel for cross-process invalidation
- clear_registry_cache() invalidates the shared Redis entry before a forced
  DB refresh (called by admin mutations)
- rest_api.py: start a background subscription task so every worker reloads
  its in-process cache when another worker refreshes the registry
This commit is contained in:
Bentlybro
2026-04-07 18:34:45 +01:00
parent 696b273afc
commit c5dfe3333d
4 changed files with 164 additions and 64 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import contextlib
import logging
import platform
@@ -118,18 +119,26 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
# Load LLM registry before initializing blocks so blocks can use registry data.
# Tries Redis first (fast path on warm restart), falls back to DB.
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
# When block integration lands, this should fail hard or skip block initialization
try:
await backend.data.llm_registry.refresh_llm_registry()
logger.info("LLM registry refreshed successfully at startup")
logger.info("LLM registry loaded successfully at startup")
except Exception as e:
logger.warning(
f"Failed to refresh LLM registry at startup: {e}. "
f"Failed to load LLM registry at startup: {e}. "
"Blocks will initialize with empty registry."
)
# Start background task so this worker reloads its in-process cache whenever
# another worker (e.g. the admin API) refreshes the registry.
_registry_subscription_task = asyncio.create_task(
backend.data.llm_registry.subscribe_to_registry_refresh(
backend.data.llm_registry.refresh_llm_registry
)
)
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
@@ -154,6 +163,12 @@ async def lifespan_context(app: fastapi.FastAPI):
with launch_darkly_context():
yield
_registry_subscription_task.cancel()
try:
await _registry_subscription_task
except asyncio.CancelledError:
pass
try:
await shutdown_cloud_storage_handler()
except Exception as e:

View File

@@ -1,10 +1,15 @@
"""LLM Registry - Dynamic model management system."""
from backend.blocks.llm import ModelMetadata
from .notifications import (
publish_registry_refresh_notification,
subscribe_to_registry_refresh,
)
from .registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
clear_registry_cache,
get_all_model_slugs_for_validation,
get_all_models,
get_default_model_slug,
@@ -20,7 +25,11 @@ __all__ = [
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Functions
# Cache management
"clear_registry_cache",
"publish_registry_refresh_notification",
"subscribe_to_registry_refresh",
# Read functions
"refresh_llm_registry",
"get_model",
"get_all_models",

View File

@@ -0,0 +1,84 @@
"""Pub/sub notifications for LLM registry cross-process synchronisation."""
from __future__ import annotations
import asyncio
import logging
from typing import Awaitable, Callable
logger = logging.getLogger(__name__)
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
async def publish_registry_refresh_notification() -> None:
"""Publish a refresh signal so all other workers reload their in-process cache."""
from backend.data.redis_client import get_redis_async
try:
redis = await get_redis_async()
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
logger.debug("Published LLM registry refresh notification")
except Exception as e:
logger.warning("Failed to publish registry refresh notification: %s", e)
async def subscribe_to_registry_refresh(
on_refresh: Callable[[], Awaitable[None]],
) -> None:
"""Listen for registry refresh signals and call on_refresh each time one arrives.
Designed to run as a long-lived background asyncio.Task. Automatically
reconnects if the Redis connection drops.
Args:
on_refresh: Async callable invoked on each refresh signal.
Typically ``llm_registry.refresh_llm_registry``.
"""
from backend.data.redis_client import HOST, PASSWORD, PORT
from redis.asyncio import Redis as AsyncRedis
while True:
try:
# Dedicated connection — pub/sub must not share a connection used
# for regular commands.
redis_sub = AsyncRedis(
host=HOST, port=PORT, password=PASSWORD, decode_responses=True
)
pubsub = redis_sub.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info("Subscribed to LLM registry refresh channel")
while True:
try:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if (
message
and message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.debug("LLM registry refresh signal received")
try:
await on_refresh()
except Exception as e:
logger.error(
"Error in registry on_refresh callback: %s", e
)
except asyncio.CancelledError:
raise
except Exception as e:
logger.warning(
"Error processing registry refresh message: %s", e
)
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("LLM registry subscription task cancelled")
break
except Exception as e:
logger.warning(
"LLM registry subscription error: %s. Retrying in 5s...", e
)
await asyncio.sleep(5)

View File

@@ -10,6 +10,7 @@ import prisma.models
from pydantic import BaseModel, ConfigDict
from backend.blocks.llm import ModelMetadata
from backend.util.cache import cached
logger = logging.getLogger(__name__)
@@ -65,18 +66,17 @@ class RegistryModel(BaseModel):
supports_parallel_tool_calls: bool = False
# In-memory cache (will be replaced with Redis in PR #6)
# L1 in-process cache — Redis is the shared L2 via @cached(shared_cache=True)
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_lock = asyncio.Lock()
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel:
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: # type: ignore[name-defined]
"""Transform a raw Prisma LlmModel record into a RegistryModel instance."""
# Parse costs
costs = tuple(
RegistryModelCost(
unit=str(cost.unit), # Convert enum to string
unit=str(cost.unit),
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
@@ -87,7 +87,6 @@ def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel:
for cost in (record.Costs or [])
)
# Parse creator
creator = None
if record.Creator:
creator = RegistryModelCreator(
@@ -99,35 +98,26 @@ def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel:
logo_url=record.Creator.logoUrl,
)
# Parse capabilities
capabilities = dict(record.capabilities or {})
# Build metadata from record
# Warn if Provider relation is missing (indicates data corruption)
if not record.Provider:
logger.warning(
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
f"falling back to providerId {record.providerId}"
"LlmModel %s has no Provider despite NOT NULL FK - "
"falling back to providerId %s",
record.slug,
record.providerId,
)
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_name = record.Provider.name if record.Provider else record.providerId
provider_display = (
record.Provider.displayName
if record.Provider
else record.providerId
record.Provider.displayName if record.Provider else record.providerId
)
creator_name = record.Creator.displayName if record.Creator else "Unknown"
# Extract creator name (fallback to "Unknown" if no creator)
creator_name = (
record.Creator.displayName if record.Creator else "Unknown"
)
# Price tier defaults to 1 if not set
if record.priceTier not in (1, 2, 3):
logger.warning(
f"LlmModel {record.slug} has out-of-range priceTier={record.priceTier}, "
"defaulting to 1"
"LlmModel %s has out-of-range priceTier=%s, defaulting to 1",
record.slug,
record.priceTier,
)
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
@@ -164,41 +154,53 @@ def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel:
)
async def refresh_llm_registry() -> None:
"""
Refresh the LLM registry from the database.
@cached(maxsize=1, ttl_seconds=300, shared_cache=True, refresh_ttl_on_get=True)
async def _fetch_registry_from_db() -> list[RegistryModel]:
"""Fetch all LLM models from the database.
Fetches all models with their costs, providers, and creators,
then updates the in-memory cache.
Results are cached in Redis (shared_cache=True) so subsequent calls within
the TTL window skip the DB entirely — both within this process and across
all other workers that share the same Redis instance.
"""
records = await prisma.models.LlmModel.prisma().find_many( # type: ignore[attr-defined]
include={"Provider": True, "Costs": True, "Creator": True}
)
logger.info("Fetched %d LLM models from database", len(records))
return [_record_to_registry_model(r) for r in records]
def clear_registry_cache() -> None:
"""Invalidate the shared Redis cache for the registry DB fetch.
Call this before refresh_llm_registry() after any admin DB mutation so the
next fetch hits the database rather than serving the now-stale cached data.
"""
_fetch_registry_from_db.cache_clear()
async def refresh_llm_registry() -> None:
"""Refresh the in-process L1 cache from Redis/DB.
On the first call (or after clear_registry_cache()), fetches fresh data
from the database and stores it in Redis. Subsequent calls by other
workers hit the Redis cache instead of the DB.
"""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.info(f"Fetched {len(records)} LLM models from database")
models = await _fetch_registry_from_db()
new_models = {m.slug: m for m in models}
# Build model instances
new_models: dict[str, RegistryModel] = {}
for record in records:
model = _record_to_registry_model(record)
new_models[record.slug] = model
# Atomic swap
global _dynamic_models, _schema_options
_dynamic_models = new_models
_schema_options = _build_schema_options()
logger.info(
f"LLM registry refreshed: {len(_dynamic_models)} models, "
f"{len(_schema_options)} schema options"
"LLM registry refreshed: %d models, %d schema options",
len(_dynamic_models),
len(_schema_options),
)
except Exception as e:
logger.error(f"Failed to refresh LLM registry: {e}", exc_info=True)
logger.error("Failed to refresh LLM registry: %s", e, exc_info=True)
raise
@@ -240,23 +242,13 @@ def get_schema_options() -> list[dict[str, str]]:
def get_default_model_slug() -> str | None:
"""Get the default model slug (first recommended, or first enabled)."""
# Sort once and use next() to short-circuit on first match
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
# Prefer recommended models
recommended = next(
(m.slug for m in models if m.is_recommended and m.is_enabled), None
)
if recommended:
return recommended
# Fallback to first enabled model
return next((m.slug for m in models if m.is_enabled), None)
return recommended or next((m.slug for m in models if m.is_enabled), None)
def get_all_model_slugs_for_validation() -> list[str]:
"""
Get all model slugs for validation (enables migrate_llm_models to work).
Returns slugs for enabled models only.
"""
"""Get all model slugs for validation (enabled models only)."""
return [model.slug for model in _dynamic_models.values() if model.is_enabled]