From c5dfe3333dc5bdf9b1021ccd40fa2262d9b74e3c Mon Sep 17 00:00:00 2001 From: Bentlybro Date: Tue, 7 Apr 2026 18:34:45 +0100 Subject: [PATCH] feat(backend/llm-registry): add Redis-backed cache and cross-process pub/sub sync MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Wrap DB fetch with @cached(shared_cache=True) so results are stored in Redis automatically — other workers skip the DB on warm cache - Add notifications.py with publish/subscribe helpers using llm_registry:refresh pub/sub channel for cross-process invalidation - clear_registry_cache() invalidates the shared Redis entry before a forced DB refresh (called by admin mutations) - rest_api.py: start a background subscription task so every worker reloads its in-process cache when another worker refreshes the registry --- .../backend/backend/api/rest_api.py | 23 +++- .../backend/data/llm_registry/__init__.py | 11 +- .../data/llm_registry/notifications.py | 84 +++++++++++++ .../backend/data/llm_registry/registry.py | 110 ++++++++---------- 4 files changed, 164 insertions(+), 64 deletions(-) create mode 100644 autogpt_platform/backend/backend/data/llm_registry/notifications.py diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index 771b6a8fc2..948857fc6d 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import logging import platform @@ -118,18 +119,26 @@ async def lifespan_context(app: fastapi.FastAPI): AutoRegistry.patch_integrations() - # Refresh LLM registry before initializing blocks so blocks can use registry data + # Load LLM registry before initializing blocks so blocks can use registry data. + # Tries Redis first (fast path on warm restart), falls back to DB. # Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5) - # When block integration lands, this should fail hard or skip block initialization try: await backend.data.llm_registry.refresh_llm_registry() - logger.info("LLM registry refreshed successfully at startup") + logger.info("LLM registry loaded successfully at startup") except Exception as e: logger.warning( - f"Failed to refresh LLM registry at startup: {e}. " + f"Failed to load LLM registry at startup: {e}. " "Blocks will initialize with empty registry." ) + # Start background task so this worker reloads its in-process cache whenever + # another worker (e.g. the admin API) refreshes the registry. + _registry_subscription_task = asyncio.create_task( + backend.data.llm_registry.subscribe_to_registry_refresh( + backend.data.llm_registry.refresh_llm_registry + ) + ) + await backend.data.block.initialize_blocks() await backend.data.user.migrate_and_encrypt_user_integrations() @@ -154,6 +163,12 @@ async def lifespan_context(app: fastapi.FastAPI): with launch_darkly_context(): yield + _registry_subscription_task.cancel() + try: + await _registry_subscription_task + except asyncio.CancelledError: + pass + try: await shutdown_cloud_storage_handler() except Exception as e: diff --git a/autogpt_platform/backend/backend/data/llm_registry/__init__.py b/autogpt_platform/backend/backend/data/llm_registry/__init__.py index 3e6b9896b3..b5af2a4c7b 100644 --- a/autogpt_platform/backend/backend/data/llm_registry/__init__.py +++ b/autogpt_platform/backend/backend/data/llm_registry/__init__.py @@ -1,10 +1,15 @@ """LLM Registry - Dynamic model management system.""" from backend.blocks.llm import ModelMetadata +from .notifications import ( + publish_registry_refresh_notification, + subscribe_to_registry_refresh, +) from .registry import ( RegistryModel, RegistryModelCost, RegistryModelCreator, + clear_registry_cache, get_all_model_slugs_for_validation, get_all_models, get_default_model_slug, @@ -20,7 +25,11 @@ __all__ = [ "RegistryModel", "RegistryModelCost", "RegistryModelCreator", - # Functions + # Cache management + "clear_registry_cache", + "publish_registry_refresh_notification", + "subscribe_to_registry_refresh", + # Read functions "refresh_llm_registry", "get_model", "get_all_models", diff --git a/autogpt_platform/backend/backend/data/llm_registry/notifications.py b/autogpt_platform/backend/backend/data/llm_registry/notifications.py new file mode 100644 index 0000000000..a04f3f703d --- /dev/null +++ b/autogpt_platform/backend/backend/data/llm_registry/notifications.py @@ -0,0 +1,84 @@ +"""Pub/sub notifications for LLM registry cross-process synchronisation.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Awaitable, Callable + +logger = logging.getLogger(__name__) + +REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh" + + +async def publish_registry_refresh_notification() -> None: + """Publish a refresh signal so all other workers reload their in-process cache.""" + from backend.data.redis_client import get_redis_async + + try: + redis = await get_redis_async() + await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh") + logger.debug("Published LLM registry refresh notification") + except Exception as e: + logger.warning("Failed to publish registry refresh notification: %s", e) + + +async def subscribe_to_registry_refresh( + on_refresh: Callable[[], Awaitable[None]], +) -> None: + """Listen for registry refresh signals and call on_refresh each time one arrives. + + Designed to run as a long-lived background asyncio.Task. Automatically + reconnects if the Redis connection drops. + + Args: + on_refresh: Async callable invoked on each refresh signal. + Typically ``llm_registry.refresh_llm_registry``. + """ + from backend.data.redis_client import HOST, PASSWORD, PORT + from redis.asyncio import Redis as AsyncRedis + + while True: + try: + # Dedicated connection — pub/sub must not share a connection used + # for regular commands. + redis_sub = AsyncRedis( + host=HOST, port=PORT, password=PASSWORD, decode_responses=True + ) + pubsub = redis_sub.pubsub() + await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL) + logger.info("Subscribed to LLM registry refresh channel") + + while True: + try: + message = await pubsub.get_message( + ignore_subscribe_messages=True, timeout=1.0 + ) + if ( + message + and message["type"] == "message" + and message["channel"] == REGISTRY_REFRESH_CHANNEL + ): + logger.debug("LLM registry refresh signal received") + try: + await on_refresh() + except Exception as e: + logger.error( + "Error in registry on_refresh callback: %s", e + ) + except asyncio.CancelledError: + raise + except Exception as e: + logger.warning( + "Error processing registry refresh message: %s", e + ) + await asyncio.sleep(1) + + except asyncio.CancelledError: + logger.info("LLM registry subscription task cancelled") + break + except Exception as e: + logger.warning( + "LLM registry subscription error: %s. Retrying in 5s...", e + ) + await asyncio.sleep(5) diff --git a/autogpt_platform/backend/backend/data/llm_registry/registry.py b/autogpt_platform/backend/backend/data/llm_registry/registry.py index 456cc2ecd5..5c359131d8 100644 --- a/autogpt_platform/backend/backend/data/llm_registry/registry.py +++ b/autogpt_platform/backend/backend/data/llm_registry/registry.py @@ -10,6 +10,7 @@ import prisma.models from pydantic import BaseModel, ConfigDict from backend.blocks.llm import ModelMetadata +from backend.util.cache import cached logger = logging.getLogger(__name__) @@ -65,18 +66,17 @@ class RegistryModel(BaseModel): supports_parallel_tool_calls: bool = False -# In-memory cache (will be replaced with Redis in PR #6) +# L1 in-process cache — Redis is the shared L2 via @cached(shared_cache=True) _dynamic_models: dict[str, RegistryModel] = {} _schema_options: list[dict[str, str]] = [] _lock = asyncio.Lock() -def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: +def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: # type: ignore[name-defined] """Transform a raw Prisma LlmModel record into a RegistryModel instance.""" - # Parse costs costs = tuple( RegistryModelCost( - unit=str(cost.unit), # Convert enum to string + unit=str(cost.unit), credit_cost=cost.creditCost, credential_provider=cost.credentialProvider, credential_id=cost.credentialId, @@ -87,7 +87,6 @@ def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: for cost in (record.Costs or []) ) - # Parse creator creator = None if record.Creator: creator = RegistryModelCreator( @@ -99,35 +98,26 @@ def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: logo_url=record.Creator.logoUrl, ) - # Parse capabilities capabilities = dict(record.capabilities or {}) - # Build metadata from record - # Warn if Provider relation is missing (indicates data corruption) if not record.Provider: logger.warning( - f"LlmModel {record.slug} has no Provider despite NOT NULL FK - " - f"falling back to providerId {record.providerId}" + "LlmModel %s has no Provider despite NOT NULL FK - " + "falling back to providerId %s", + record.slug, + record.providerId, ) - provider_name = ( - record.Provider.name if record.Provider else record.providerId - ) + provider_name = record.Provider.name if record.Provider else record.providerId provider_display = ( - record.Provider.displayName - if record.Provider - else record.providerId + record.Provider.displayName if record.Provider else record.providerId ) + creator_name = record.Creator.displayName if record.Creator else "Unknown" - # Extract creator name (fallback to "Unknown" if no creator) - creator_name = ( - record.Creator.displayName if record.Creator else "Unknown" - ) - - # Price tier defaults to 1 if not set if record.priceTier not in (1, 2, 3): logger.warning( - f"LlmModel {record.slug} has out-of-range priceTier={record.priceTier}, " - "defaulting to 1" + "LlmModel %s has out-of-range priceTier=%s, defaulting to 1", + record.slug, + record.priceTier, ) price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1 @@ -164,41 +154,53 @@ def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: ) -async def refresh_llm_registry() -> None: - """ - Refresh the LLM registry from the database. +@cached(maxsize=1, ttl_seconds=300, shared_cache=True, refresh_ttl_on_get=True) +async def _fetch_registry_from_db() -> list[RegistryModel]: + """Fetch all LLM models from the database. - Fetches all models with their costs, providers, and creators, - then updates the in-memory cache. + Results are cached in Redis (shared_cache=True) so subsequent calls within + the TTL window skip the DB entirely — both within this process and across + all other workers that share the same Redis instance. + """ + records = await prisma.models.LlmModel.prisma().find_many( # type: ignore[attr-defined] + include={"Provider": True, "Costs": True, "Creator": True} + ) + logger.info("Fetched %d LLM models from database", len(records)) + return [_record_to_registry_model(r) for r in records] + + +def clear_registry_cache() -> None: + """Invalidate the shared Redis cache for the registry DB fetch. + + Call this before refresh_llm_registry() after any admin DB mutation so the + next fetch hits the database rather than serving the now-stale cached data. + """ + _fetch_registry_from_db.cache_clear() + + +async def refresh_llm_registry() -> None: + """Refresh the in-process L1 cache from Redis/DB. + + On the first call (or after clear_registry_cache()), fetches fresh data + from the database and stores it in Redis. Subsequent calls by other + workers hit the Redis cache instead of the DB. """ async with _lock: try: - records = await prisma.models.LlmModel.prisma().find_many( - include={ - "Provider": True, - "Costs": True, - "Creator": True, - } - ) - logger.info(f"Fetched {len(records)} LLM models from database") + models = await _fetch_registry_from_db() + new_models = {m.slug: m for m in models} - # Build model instances - new_models: dict[str, RegistryModel] = {} - for record in records: - model = _record_to_registry_model(record) - new_models[record.slug] = model - - # Atomic swap global _dynamic_models, _schema_options _dynamic_models = new_models _schema_options = _build_schema_options() logger.info( - f"LLM registry refreshed: {len(_dynamic_models)} models, " - f"{len(_schema_options)} schema options" + "LLM registry refreshed: %d models, %d schema options", + len(_dynamic_models), + len(_schema_options), ) except Exception as e: - logger.error(f"Failed to refresh LLM registry: {e}", exc_info=True) + logger.error("Failed to refresh LLM registry: %s", e, exc_info=True) raise @@ -240,23 +242,13 @@ def get_schema_options() -> list[dict[str, str]]: def get_default_model_slug() -> str | None: """Get the default model slug (first recommended, or first enabled).""" - # Sort once and use next() to short-circuit on first match models = sorted(_dynamic_models.values(), key=lambda m: m.display_name) - - # Prefer recommended models recommended = next( (m.slug for m in models if m.is_recommended and m.is_enabled), None ) - if recommended: - return recommended - - # Fallback to first enabled model - return next((m.slug for m in models if m.is_enabled), None) + return recommended or next((m.slug for m in models if m.is_enabled), None) def get_all_model_slugs_for_validation() -> list[str]: - """ - Get all model slugs for validation (enables migrate_llm_models to work). - Returns slugs for enabled models only. - """ + """Get all model slugs for validation (enabled models only).""" return [model.slug for model in _dynamic_models.values() if model.is_enabled]