mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
3 Commits
test-scree
...
fix/cost-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9710dc76fc | ||
|
|
04ed0e6a4d | ||
|
|
2a47ecc129 |
@@ -15,6 +15,8 @@ import math
|
||||
import re
|
||||
import threading
|
||||
|
||||
from prisma.errors import DataError
|
||||
|
||||
from backend.data.db_accessors import platform_cost_db
|
||||
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
|
||||
|
||||
@@ -50,6 +52,15 @@ def _schedule_cost_log(entry: PlatformCostEntry) -> None:
|
||||
async with _get_log_semaphore():
|
||||
try:
|
||||
await platform_cost_db().log_platform_cost(entry)
|
||||
except DataError as e:
|
||||
# Prisma DataError typically means the DB manager pod is running a
|
||||
# stale Prisma client (e.g. during a rolling deploy after a schema
|
||||
# migration). Log at WARNING so Sentry is not spammed.
|
||||
logger.warning(
|
||||
f"Skipping platform cost log (schema mismatch?) for "
|
||||
f"user={entry.user_id} provider={entry.provider} "
|
||||
f"block={entry.block_name}: {e}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
|
||||
@@ -9,9 +9,27 @@ from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.errors import DataError
|
||||
|
||||
from backend.data.platform_cost import PlatformCostEntry
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .token_tracking import persist_and_record_usage
|
||||
from .token_tracking import _schedule_cost_log, persist_and_record_usage
|
||||
|
||||
|
||||
def _make_data_error(msg: str = "stale schema") -> DataError:
|
||||
"""Construct a valid prisma DataError (requires a dict, not a bare string)."""
|
||||
return DataError(
|
||||
{
|
||||
"user_facing_error": {
|
||||
"is_panic": False,
|
||||
"message": msg,
|
||||
"meta": {},
|
||||
"error_code": "P2006",
|
||||
"batch_request_idx": 0,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _make_session() -> ChatSession:
|
||||
@@ -567,3 +585,53 @@ class TestPlatformCostLogging:
|
||||
# Negative cost rejected — falls back to token-based tracking
|
||||
assert entry.cost_microdollars is None
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
|
||||
|
||||
def _make_cost_entry(**overrides: object) -> PlatformCostEntry:
|
||||
return PlatformCostEntry.model_validate(
|
||||
{
|
||||
"user_id": "user-1",
|
||||
"block_id": "copilot",
|
||||
"block_name": "copilot:SDK",
|
||||
"provider": "anthropic",
|
||||
**overrides,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestScheduleCostLogDataError:
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_error_logs_warning_not_error(self, caplog):
|
||||
"""DataError from stale Prisma client should be logged at WARNING, not ERROR."""
|
||||
import logging
|
||||
|
||||
mock_log = AsyncMock(side_effect=_make_data_error())
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
):
|
||||
entry = _make_cost_entry()
|
||||
with caplog.at_level(
|
||||
logging.WARNING, logger="backend.copilot.token_tracking"
|
||||
):
|
||||
_schedule_cost_log(entry)
|
||||
await asyncio.sleep(0)
|
||||
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert warning_records, "Expected a WARNING log record for DataError"
|
||||
assert "schema mismatch" in warning_records[0].message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_error_does_not_propagate(self):
|
||||
"""DataError in the scheduled task must not crash the event loop."""
|
||||
mock_log = AsyncMock(side_effect=_make_data_error())
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
):
|
||||
entry = _make_cost_entry()
|
||||
_schedule_cost_log(entry)
|
||||
await asyncio.sleep(0) # must not raise
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from prisma.errors import DataError
|
||||
from prisma.models import PlatformCostLog as PrismaLog
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -87,6 +88,15 @@ async def log_platform_cost_safe(entry: PlatformCostEntry) -> None:
|
||||
try:
|
||||
async with _log_semaphore:
|
||||
await log_platform_cost(entry)
|
||||
except DataError as e:
|
||||
# Prisma DataError typically means the DB manager pod is running a stale
|
||||
# Prisma client (e.g. during a rolling deploy after a schema migration).
|
||||
# Log at WARNING so Sentry is not spammed.
|
||||
logger.warning(
|
||||
f"Skipping platform cost log (schema mismatch?) for "
|
||||
f"user={entry.user_id} provider={entry.provider} "
|
||||
f"block={entry.block_name}: {e}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.errors import DataError
|
||||
|
||||
from .platform_cost import (
|
||||
PlatformCostEntry,
|
||||
@@ -16,6 +17,21 @@ from .platform_cost import (
|
||||
)
|
||||
|
||||
|
||||
def _make_data_error(msg: str = "stale schema") -> DataError:
|
||||
"""Construct a valid prisma DataError (requires a dict, not a bare string)."""
|
||||
return DataError(
|
||||
{
|
||||
"user_facing_error": {
|
||||
"is_panic": False,
|
||||
"message": msg,
|
||||
"meta": {},
|
||||
"error_code": "P2006",
|
||||
"batch_request_idx": 0,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestMaskEmail:
|
||||
def test_typical_email(self):
|
||||
assert _mask_email("user@example.com") == "us***@example.com"
|
||||
@@ -156,6 +172,28 @@ class TestLogPlatformCostSafe:
|
||||
await log_platform_cost_safe(entry)
|
||||
mock_create.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_error_logs_warning_not_error(self, caplog):
|
||||
"""DataError (stale Prisma client) should be logged at WARNING, not ERROR."""
|
||||
import logging
|
||||
|
||||
with patch("backend.data.platform_cost.PrismaLog.prisma") as mock_prisma:
|
||||
mock_prisma.return_value.create = AsyncMock(side_effect=_make_data_error())
|
||||
entry = _make_entry()
|
||||
with caplog.at_level(logging.WARNING, logger="backend.data.platform_cost"):
|
||||
await log_platform_cost_safe(entry)
|
||||
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert warning_records, "Expected a WARNING log record for DataError"
|
||||
assert "schema mismatch" in warning_records[0].message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_error_does_not_raise(self):
|
||||
"""DataError must be swallowed — log_platform_cost_safe never raises."""
|
||||
with patch("backend.data.platform_cost.PrismaLog.prisma") as mock_prisma:
|
||||
mock_prisma.return_value.create = AsyncMock(side_effect=_make_data_error())
|
||||
entry = _make_entry()
|
||||
await log_platform_cost_safe(entry) # must not raise
|
||||
|
||||
|
||||
class TestGetPlatformCostDashboard:
|
||||
def setup_method(self):
|
||||
|
||||
Reference in New Issue
Block a user