fix(backend/copilot): route copilot cost logging through DatabaseManagerAsyncClient

The copilot executor's token_tracking.py was using schedule_cost_log()
which calls execute_raw_with_schema() directly on the Prisma singleton.
In the copilot_executor process, Prisma is not reliably connected due to
event-loop binding issues, causing ClientNotConnectedError on every turn.

Fix: route cost logging through platform_cost_db() -> DatabaseManagerAsyncClient
RPC (same approach already used by the block executor). Also fix
_copilot_block_name() to extract only the service tag from the log prefix
(e.g. "[SDK][session-id][T1]" -> "copilot:SDK") instead of the full suffix.

Update cost_tracking.py drain to drain token_tracking._pending_log_tasks,
and update token_tracking_test.py mocks to match new call site.
This commit is contained in:
Zamil Majdy
2026-04-07 16:58:41 +07:00
parent 254e6057f4
commit 2a73d1baa9
4 changed files with 79 additions and 26 deletions

View File

@@ -9,19 +9,41 @@ Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
This module extracts that common logic so both paths stay in sync.
"""
import asyncio
import logging
import re
from backend.data.platform_cost import (
PlatformCostEntry,
schedule_cost_log,
usd_to_microdollars,
)
from backend.data.db_accessors import platform_cost_db
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
from .model import ChatSession, Usage
from .rate_limit import record_token_usage
logger = logging.getLogger(__name__)
# Hold strong references to in-flight cost log tasks to prevent GC.
_pending_log_tasks: set["asyncio.Task[None]"] = set()
def _schedule_cost_log(entry: PlatformCostEntry) -> None:
"""Schedule a fire-and-forget cost log via DatabaseManagerAsyncClient RPC."""
async def _safe_log() -> None:
try:
await platform_cost_db().log_platform_cost(entry)
except Exception:
logger.exception(
"Failed to log platform cost for user=%s provider=%s block=%s",
entry.user_id,
entry.provider,
entry.block_name,
)
task = asyncio.create_task(_safe_log())
_pending_log_tasks.add(task)
task.add_done_callback(_pending_log_tasks.discard)
# Identifiers used by PlatformCostLog for copilot turns (not tied to a real
# block/credential in the block_cost_config or credentials_store tables).
COPILOT_BLOCK_ID = "copilot"
@@ -29,8 +51,10 @@ COPILOT_CREDENTIAL_ID = "copilot_system"
def _copilot_block_name(log_prefix: str) -> str:
"""Turn a log prefix like ``"[SDK]"`` into a stable block_name
``"copilot:SDK"``. Empty prefix becomes just ``"copilot"``."""
"""Extract stable block_name from ``"[SDK][session][T1]"`` -> ``"copilot:SDK"``."""
match = re.search(r"\[([A-Za-z][A-Za-z0-9_]*)\]", log_prefix)
if match:
return f"{COPILOT_BLOCK_ID}:{match.group(1)}"
tag = log_prefix.strip(" []")
return f"{COPILOT_BLOCK_ID}:{tag}" if tag else COPILOT_BLOCK_ID
@@ -148,7 +172,7 @@ async def persist_and_record_usage(
tracking_type = "tokens"
tracking_amount = total_tokens
schedule_cost_log(
_schedule_cost_log(
PlatformCostEntry(
user_id=user_id,
graph_exec_id=session_id,

View File

@@ -298,8 +298,10 @@ class TestPlatformCostLogging:
new_callable=AsyncMock,
),
patch(
"backend.data.platform_cost.log_platform_cost_safe",
new=mock_log,
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
@@ -337,8 +339,10 @@ class TestPlatformCostLogging:
new_callable=AsyncMock,
),
patch(
"backend.data.platform_cost.log_platform_cost_safe",
new=mock_log,
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
@@ -368,8 +372,10 @@ class TestPlatformCostLogging:
new_callable=AsyncMock,
),
patch(
"backend.data.platform_cost.log_platform_cost_safe",
new=mock_log,
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
@@ -391,8 +397,10 @@ class TestPlatformCostLogging:
new_callable=AsyncMock,
),
patch(
"backend.data.platform_cost.log_platform_cost_safe",
new=mock_log,
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
@@ -418,8 +426,10 @@ class TestPlatformCostLogging:
new_callable=AsyncMock,
),
patch(
"backend.data.platform_cost.log_platform_cost_safe",
new=mock_log,
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
@@ -445,8 +455,10 @@ class TestPlatformCostLogging:
new_callable=AsyncMock,
),
patch(
"backend.data.platform_cost.log_platform_cost_safe",
new=mock_log,
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
@@ -470,8 +482,10 @@ class TestPlatformCostLogging:
new_callable=AsyncMock,
),
patch(
"backend.data.platform_cost.log_platform_cost_safe",
new=mock_log,
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(

View File

@@ -142,3 +142,9 @@ def credit_db():
credit_db = get_database_manager_async_client()
return credit_db
def platform_cost_db():
from backend.util.clients import get_database_manager_async_client
return get_database_manager_async_client()

View File

@@ -58,12 +58,21 @@ async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
len(still_pending),
timeout,
)
# Also drain copilot cost log tasks (platform_cost._pending_log_tasks)
from backend.data.platform_cost import ( # noqa: PLC0415
drain_pending_cost_logs as _drain_copilot,
# Also drain copilot cost log tasks (token_tracking._pending_log_tasks)
from backend.copilot.token_tracking import ( # noqa: PLC0415
_pending_log_tasks as _copilot_tasks,
)
await _drain_copilot()
copilot_pending = list(_copilot_tasks)
if copilot_pending:
logger.info("Draining %d copilot cost log task(s)", len(copilot_pending))
_, still_pending = await asyncio.wait(copilot_pending, timeout=timeout)
if still_pending:
logger.warning(
"%d copilot cost log task(s) did not complete within %.1fs",
len(still_pending),
timeout,
)
def _schedule_log(