mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): Fix cancel_execution can only work once (#9825)
### Changes 🏗️ The recent change to the execution cancelation fix turns out to only work on the first request. This PR change fixes it by reworking how the thread_cached work on async functions. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Cancel agent executions multiple times
This commit is contained in:
@@ -1,20 +1,59 @@
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
from typing import Any, Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
return wrapper
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (func, args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
# Include function in the key to prevent collisions between different functions
|
||||
key = (func, args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable[..., Any]) -> None:
|
||||
"""Clear the cache for a thread-cached function."""
|
||||
thread_local = threading.local()
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is not None:
|
||||
# Clear all entries that match the function
|
||||
for key in list(cache.keys()):
|
||||
if key and len(key) > 0 and key[0] == func:
|
||||
del cache[key]
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Coroutine, Sequence
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Sequence
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
@@ -87,13 +87,10 @@ def execution_scheduler_client() -> Scheduler:
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_queue_client() -> Coroutine[None, None, AsyncRabbitMQ]:
|
||||
async def f() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
return f()
|
||||
async def execution_queue_client() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
|
||||
Reference in New Issue
Block a user