mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Adds 14 unit tests covering all branches of the shared token-usage persistence helper: total_tokens semantics (prompt+completion, cache excluded), session Usage record appending, multi-turn accumulation, None session/user handling, rate-limit recording dispatch, fail-open on Redis errors, and early-exit when tokens are zero. Covers both baseline (no cache) and SDK (with cache breakdown) paths so the ~100 lines of token_tracking.py have explicit test coverage.
282 lines
9.4 KiB
Python
282 lines
9.4 KiB
Python
"""Unit tests for token_tracking.persist_and_record_usage.
|
|
|
|
Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
|
|
calling conventions, session persistence, and rate-limit recording.
|
|
"""
|
|
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from .model import ChatSession, Usage
|
|
from .token_tracking import persist_and_record_usage
|
|
|
|
|
|
def _make_session() -> ChatSession:
|
|
"""Return a minimal in-memory ChatSession for testing."""
|
|
return ChatSession(
|
|
session_id="sess-test",
|
|
user_id="user-test",
|
|
title=None,
|
|
messages=[],
|
|
usage=[],
|
|
started_at=datetime.now(UTC),
|
|
updated_at=datetime.now(UTC),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Return value / total_tokens semantics
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestTotalTokens:
|
|
@pytest.mark.asyncio
|
|
async def test_returns_prompt_plus_completion(self):
|
|
"""total_tokens = prompt + completion (cache excluded from total)."""
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new_callable=AsyncMock,
|
|
):
|
|
total = await persist_and_record_usage(
|
|
session=None,
|
|
user_id=None,
|
|
prompt_tokens=300,
|
|
completion_tokens=200,
|
|
)
|
|
assert total == 500
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_zero_when_no_tokens(self):
|
|
"""Returns 0 early when both prompt and completion are zero."""
|
|
total = await persist_and_record_usage(
|
|
session=None,
|
|
user_id=None,
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
)
|
|
assert total == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_tokens_excluded_from_total(self):
|
|
"""Cache tokens are stored separately and not added to total_tokens."""
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new_callable=AsyncMock,
|
|
):
|
|
total = await persist_and_record_usage(
|
|
session=None,
|
|
user_id=None,
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
cache_read_tokens=5000,
|
|
cache_creation_tokens=200,
|
|
)
|
|
# total = prompt + completion only (5000 + 200 cache excluded)
|
|
assert total == 150
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_baseline_path_no_cache(self):
|
|
"""Baseline (OpenRouter) path passes no cache tokens; total = prompt + completion."""
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new_callable=AsyncMock,
|
|
):
|
|
total = await persist_and_record_usage(
|
|
session=None,
|
|
user_id="u1",
|
|
prompt_tokens=1000,
|
|
completion_tokens=400,
|
|
log_prefix="[Baseline]",
|
|
)
|
|
assert total == 1400
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sdk_path_with_cache(self):
|
|
"""SDK (Anthropic) path passes cache tokens; total still = prompt + completion."""
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new_callable=AsyncMock,
|
|
):
|
|
total = await persist_and_record_usage(
|
|
session=None,
|
|
user_id="u2",
|
|
prompt_tokens=200,
|
|
completion_tokens=100,
|
|
cache_read_tokens=8000,
|
|
cache_creation_tokens=400,
|
|
log_prefix="[SDK]",
|
|
cost_usd=0.0015,
|
|
)
|
|
assert total == 300
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session persistence
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSessionPersistence:
|
|
@pytest.mark.asyncio
|
|
async def test_appends_usage_to_session(self):
|
|
session = _make_session()
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new_callable=AsyncMock,
|
|
):
|
|
await persist_and_record_usage(
|
|
session=session,
|
|
user_id=None,
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
)
|
|
assert len(session.usage) == 1
|
|
usage: Usage = session.usage[0]
|
|
assert usage.prompt_tokens == 100
|
|
assert usage.completion_tokens == 50
|
|
assert usage.total_tokens == 150
|
|
assert usage.cache_read_tokens == 0
|
|
assert usage.cache_creation_tokens == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_appends_cache_breakdown_to_session(self):
|
|
session = _make_session()
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new_callable=AsyncMock,
|
|
):
|
|
await persist_and_record_usage(
|
|
session=session,
|
|
user_id=None,
|
|
prompt_tokens=200,
|
|
completion_tokens=80,
|
|
cache_read_tokens=3000,
|
|
cache_creation_tokens=500,
|
|
)
|
|
usage: Usage = session.usage[0]
|
|
assert usage.cache_read_tokens == 3000
|
|
assert usage.cache_creation_tokens == 500
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_turns_append_multiple_records(self):
|
|
session = _make_session()
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new_callable=AsyncMock,
|
|
):
|
|
await persist_and_record_usage(
|
|
session=session, user_id=None, prompt_tokens=100, completion_tokens=50
|
|
)
|
|
await persist_and_record_usage(
|
|
session=session, user_id=None, prompt_tokens=200, completion_tokens=70
|
|
)
|
|
assert len(session.usage) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_none_session_does_not_raise(self):
|
|
"""When session is None (e.g. error path), no exception should be raised."""
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new_callable=AsyncMock,
|
|
):
|
|
total = await persist_and_record_usage(
|
|
session=None,
|
|
user_id=None,
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
)
|
|
assert total == 150
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_append_when_zero_tokens(self):
|
|
"""When tokens are zero, function returns early — session unchanged."""
|
|
session = _make_session()
|
|
total = await persist_and_record_usage(
|
|
session=session,
|
|
user_id=None,
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
)
|
|
assert total == 0
|
|
assert len(session.usage) == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Rate-limit recording
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRateLimitRecording:
|
|
@pytest.mark.asyncio
|
|
async def test_calls_record_token_usage_when_user_id_present(self):
|
|
mock_record = AsyncMock()
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new=mock_record,
|
|
):
|
|
await persist_and_record_usage(
|
|
session=None,
|
|
user_id="user-abc",
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
cache_read_tokens=1000,
|
|
cache_creation_tokens=200,
|
|
)
|
|
mock_record.assert_awaited_once_with(
|
|
user_id="user-abc",
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
cache_read_tokens=1000,
|
|
cache_creation_tokens=200,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_record_when_user_id_is_none(self):
|
|
"""Anonymous sessions should not create Redis keys."""
|
|
mock_record = AsyncMock()
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new=mock_record,
|
|
):
|
|
await persist_and_record_usage(
|
|
session=None,
|
|
user_id=None,
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
)
|
|
mock_record.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_failure_does_not_raise(self):
|
|
"""A Redis error in record_token_usage should be swallowed (fail-open)."""
|
|
mock_record = AsyncMock(side_effect=ConnectionError("Redis down"))
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new=mock_record,
|
|
):
|
|
# Should not raise
|
|
total = await persist_and_record_usage(
|
|
session=None,
|
|
user_id="user-xyz",
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
)
|
|
assert total == 150
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_record_when_zero_tokens(self):
|
|
"""Returns 0 before calling record_token_usage when tokens are zero."""
|
|
mock_record = AsyncMock()
|
|
with patch(
|
|
"backend.copilot.token_tracking.record_token_usage",
|
|
new=mock_record,
|
|
):
|
|
await persist_and_record_usage(
|
|
session=None,
|
|
user_id="user-abc",
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
)
|
|
mock_record.assert_not_awaited()
|