From b5f63c13a42005277db81204a46159b0087f68ad Mon Sep 17 00:00:00 2001 From: Bentlybro Date: Wed, 8 Apr 2026 13:38:26 +0100 Subject: [PATCH] test(backend/llm-registry): comprehensive registry + notifications tests registry_test.py (+8 tests): - clear_registry_cache, get_model (found/not found), get_all_models, get_enabled_models, get_all_model_slugs_for_validation - refresh_llm_registry error re-raise notifications_test.py (new, 9 tests): - publish: happy path and Redis error swallowed - subscribe: valid message triggers on_refresh, non-message types ignored, wrong channel ignored, None (timeout) handled, multiple messages, CancelledError stops loop, connection error triggers reconnect --- .../data/llm_registry/notifications_test.py | 195 ++++++++++++++++++ .../data/llm_registry/registry_test.py | 108 ++++++++++ 2 files changed, 303 insertions(+) create mode 100644 autogpt_platform/backend/backend/data/llm_registry/notifications_test.py diff --git a/autogpt_platform/backend/backend/data/llm_registry/notifications_test.py b/autogpt_platform/backend/backend/data/llm_registry/notifications_test.py new file mode 100644 index 0000000000..828c6faecd --- /dev/null +++ b/autogpt_platform/backend/backend/data/llm_registry/notifications_test.py @@ -0,0 +1,195 @@ +"""Tests for LLM registry pub/sub notifications (notifications.py). + +Covers: +- publish_registry_refresh_notification: happy path and Redis error swallowed +- subscribe_to_registry_refresh: message triggers on_refresh, non-message + types ignored, wrong channel ignored, CancelledError stops the loop, + connection errors trigger reconnect +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from backend.data.llm_registry.notifications import ( + REGISTRY_REFRESH_CHANNEL, + publish_registry_refresh_notification, + subscribe_to_registry_refresh, +) + + +# --------------------------------------------------------------------------- +# publish_registry_refresh_notification +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_sends_to_correct_channel(mocker): + """publish_registry_refresh_notification publishes on the registry channel.""" + mock_redis = AsyncMock() + mocker.patch( + "backend.data.redis_client.get_redis_async", + return_value=mock_redis, + ) + + await publish_registry_refresh_notification() + + mock_redis.publish.assert_called_once_with(REGISTRY_REFRESH_CHANNEL, "refresh") + + +@pytest.mark.asyncio +async def test_publish_swallows_redis_error(mocker): + """Redis errors during publish are caught and logged, not raised.""" + mocker.patch( + "backend.data.redis_client.get_redis_async", + side_effect=ConnectionError("Redis unavailable"), + ) + + # Should not raise — errors are swallowed to avoid crashing the admin op + await publish_registry_refresh_notification() + + +# --------------------------------------------------------------------------- +# subscribe_to_registry_refresh +# --------------------------------------------------------------------------- + + +def _make_pubsub(messages: list) -> MagicMock: + """Build a mock pubsub that returns messages in sequence then raises CancelledError.""" + pubsub = AsyncMock() + pubsub.subscribe = AsyncMock() + # Once messages are exhausted the next get_message raises CancelledError to + # break the infinite loop cleanly in tests. + pubsub.get_message = AsyncMock( + side_effect=messages + [asyncio.CancelledError()] + ) + return pubsub + + +def _make_message(channel: str = REGISTRY_REFRESH_CHANNEL, msg_type: str = "message"): + return {"type": msg_type, "channel": channel, "data": "refresh"} + + +@pytest.mark.asyncio +async def test_subscribe_calls_on_refresh_for_valid_message(mocker): + """A message on the registry channel triggers the on_refresh callback.""" + on_refresh = AsyncMock() + pubsub = _make_pubsub([_make_message()]) + + mock_redis = MagicMock() + mock_redis.pubsub.return_value = pubsub + mocker.patch("redis.asyncio.Redis", return_value=mock_redis) + + await subscribe_to_registry_refresh(on_refresh) + + on_refresh.assert_called_once() + + +@pytest.mark.asyncio +async def test_subscribe_ignores_non_message_types(mocker): + """Subscribe messages of type 'subscribe' (handshake) do not trigger on_refresh.""" + on_refresh = AsyncMock() + pubsub = _make_pubsub([ + _make_message(msg_type="subscribe"), # handshake — should be ignored + _make_message(msg_type="psubscribe"), # also ignored + ]) + + mock_redis = MagicMock() + mock_redis.pubsub.return_value = pubsub + mocker.patch("redis.asyncio.Redis", return_value=mock_redis) + + await subscribe_to_registry_refresh(on_refresh) + + on_refresh.assert_not_called() + + +@pytest.mark.asyncio +async def test_subscribe_ignores_wrong_channel(mocker): + """Messages on a different channel do not trigger on_refresh.""" + on_refresh = AsyncMock() + pubsub = _make_pubsub([_make_message(channel="some:other:channel")]) + + mock_redis = MagicMock() + mock_redis.pubsub.return_value = pubsub + mocker.patch("redis.asyncio.Redis", return_value=mock_redis) + + await subscribe_to_registry_refresh(on_refresh) + + on_refresh.assert_not_called() + + +@pytest.mark.asyncio +async def test_subscribe_handles_none_message(mocker): + """None returned by get_message (timeout) does not crash or trigger on_refresh.""" + on_refresh = AsyncMock() + pubsub = _make_pubsub([None, None]) # two timeouts, then CancelledError + + mock_redis = MagicMock() + mock_redis.pubsub.return_value = pubsub + mocker.patch("redis.asyncio.Redis", return_value=mock_redis) + + await subscribe_to_registry_refresh(on_refresh) + + on_refresh.assert_not_called() + + +@pytest.mark.asyncio +async def test_subscribe_processes_multiple_messages(mocker): + """Multiple valid messages each trigger on_refresh.""" + on_refresh = AsyncMock() + pubsub = _make_pubsub([_make_message(), _make_message(), _make_message()]) + + mock_redis = MagicMock() + mock_redis.pubsub.return_value = pubsub + mocker.patch("redis.asyncio.Redis", return_value=mock_redis) + + await subscribe_to_registry_refresh(on_refresh) + + assert on_refresh.call_count == 3 + + +@pytest.mark.asyncio +async def test_subscribe_cancelled_error_stops_loop(mocker): + """CancelledError at the outer loop level causes the function to return cleanly.""" + on_refresh = AsyncMock() + pubsub = AsyncMock() + pubsub.subscribe = AsyncMock(side_effect=asyncio.CancelledError()) + + mock_redis = MagicMock() + mock_redis.pubsub.return_value = pubsub + mocker.patch("redis.asyncio.Redis", return_value=mock_redis) + + # Should return normally — CancelledError is caught and the loop breaks + await subscribe_to_registry_refresh(on_refresh) + + on_refresh.assert_not_called() + + +@pytest.mark.asyncio +async def test_subscribe_reconnects_after_connection_error(mocker): + """A connection error on the first attempt triggers a reconnect attempt.""" + on_refresh = AsyncMock() + + # First call raises ConnectionError; second call succeeds then CancelledError + good_pubsub = _make_pubsub([_make_message()]) + bad_redis = MagicMock() + bad_redis.pubsub.side_effect = ConnectionError("Redis down") + good_redis = MagicMock() + good_redis.pubsub.return_value = good_pubsub + + mock_redis_cls = mocker.patch( + "redis.asyncio.Redis", side_effect=[bad_redis, good_redis] + ) + mock_sleep = mocker.patch("asyncio.sleep", new=AsyncMock()) + + await subscribe_to_registry_refresh(on_refresh) + + # Should have slept before retrying + mock_sleep.assert_called_once_with(5) + # Should have tried to create two Redis connections + assert mock_redis_cls.call_count == 2 + # After reconnect, the valid message triggered on_refresh + on_refresh.assert_called_once() diff --git a/autogpt_platform/backend/backend/data/llm_registry/registry_test.py b/autogpt_platform/backend/backend/data/llm_registry/registry_test.py index 2e10a4096c..b56519862b 100644 --- a/autogpt_platform/backend/backend/data/llm_registry/registry_test.py +++ b/autogpt_platform/backend/backend/data/llm_registry/registry_test.py @@ -14,7 +14,12 @@ from backend.data.llm_registry.registry import ( RegistryModelCreator, _build_schema_options, _record_to_registry_model, + clear_registry_cache, + get_all_model_slugs_for_validation, + get_all_models, get_default_model_slug, + get_enabled_models, + get_model, get_schema_options, refresh_llm_registry, ) @@ -356,3 +361,106 @@ async def test_refresh_llm_registry(): # Schema options should be populated too assert len(reg._schema_options) == 1 assert reg._schema_options[0]["value"] == "openai/gpt-4o" + + +# --------------------------------------------------------------------------- +# clear_registry_cache tests +# --------------------------------------------------------------------------- + + +def test_clear_registry_cache(): + """clear_registry_cache calls cache_clear on the cached fetch function.""" + import backend.data.llm_registry.registry as reg + from unittest.mock import patch + + with patch.object(reg._fetch_registry_from_db, "cache_clear") as mock_clear: + clear_registry_cache() + mock_clear.assert_called_once() + + +# --------------------------------------------------------------------------- +# get_model / get_all_models / get_enabled_models / get_all_model_slugs tests +# --------------------------------------------------------------------------- + + +def test_get_model_found(): + """get_model returns the model when the slug exists in the registry.""" + import backend.data.llm_registry.registry as reg + + m = _make_registry_model(slug="openai/gpt-4o") + reg._dynamic_models = {"openai/gpt-4o": m} + + result = get_model("openai/gpt-4o") + + assert result is m + + +def test_get_model_not_found(): + """get_model returns None for an unknown slug.""" + import backend.data.llm_registry.registry as reg + + reg._dynamic_models = {} + + assert get_model("nonexistent/model") is None + + +def test_get_all_models_includes_disabled(): + """get_all_models returns all models, including disabled ones.""" + import backend.data.llm_registry.registry as reg + + enabled = _make_registry_model(slug="openai/gpt-4", is_enabled=True) + disabled = _make_registry_model(slug="openai/old-model", is_enabled=False) + reg._dynamic_models = {"openai/gpt-4": enabled, "openai/old-model": disabled} + + result = get_all_models() + + slugs = [m.slug for m in result] + assert "openai/gpt-4" in slugs + assert "openai/old-model" in slugs + assert len(result) == 2 + + +def test_get_enabled_models_excludes_disabled(): + """get_enabled_models returns only models where is_enabled=True.""" + import backend.data.llm_registry.registry as reg + + enabled = _make_registry_model(slug="openai/gpt-4", is_enabled=True) + disabled = _make_registry_model(slug="openai/old-model", is_enabled=False) + reg._dynamic_models = {"openai/gpt-4": enabled, "openai/old-model": disabled} + + result = get_enabled_models() + + assert len(result) == 1 + assert result[0].slug == "openai/gpt-4" + + +def test_get_all_model_slugs_for_validation(): + """get_all_model_slugs_for_validation returns only enabled model slugs.""" + import backend.data.llm_registry.registry as reg + + reg._dynamic_models = { + "openai/gpt-4": _make_registry_model(slug="openai/gpt-4", is_enabled=True), + "openai/old": _make_registry_model(slug="openai/old", is_enabled=False), + "anthropic/claude": _make_registry_model( + slug="anthropic/claude", is_enabled=True + ), + } + + result = get_all_model_slugs_for_validation() + + assert "openai/gpt-4" in result + assert "anthropic/claude" in result + assert "openai/old" not in result + assert len(result) == 2 + + +@pytest.mark.asyncio +async def test_refresh_llm_registry_error_is_reraised(mocker): + """refresh_llm_registry re-raises exceptions after logging them.""" + mocker.patch( + "backend.data.llm_registry.registry._fetch_registry_from_db", + new=AsyncMock(side_effect=RuntimeError("DB unavailable")), + ) + + with pytest.raises(RuntimeError, match="DB unavailable"): + await refresh_llm_registry()