Compare commits

...

4 Commits

Author SHA1 Message Date
majdyz
079902501e chore: merge dev into fix/openrouter-null-cache-tokens 2026-04-15 15:53:13 +07:00
majdyz
a042c84907 refactor(backend/copilot): extract _apply_token_usage helper and parametrize tests
Extract the four-field null-safe token accumulation from _run_stream_attempt
into a standalone _apply_token_usage() function so tests exercise the
production logic directly (fixing the 4-line codecov coverage gap).

Update TestTokenUsageNullSafety to call _apply_token_usage instead of
a local mirror, and add pytest.mark.parametrize coverage over all four
token fields as suggested in coderabbit review.
2026-04-15 14:53:23 +07:00
majdyz
dfb7f327de Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into fix/openrouter-null-cache-tokens 2026-04-15 14:49:00 +07:00
majdyz
0c4931b8f8 fix(backend/copilot): null-safe token accumulation for OpenRouter null cache fields
OpenRouter occasionally returns null (not 0) for cache_read_input_tokens and
cache_creation_input_tokens on the initial streaming event before real counts
are available. Using .get(key, 0) returns None when the key exists with a null
value, causing TypeError on subsequent +=. Switch to .get(key) or 0 which
treats both missing and null keys as 0.

Adds _TokenUsage unit tests covering the null event, real event, absent keys,
and multi-turn accumulation scenarios.
2026-04-15 13:50:34 +07:00
2 changed files with 41 additions and 68 deletions

View File

@@ -298,6 +298,21 @@ class _TokenUsage:
self.cost_usd = None
def _apply_token_usage(acc: _TokenUsage, usage: dict) -> None:
"""Accumulate token counts from a ResultMessage usage dict into *acc*.
Uses ``or 0`` instead of ``.get(key, 0)`` because OpenRouter may include
cache token keys with a ``null`` value (rather than omitting them) during
the initial streaming event before real counts are available. Plain
``.get(key, 0)`` returns ``None`` when the key exists but is ``null``,
causing ``int += None`` TypeError.
"""
acc.prompt_tokens += usage.get("input_tokens") or 0
acc.cache_read_tokens += usage.get("cache_read_input_tokens") or 0
acc.cache_creation_tokens += usage.get("cache_creation_input_tokens") or 0
acc.completion_tokens += usage.get("output_tokens") or 0
@dataclass
class _RetryState:
"""Mutable state passed to `_run_stream_attempt` instead of closures.
@@ -1912,21 +1927,7 @@ async def _run_stream_attempt(
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
# Use `or 0` instead of a default in .get() because
# OpenRouter may include the key with a null value (e.g.
# {"cache_read_input_tokens": null}) for models that don't
# yet report cache tokens, making .get("key", 0) return
# None rather than the fallback 0.
state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens") or 0
state.usage.cache_read_tokens += (
sdk_msg.usage.get("cache_read_input_tokens") or 0
)
state.usage.cache_creation_tokens += (
sdk_msg.usage.get("cache_creation_input_tokens") or 0
)
state.usage.completion_tokens += (
sdk_msg.usage.get("output_tokens") or 0
)
_apply_token_usage(state.usage, sdk_msg.usage)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, "
"cache_create=%d, output=%d",

View File

@@ -17,6 +17,7 @@ from .conftest import build_test_transcript as _build_transcript
from .service import (
_RETRY_TARGET_TOKENS,
ReducedContext,
_apply_token_usage,
_is_prompt_too_long,
_is_tool_only_message,
_iter_sdk_messages,
@@ -354,47 +355,6 @@ class TestIsParallelContinuation:
assert _is_tool_only_message(msg) is True
# ---------------------------------------------------------------------------
# _normalize_model_name — used by per-request model override
# ---------------------------------------------------------------------------
class TestNormalizeModelName:
"""Unit tests for the model-name normalisation helper.
The per-request model toggle calls _normalize_model_name with either
``"anthropic/claude-opus-4-6"`` (for 'advanced') or ``config.model`` (for
'standard'). These tests verify the OpenRouter/provider-prefix stripping
that keeps the value compatible with the Claude CLI.
"""
def test_strips_anthropic_prefix(self):
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
def test_strips_openai_prefix(self):
assert _normalize_model_name("openai/gpt-4o") == "gpt-4o"
def test_strips_google_prefix(self):
assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash"
def test_already_normalized_unchanged(self):
assert (
_normalize_model_name("claude-sonnet-4-20250514")
== "claude-sonnet-4-20250514"
)
def test_empty_string_unchanged(self):
assert _normalize_model_name("") == ""
def test_opus_model_roundtrip(self):
"""The exact string used for the 'opus' toggle strips correctly."""
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
def test_sonnet_openrouter_model(self):
"""Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly."""
assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4"
# ---------------------------------------------------------------------------
# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug)
# ---------------------------------------------------------------------------
@@ -409,13 +369,6 @@ class TestTokenUsageNullSafety:
when the key existed with a null value, causing 'int += None' TypeError.
"""
def _apply_usage(self, usage: dict, acc: _TokenUsage) -> None:
"""Mirror the production accumulation in sdk/service.py."""
acc.prompt_tokens += usage.get("input_tokens") or 0
acc.cache_read_tokens += usage.get("cache_read_input_tokens") or 0
acc.cache_creation_tokens += usage.get("cache_creation_input_tokens") or 0
acc.completion_tokens += usage.get("output_tokens") or 0
def test_null_cache_tokens_do_not_crash(self):
"""OpenRouter initial event: cache keys present with null value."""
usage = {
@@ -425,7 +378,7 @@ class TestTokenUsageNullSafety:
"cache_creation_input_tokens": None,
}
acc = _TokenUsage()
self._apply_usage(usage, acc) # must not raise TypeError
_apply_token_usage(acc, usage) # must not raise TypeError
assert acc.prompt_tokens == 0
assert acc.cache_read_tokens == 0
assert acc.cache_creation_tokens == 0
@@ -440,7 +393,7 @@ class TestTokenUsageNullSafety:
"cache_creation_input_tokens": 512,
}
acc = _TokenUsage()
self._apply_usage(usage, acc)
_apply_token_usage(acc, usage)
assert acc.prompt_tokens == 10
assert acc.cache_read_tokens == 16600
assert acc.cache_creation_tokens == 512
@@ -450,7 +403,7 @@ class TestTokenUsageNullSafety:
"""Minimal usage dict without cache keys defaults correctly."""
usage = {"input_tokens": 5, "output_tokens": 20}
acc = _TokenUsage()
self._apply_usage(usage, acc)
_apply_token_usage(acc, usage)
assert acc.prompt_tokens == 5
assert acc.cache_read_tokens == 0
assert acc.cache_creation_tokens == 0
@@ -471,9 +424,28 @@ class TestTokenUsageNullSafety:
"cache_creation_input_tokens": 512,
}
acc = _TokenUsage()
self._apply_usage(null_event, acc)
self._apply_usage(real_event, acc)
_apply_token_usage(acc, null_event)
_apply_token_usage(acc, real_event)
assert acc.prompt_tokens == 10
assert acc.cache_read_tokens == 16600
assert acc.cache_creation_tokens == 512
assert acc.completion_tokens == 349
@pytest.mark.parametrize(
"key,null_field,real_value,acc_attr",
[
("cache_read_input_tokens", None, 16600, "cache_read_tokens"),
("cache_creation_input_tokens", None, 512, "cache_creation_tokens"),
("input_tokens", None, 10, "prompt_tokens"),
("output_tokens", None, 349, "completion_tokens"),
],
)
def test_null_then_real_per_field(
self, key: str, null_field: None, real_value: int, acc_attr: str
) -> None:
"""Each token field handles null → real transition independently."""
acc = _TokenUsage()
_apply_token_usage(acc, {key: null_field})
assert getattr(acc, acc_attr) == 0
_apply_token_usage(acc, {key: real_value})
assert getattr(acc, acc_attr) == real_value