fix(backend): Fix cancel_execution can only work once (#9825)

### Changes 🏗️

The recent change to the execution cancelation fix turns out to only
work on the first request.
This PR change fixes it by reworking how the thread_cached work on async
functions.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Cancel agent executions multiple times
This commit is contained in:
Zamil Majdy
2025-04-16 12:33:49 +02:00
committed by GitHub
parent dc9348ec26
commit 71cdc18674
2 changed files with 55 additions and 19 deletions

View File

@@ -1,20 +1,59 @@
import inspect
import threading
from typing import Callable, ParamSpec, TypeVar
from typing import Any, Awaitable, Callable, ParamSpec, TypeVar, cast, overload
P = ParamSpec("P")
R = TypeVar("R")
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
@overload
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
def thread_cached(
func: Callable[P, R] | Callable[P, Awaitable[R]],
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
thread_local = threading.local()
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
if inspect.iscoroutinefunction(func):
return wrapper
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (func, args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
*args, **kwargs
)
return cache[key]
return async_wrapper
else:
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
# Include function in the key to prevent collisions between different functions
key = (func, args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
return sync_wrapper
def clear_thread_cache(func: Callable[..., Any]) -> None:
"""Clear the cache for a thread-cached function."""
thread_local = threading.local()
cache = getattr(thread_local, "cache", None)
if cache is not None:
# Clear all entries that match the function
for key in list(cache.keys()):
if key and len(key) > 0 and key[0] == func:
del cache[key]

View File

@@ -2,7 +2,7 @@ import asyncio
import logging
from collections import defaultdict
from datetime import datetime
from typing import TYPE_CHECKING, Annotated, Any, Coroutine, Sequence
from typing import TYPE_CHECKING, Annotated, Any, Sequence
import pydantic
import stripe
@@ -87,13 +87,10 @@ def execution_scheduler_client() -> Scheduler:
@thread_cached
def execution_queue_client() -> Coroutine[None, None, AsyncRabbitMQ]:
async def f() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
return f()
async def execution_queue_client() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
@thread_cached