mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
feat(backend/llm-registry): add Redis-backed cache and cross-process pub/sub sync
- Wrap DB fetch with @cached(shared_cache=True) so results are stored in Redis automatically — other workers skip the DB on warm cache - Add notifications.py with publish/subscribe helpers using llm_registry:refresh pub/sub channel for cross-process invalidation - clear_registry_cache() invalidates the shared Redis entry before a forced DB refresh (called by admin mutations) - rest_api.py: start a background subscription task so every worker reloads its in-process cache when another worker refreshes the registry
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import platform
|
||||
@@ -118,18 +119,26 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
# Load LLM registry before initializing blocks so blocks can use registry data.
|
||||
# Tries Redis first (fast path on warm restart), falls back to DB.
|
||||
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
|
||||
# When block integration lands, this should fail hard or skip block initialization
|
||||
try:
|
||||
await backend.data.llm_registry.refresh_llm_registry()
|
||||
logger.info("LLM registry refreshed successfully at startup")
|
||||
logger.info("LLM registry loaded successfully at startup")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to refresh LLM registry at startup: {e}. "
|
||||
f"Failed to load LLM registry at startup: {e}. "
|
||||
"Blocks will initialize with empty registry."
|
||||
)
|
||||
|
||||
# Start background task so this worker reloads its in-process cache whenever
|
||||
# another worker (e.g. the admin API) refreshes the registry.
|
||||
_registry_subscription_task = asyncio.create_task(
|
||||
backend.data.llm_registry.subscribe_to_registry_refresh(
|
||||
backend.data.llm_registry.refresh_llm_registry
|
||||
)
|
||||
)
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
@@ -154,6 +163,12 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
_registry_subscription_task.cancel()
|
||||
try:
|
||||
await _registry_subscription_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
"""LLM Registry - Dynamic model management system."""
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from .notifications import (
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
from .registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
clear_registry_cache,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
@@ -20,7 +25,11 @@ __all__ = [
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Functions
|
||||
# Cache management
|
||||
"clear_registry_cache",
|
||||
"publish_registry_refresh_notification",
|
||||
"subscribe_to_registry_refresh",
|
||||
# Read functions
|
||||
"refresh_llm_registry",
|
||||
"get_model",
|
||||
"get_all_models",
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Pub/sub notifications for LLM registry cross-process synchronisation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
async def publish_registry_refresh_notification() -> None:
|
||||
"""Publish a refresh signal so all other workers reload their in-process cache."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.debug("Published LLM registry refresh notification")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish registry refresh notification: %s", e)
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Listen for registry refresh signals and call on_refresh each time one arrives.
|
||||
|
||||
Designed to run as a long-lived background asyncio.Task. Automatically
|
||||
reconnects if the Redis connection drops.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable invoked on each refresh signal.
|
||||
Typically ``llm_registry.refresh_llm_registry``.
|
||||
"""
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Dedicated connection — pub/sub must not share a connection used
|
||||
# for regular commands.
|
||||
redis_sub = AsyncRedis(
|
||||
host=HOST, port=PORT, password=PASSWORD, decode_responses=True
|
||||
)
|
||||
pubsub = redis_sub.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info("Subscribed to LLM registry refresh channel")
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if (
|
||||
message
|
||||
and message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.debug("LLM registry refresh signal received")
|
||||
try:
|
||||
await on_refresh()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in registry on_refresh callback: %s", e
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error processing registry refresh message: %s", e
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("LLM registry subscription task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"LLM registry subscription error: %s. Retrying in 5s...", e
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
@@ -10,6 +10,7 @@ import prisma.models
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -65,18 +66,17 @@ class RegistryModel(BaseModel):
|
||||
supports_parallel_tool_calls: bool = False
|
||||
|
||||
|
||||
# In-memory cache (will be replaced with Redis in PR #6)
|
||||
# L1 in-process cache — Redis is the shared L2 via @cached(shared_cache=True)
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel:
|
||||
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: # type: ignore[name-defined]
|
||||
"""Transform a raw Prisma LlmModel record into a RegistryModel instance."""
|
||||
# Parse costs
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
unit=str(cost.unit), # Convert enum to string
|
||||
unit=str(cost.unit),
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
@@ -87,7 +87,6 @@ def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel:
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
# Parse creator
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
@@ -99,35 +98,26 @@ def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel:
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
# Parse capabilities
|
||||
capabilities = dict(record.capabilities or {})
|
||||
|
||||
# Build metadata from record
|
||||
# Warn if Provider relation is missing (indicates data corruption)
|
||||
if not record.Provider:
|
||||
logger.warning(
|
||||
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
|
||||
f"falling back to providerId {record.providerId}"
|
||||
"LlmModel %s has no Provider despite NOT NULL FK - "
|
||||
"falling back to providerId %s",
|
||||
record.slug,
|
||||
record.providerId,
|
||||
)
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
provider_name = record.Provider.name if record.Provider else record.providerId
|
||||
provider_display = (
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
record.Provider.displayName if record.Provider else record.providerId
|
||||
)
|
||||
creator_name = record.Creator.displayName if record.Creator else "Unknown"
|
||||
|
||||
# Extract creator name (fallback to "Unknown" if no creator)
|
||||
creator_name = (
|
||||
record.Creator.displayName if record.Creator else "Unknown"
|
||||
)
|
||||
|
||||
# Price tier defaults to 1 if not set
|
||||
if record.priceTier not in (1, 2, 3):
|
||||
logger.warning(
|
||||
f"LlmModel {record.slug} has out-of-range priceTier={record.priceTier}, "
|
||||
"defaulting to 1"
|
||||
"LlmModel %s has out-of-range priceTier=%s, defaulting to 1",
|
||||
record.slug,
|
||||
record.priceTier,
|
||||
)
|
||||
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
|
||||
|
||||
@@ -164,41 +154,53 @@ def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel:
|
||||
)
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""
|
||||
Refresh the LLM registry from the database.
|
||||
@cached(maxsize=1, ttl_seconds=300, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def _fetch_registry_from_db() -> list[RegistryModel]:
|
||||
"""Fetch all LLM models from the database.
|
||||
|
||||
Fetches all models with their costs, providers, and creators,
|
||||
then updates the in-memory cache.
|
||||
Results are cached in Redis (shared_cache=True) so subsequent calls within
|
||||
the TTL window skip the DB entirely — both within this process and across
|
||||
all other workers that share the same Redis instance.
|
||||
"""
|
||||
records = await prisma.models.LlmModel.prisma().find_many( # type: ignore[attr-defined]
|
||||
include={"Provider": True, "Costs": True, "Creator": True}
|
||||
)
|
||||
logger.info("Fetched %d LLM models from database", len(records))
|
||||
return [_record_to_registry_model(r) for r in records]
|
||||
|
||||
|
||||
def clear_registry_cache() -> None:
|
||||
"""Invalidate the shared Redis cache for the registry DB fetch.
|
||||
|
||||
Call this before refresh_llm_registry() after any admin DB mutation so the
|
||||
next fetch hits the database rather than serving the now-stale cached data.
|
||||
"""
|
||||
_fetch_registry_from_db.cache_clear()
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the in-process L1 cache from Redis/DB.
|
||||
|
||||
On the first call (or after clear_registry_cache()), fetches fresh data
|
||||
from the database and stores it in Redis. Subsequent calls by other
|
||||
workers hit the Redis cache instead of the DB.
|
||||
"""
|
||||
async with _lock:
|
||||
try:
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
"Creator": True,
|
||||
}
|
||||
)
|
||||
logger.info(f"Fetched {len(records)} LLM models from database")
|
||||
models = await _fetch_registry_from_db()
|
||||
new_models = {m.slug: m for m in models}
|
||||
|
||||
# Build model instances
|
||||
new_models: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
model = _record_to_registry_model(record)
|
||||
new_models[record.slug] = model
|
||||
|
||||
# Atomic swap
|
||||
global _dynamic_models, _schema_options
|
||||
_dynamic_models = new_models
|
||||
_schema_options = _build_schema_options()
|
||||
|
||||
logger.info(
|
||||
f"LLM registry refreshed: {len(_dynamic_models)} models, "
|
||||
f"{len(_schema_options)} schema options"
|
||||
"LLM registry refreshed: %d models, %d schema options",
|
||||
len(_dynamic_models),
|
||||
len(_schema_options),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh LLM registry: {e}", exc_info=True)
|
||||
logger.error("Failed to refresh LLM registry: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -240,23 +242,13 @@ def get_schema_options() -> list[dict[str, str]]:
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""Get the default model slug (first recommended, or first enabled)."""
|
||||
# Sort once and use next() to short-circuit on first match
|
||||
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
|
||||
|
||||
# Prefer recommended models
|
||||
recommended = next(
|
||||
(m.slug for m in models if m.is_recommended and m.is_enabled), None
|
||||
)
|
||||
if recommended:
|
||||
return recommended
|
||||
|
||||
# Fallback to first enabled model
|
||||
return next((m.slug for m in models if m.is_enabled), None)
|
||||
return recommended or next((m.slug for m in models if m.is_enabled), None)
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> list[str]:
|
||||
"""
|
||||
Get all model slugs for validation (enables migrate_llm_models to work).
|
||||
Returns slugs for enabled models only.
|
||||
"""
|
||||
"""Get all model slugs for validation (enabled models only)."""
|
||||
return [model.slug for model in _dynamic_models.values() if model.is_enabled]
|
||||
|
||||
Reference in New Issue
Block a user