mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): fix race condition on _copilot_tasks concurrent iteration during drain
Add _pending_log_tasks_lock to token_tracking.py so that add/discard operations on _pending_log_tasks are always lock-protected. Update drain_pending_cost_logs in cost_tracking.py to acquire the copilot tasks lock (not its own lock) when taking a snapshot of the copilot set, preventing RuntimeError: Set changed size during iteration during graceful shutdown when done callbacks fire concurrently.
This commit is contained in:
@@ -12,6 +12,7 @@ This module extracts that common logic so both paths stay in sync.
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
|
||||
from backend.data.db_accessors import platform_cost_db
|
||||
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
|
||||
@@ -23,6 +24,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Hold strong references to in-flight cost log tasks to prevent GC.
|
||||
_pending_log_tasks: set[asyncio.Task[None]] = set()
|
||||
# Guards all reads and writes to _pending_log_tasks. Done callbacks (discard)
|
||||
# fire from the event loop thread; drain_pending_cost_logs iterates the set
|
||||
# from any caller — the lock prevents RuntimeError from concurrent modification.
|
||||
_pending_log_tasks_lock = threading.Lock()
|
||||
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
|
||||
# shared across event loops running in different threads.
|
||||
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
|
||||
@@ -53,8 +58,14 @@ def _schedule_cost_log(entry: PlatformCostEntry) -> None:
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_safe_log())
|
||||
_pending_log_tasks.add(task)
|
||||
task.add_done_callback(_pending_log_tasks.discard)
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.add(task)
|
||||
|
||||
def _remove(t: asyncio.Task[None]) -> None:
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.discard(t)
|
||||
|
||||
task.add_done_callback(_remove)
|
||||
|
||||
|
||||
# Identifiers used by PlatformCostLog for copilot turns (not tied to a real
|
||||
|
||||
@@ -86,8 +86,11 @@ async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
|
||||
from backend.copilot.token_tracking import ( # noqa: PLC0415
|
||||
_pending_log_tasks as _copilot_tasks,
|
||||
)
|
||||
from backend.copilot.token_tracking import ( # noqa: PLC0415
|
||||
_pending_log_tasks_lock as _copilot_tasks_lock,
|
||||
)
|
||||
|
||||
with _pending_log_tasks_lock:
|
||||
with _copilot_tasks_lock:
|
||||
copilot_pending = [t for t in _copilot_tasks if t.get_loop() is current_loop]
|
||||
if copilot_pending:
|
||||
logger.info("Draining %d copilot cost log task(s)", len(copilot_pending))
|
||||
|
||||
Reference in New Issue
Block a user