fix(backend): fix race condition on _copilot_tasks concurrent iteration during drain

Add _pending_log_tasks_lock to token_tracking.py so that add/discard
operations on _pending_log_tasks are always lock-protected. Update
drain_pending_cost_logs in cost_tracking.py to acquire the copilot
tasks lock (not its own lock) when taking a snapshot of the copilot
set, preventing RuntimeError: Set changed size during iteration during
graceful shutdown when done callbacks fire concurrently.
This commit is contained in:
Zamil Majdy
2026-04-07 22:48:51 +07:00
parent cf605ef5a3
commit 5164fa878f
2 changed files with 17 additions and 3 deletions

View File

@@ -12,6 +12,7 @@ This module extracts that common logic so both paths stay in sync.
import asyncio
import logging
import re
import threading
from backend.data.db_accessors import platform_cost_db
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
@@ -23,6 +24,10 @@ logger = logging.getLogger(__name__)
# Hold strong references to in-flight cost log tasks to prevent GC.
_pending_log_tasks: set[asyncio.Task[None]] = set()
# Guards all reads and writes to _pending_log_tasks. Done callbacks (discard)
# fire from the event loop thread; drain_pending_cost_logs iterates the set
# from any caller — the lock prevents RuntimeError from concurrent modification.
_pending_log_tasks_lock = threading.Lock()
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
# shared across event loops running in different threads.
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
@@ -53,8 +58,14 @@ def _schedule_cost_log(entry: PlatformCostEntry) -> None:
)
task = asyncio.create_task(_safe_log())
_pending_log_tasks.add(task)
task.add_done_callback(_pending_log_tasks.discard)
with _pending_log_tasks_lock:
_pending_log_tasks.add(task)
def _remove(t: asyncio.Task[None]) -> None:
with _pending_log_tasks_lock:
_pending_log_tasks.discard(t)
task.add_done_callback(_remove)
# Identifiers used by PlatformCostLog for copilot turns (not tied to a real

View File

@@ -86,8 +86,11 @@ async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
from backend.copilot.token_tracking import ( # noqa: PLC0415
_pending_log_tasks as _copilot_tasks,
)
from backend.copilot.token_tracking import ( # noqa: PLC0415
_pending_log_tasks_lock as _copilot_tasks_lock,
)
with _pending_log_tasks_lock:
with _copilot_tasks_lock:
copilot_pending = [t for t in _copilot_tasks if t.get_loop() is current_loop]
if copilot_pending:
logger.info("Draining %d copilot cost log task(s)", len(copilot_pending))