mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend/copilot): route copilot cost logging through DatabaseManagerAsyncClient
The copilot executor's token_tracking.py was using schedule_cost_log() which calls execute_raw_with_schema() directly on the Prisma singleton. In the copilot_executor process, Prisma is not reliably connected due to event-loop binding issues, causing ClientNotConnectedError on every turn. Fix: route cost logging through platform_cost_db() -> DatabaseManagerAsyncClient RPC (same approach already used by the block executor). Also fix _copilot_block_name() to extract only the service tag from the log prefix (e.g. "[SDK][session-id][T1]" -> "copilot:SDK") instead of the full suffix. Update cost_tracking.py drain to drain token_tracking._pending_log_tasks, and update token_tracking_test.py mocks to match new call site.
This commit is contained in:
@@ -9,19 +9,41 @@ Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
|
||||
This module extracts that common logic so both paths stay in sync.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
|
||||
from backend.data.platform_cost import (
|
||||
PlatformCostEntry,
|
||||
schedule_cost_log,
|
||||
usd_to_microdollars,
|
||||
)
|
||||
from backend.data.db_accessors import platform_cost_db
|
||||
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .rate_limit import record_token_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hold strong references to in-flight cost log tasks to prevent GC.
|
||||
_pending_log_tasks: set["asyncio.Task[None]"] = set()
|
||||
|
||||
|
||||
def _schedule_cost_log(entry: PlatformCostEntry) -> None:
|
||||
"""Schedule a fire-and-forget cost log via DatabaseManagerAsyncClient RPC."""
|
||||
|
||||
async def _safe_log() -> None:
|
||||
try:
|
||||
await platform_cost_db().log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_safe_log())
|
||||
_pending_log_tasks.add(task)
|
||||
task.add_done_callback(_pending_log_tasks.discard)
|
||||
|
||||
|
||||
# Identifiers used by PlatformCostLog for copilot turns (not tied to a real
|
||||
# block/credential in the block_cost_config or credentials_store tables).
|
||||
COPILOT_BLOCK_ID = "copilot"
|
||||
@@ -29,8 +51,10 @@ COPILOT_CREDENTIAL_ID = "copilot_system"
|
||||
|
||||
|
||||
def _copilot_block_name(log_prefix: str) -> str:
|
||||
"""Turn a log prefix like ``"[SDK]"`` into a stable block_name
|
||||
``"copilot:SDK"``. Empty prefix becomes just ``"copilot"``."""
|
||||
"""Extract stable block_name from ``"[SDK][session][T1]"`` -> ``"copilot:SDK"``."""
|
||||
match = re.search(r"\[([A-Za-z][A-Za-z0-9_]*)\]", log_prefix)
|
||||
if match:
|
||||
return f"{COPILOT_BLOCK_ID}:{match.group(1)}"
|
||||
tag = log_prefix.strip(" []")
|
||||
return f"{COPILOT_BLOCK_ID}:{tag}" if tag else COPILOT_BLOCK_ID
|
||||
|
||||
@@ -148,7 +172,7 @@ async def persist_and_record_usage(
|
||||
tracking_type = "tokens"
|
||||
tracking_amount = total_tokens
|
||||
|
||||
schedule_cost_log(
|
||||
_schedule_cost_log(
|
||||
PlatformCostEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=session_id,
|
||||
|
||||
@@ -298,8 +298,10 @@ class TestPlatformCostLogging:
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.log_platform_cost_safe",
|
||||
new=mock_log,
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -337,8 +339,10 @@ class TestPlatformCostLogging:
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.log_platform_cost_safe",
|
||||
new=mock_log,
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -368,8 +372,10 @@ class TestPlatformCostLogging:
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.log_platform_cost_safe",
|
||||
new=mock_log,
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -391,8 +397,10 @@ class TestPlatformCostLogging:
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.log_platform_cost_safe",
|
||||
new=mock_log,
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -418,8 +426,10 @@ class TestPlatformCostLogging:
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.log_platform_cost_safe",
|
||||
new=mock_log,
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -445,8 +455,10 @@ class TestPlatformCostLogging:
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.log_platform_cost_safe",
|
||||
new=mock_log,
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -470,8 +482,10 @@ class TestPlatformCostLogging:
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.log_platform_cost_safe",
|
||||
new=mock_log,
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
|
||||
@@ -142,3 +142,9 @@ def credit_db():
|
||||
credit_db = get_database_manager_async_client()
|
||||
|
||||
return credit_db
|
||||
|
||||
|
||||
def platform_cost_db():
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
return get_database_manager_async_client()
|
||||
|
||||
@@ -58,12 +58,21 @@ async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
|
||||
len(still_pending),
|
||||
timeout,
|
||||
)
|
||||
# Also drain copilot cost log tasks (platform_cost._pending_log_tasks)
|
||||
from backend.data.platform_cost import ( # noqa: PLC0415
|
||||
drain_pending_cost_logs as _drain_copilot,
|
||||
# Also drain copilot cost log tasks (token_tracking._pending_log_tasks)
|
||||
from backend.copilot.token_tracking import ( # noqa: PLC0415
|
||||
_pending_log_tasks as _copilot_tasks,
|
||||
)
|
||||
|
||||
await _drain_copilot()
|
||||
copilot_pending = list(_copilot_tasks)
|
||||
if copilot_pending:
|
||||
logger.info("Draining %d copilot cost log task(s)", len(copilot_pending))
|
||||
_, still_pending = await asyncio.wait(copilot_pending, timeout=timeout)
|
||||
if still_pending:
|
||||
logger.warning(
|
||||
"%d copilot cost log task(s) did not complete within %.1fs",
|
||||
len(still_pending),
|
||||
timeout,
|
||||
)
|
||||
|
||||
|
||||
def _schedule_log(
|
||||
|
||||
Reference in New Issue
Block a user