fix(backend): Clear RabbitMQ connection cache on execution-manager retry

This commit is contained in:
Zamil Majdy
2025-04-19 07:50:04 +02:00
parent c783f64b33
commit 9052ee7b95
2 changed files with 13 additions and 12 deletions

View File

@@ -1,6 +1,6 @@
import inspect
import threading
from typing import Any, Awaitable, Callable, ParamSpec, TypeVar, cast, overload
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
P = ParamSpec("P")
R = TypeVar("R")
@@ -19,6 +19,10 @@ def thread_cached(
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
thread_local = threading.local()
def _clear():
if hasattr(thread_local, "cache"):
del thread_local.cache
if inspect.iscoroutinefunction(func):
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
@@ -32,28 +36,24 @@ def thread_cached(
)
return cache[key]
setattr(async_wrapper, "clear_cache", _clear)
return async_wrapper
else:
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
# Include function in the key to prevent collisions between different functions
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
setattr(sync_wrapper, "clear_cache", _clear)
return sync_wrapper
def clear_thread_cache(func: Callable[..., Any]) -> None:
"""Clear the cache for a thread-cached function."""
thread_local = threading.local()
cache = getattr(thread_local, "cache", None)
if cache is not None:
# Clear all entries that match the function
for key in list(cache.keys()):
if key and len(key) > 0 and key[0] == func:
del cache[key]
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()

View File

@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from backend.executor import DatabaseManager
from backend.notifications.notifications import NotificationManager
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
@@ -935,6 +935,7 @@ class ExecutionManager(AppProcess):
redis.connect()
# Consume Cancel & Run execution requests.
clear_thread_cache(get_execution_queue)
channel = get_execution_queue().get_channel()
channel.basic_qos(prefetch_count=self.pool_size)
channel.basic_consume(