mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
test(backend/llm-registry): comprehensive registry + notifications tests
registry_test.py (+8 tests): - clear_registry_cache, get_model (found/not found), get_all_models, get_enabled_models, get_all_model_slugs_for_validation - refresh_llm_registry error re-raise notifications_test.py (new, 9 tests): - publish: happy path and Redis error swallowed - subscribe: valid message triggers on_refresh, non-message types ignored, wrong channel ignored, None (timeout) handled, multiple messages, CancelledError stops loop, connection error triggers reconnect
This commit is contained in:
@@ -0,0 +1,195 @@
|
||||
"""Tests for LLM registry pub/sub notifications (notifications.py).
|
||||
|
||||
Covers:
|
||||
- publish_registry_refresh_notification: happy path and Redis error swallowed
|
||||
- subscribe_to_registry_refresh: message triggers on_refresh, non-message
|
||||
types ignored, wrong channel ignored, CancelledError stops the loop,
|
||||
connection errors trigger reconnect
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.llm_registry.notifications import (
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# publish_registry_refresh_notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_sends_to_correct_channel(mocker):
|
||||
"""publish_registry_refresh_notification publishes on the registry channel."""
|
||||
mock_redis = AsyncMock()
|
||||
mocker.patch(
|
||||
"backend.data.redis_client.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
await publish_registry_refresh_notification()
|
||||
|
||||
mock_redis.publish.assert_called_once_with(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_swallows_redis_error(mocker):
|
||||
"""Redis errors during publish are caught and logged, not raised."""
|
||||
mocker.patch(
|
||||
"backend.data.redis_client.get_redis_async",
|
||||
side_effect=ConnectionError("Redis unavailable"),
|
||||
)
|
||||
|
||||
# Should not raise — errors are swallowed to avoid crashing the admin op
|
||||
await publish_registry_refresh_notification()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# subscribe_to_registry_refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_pubsub(messages: list) -> MagicMock:
|
||||
"""Build a mock pubsub that returns messages in sequence then raises CancelledError."""
|
||||
pubsub = AsyncMock()
|
||||
pubsub.subscribe = AsyncMock()
|
||||
# Once messages are exhausted the next get_message raises CancelledError to
|
||||
# break the infinite loop cleanly in tests.
|
||||
pubsub.get_message = AsyncMock(
|
||||
side_effect=messages + [asyncio.CancelledError()]
|
||||
)
|
||||
return pubsub
|
||||
|
||||
|
||||
def _make_message(channel: str = REGISTRY_REFRESH_CHANNEL, msg_type: str = "message"):
|
||||
return {"type": msg_type, "channel": channel, "data": "refresh"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_calls_on_refresh_for_valid_message(mocker):
|
||||
"""A message on the registry channel triggers the on_refresh callback."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([_make_message()])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_ignores_non_message_types(mocker):
|
||||
"""Subscribe messages of type 'subscribe' (handshake) do not trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([
|
||||
_make_message(msg_type="subscribe"), # handshake — should be ignored
|
||||
_make_message(msg_type="psubscribe"), # also ignored
|
||||
])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_ignores_wrong_channel(mocker):
|
||||
"""Messages on a different channel do not trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([_make_message(channel="some:other:channel")])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_handles_none_message(mocker):
|
||||
"""None returned by get_message (timeout) does not crash or trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([None, None]) # two timeouts, then CancelledError
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_processes_multiple_messages(mocker):
|
||||
"""Multiple valid messages each trigger on_refresh."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = _make_pubsub([_make_message(), _make_message(), _make_message()])
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
assert on_refresh.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_cancelled_error_stops_loop(mocker):
|
||||
"""CancelledError at the outer loop level causes the function to return cleanly."""
|
||||
on_refresh = AsyncMock()
|
||||
pubsub = AsyncMock()
|
||||
pubsub.subscribe = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = pubsub
|
||||
mocker.patch("redis.asyncio.Redis", return_value=mock_redis)
|
||||
|
||||
# Should return normally — CancelledError is caught and the loop breaks
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
on_refresh.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_reconnects_after_connection_error(mocker):
|
||||
"""A connection error on the first attempt triggers a reconnect attempt."""
|
||||
on_refresh = AsyncMock()
|
||||
|
||||
# First call raises ConnectionError; second call succeeds then CancelledError
|
||||
good_pubsub = _make_pubsub([_make_message()])
|
||||
bad_redis = MagicMock()
|
||||
bad_redis.pubsub.side_effect = ConnectionError("Redis down")
|
||||
good_redis = MagicMock()
|
||||
good_redis.pubsub.return_value = good_pubsub
|
||||
|
||||
mock_redis_cls = mocker.patch(
|
||||
"redis.asyncio.Redis", side_effect=[bad_redis, good_redis]
|
||||
)
|
||||
mock_sleep = mocker.patch("asyncio.sleep", new=AsyncMock())
|
||||
|
||||
await subscribe_to_registry_refresh(on_refresh)
|
||||
|
||||
# Should have slept before retrying
|
||||
mock_sleep.assert_called_once_with(5)
|
||||
# Should have tried to create two Redis connections
|
||||
assert mock_redis_cls.call_count == 2
|
||||
# After reconnect, the valid message triggered on_refresh
|
||||
on_refresh.assert_called_once()
|
||||
@@ -14,7 +14,12 @@ from backend.data.llm_registry.registry import (
|
||||
RegistryModelCreator,
|
||||
_build_schema_options,
|
||||
_record_to_registry_model,
|
||||
clear_registry_cache,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
get_enabled_models,
|
||||
get_model,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
@@ -356,3 +361,106 @@ async def test_refresh_llm_registry():
|
||||
# Schema options should be populated too
|
||||
assert len(reg._schema_options) == 1
|
||||
assert reg._schema_options[0]["value"] == "openai/gpt-4o"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# clear_registry_cache tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_clear_registry_cache():
|
||||
"""clear_registry_cache calls cache_clear on the cached fetch function."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch.object(reg._fetch_registry_from_db, "cache_clear") as mock_clear:
|
||||
clear_registry_cache()
|
||||
mock_clear.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_model / get_all_models / get_enabled_models / get_all_model_slugs tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_model_found():
|
||||
"""get_model returns the model when the slug exists in the registry."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
m = _make_registry_model(slug="openai/gpt-4o")
|
||||
reg._dynamic_models = {"openai/gpt-4o": m}
|
||||
|
||||
result = get_model("openai/gpt-4o")
|
||||
|
||||
assert result is m
|
||||
|
||||
|
||||
def test_get_model_not_found():
|
||||
"""get_model returns None for an unknown slug."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {}
|
||||
|
||||
assert get_model("nonexistent/model") is None
|
||||
|
||||
|
||||
def test_get_all_models_includes_disabled():
|
||||
"""get_all_models returns all models, including disabled ones."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
enabled = _make_registry_model(slug="openai/gpt-4", is_enabled=True)
|
||||
disabled = _make_registry_model(slug="openai/old-model", is_enabled=False)
|
||||
reg._dynamic_models = {"openai/gpt-4": enabled, "openai/old-model": disabled}
|
||||
|
||||
result = get_all_models()
|
||||
|
||||
slugs = [m.slug for m in result]
|
||||
assert "openai/gpt-4" in slugs
|
||||
assert "openai/old-model" in slugs
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_get_enabled_models_excludes_disabled():
|
||||
"""get_enabled_models returns only models where is_enabled=True."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
enabled = _make_registry_model(slug="openai/gpt-4", is_enabled=True)
|
||||
disabled = _make_registry_model(slug="openai/old-model", is_enabled=False)
|
||||
reg._dynamic_models = {"openai/gpt-4": enabled, "openai/old-model": disabled}
|
||||
|
||||
result = get_enabled_models()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].slug == "openai/gpt-4"
|
||||
|
||||
|
||||
def test_get_all_model_slugs_for_validation():
|
||||
"""get_all_model_slugs_for_validation returns only enabled model slugs."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4": _make_registry_model(slug="openai/gpt-4", is_enabled=True),
|
||||
"openai/old": _make_registry_model(slug="openai/old", is_enabled=False),
|
||||
"anthropic/claude": _make_registry_model(
|
||||
slug="anthropic/claude", is_enabled=True
|
||||
),
|
||||
}
|
||||
|
||||
result = get_all_model_slugs_for_validation()
|
||||
|
||||
assert "openai/gpt-4" in result
|
||||
assert "anthropic/claude" in result
|
||||
assert "openai/old" not in result
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_llm_registry_error_is_reraised(mocker):
|
||||
"""refresh_llm_registry re-raises exceptions after logging them."""
|
||||
mocker.patch(
|
||||
"backend.data.llm_registry.registry._fetch_registry_from_db",
|
||||
new=AsyncMock(side_effect=RuntimeError("DB unavailable")),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="DB unavailable"):
|
||||
await refresh_llm_registry()
|
||||
|
||||
Reference in New Issue
Block a user