feat(copilot): cost-weighted token rate limiting with cache breakdown

- Rate limiter now uses Anthropic's cost model: cache_read at 10%,
  cache_creation at 25%, uncached and output at 100%
- Track cache_read_tokens and cache_creation_tokens separately in
  Usage model, StreamUsage response, and SDK token extraction
- Pass cache breakdown through to record_token_usage() for accurate
  weighted counting
- Add test for cost-weighted counting (10K cache_read → 1K weighted)

This makes multi-turn conversations fairer: cached system prompts
and tool schemas don't penalize users at full token cost.
This commit is contained in:
Zamil Majdy
2026-03-13 14:36:04 +07:00
parent a52a777b29
commit a5ed8fefa9
5 changed files with 108 additions and 18 deletions

View File

@@ -73,6 +73,9 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
cache_read_tokens: int = 0
cache_creation_tokens: int = 0
class ChatSessionInfo(BaseModel):

View File

@@ -169,18 +169,51 @@ async def record_token_usage(
user_id: str,
prompt_tokens: int,
completion_tokens: int,
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
) -> None:
"""Record token usage for a user across all windows.
Uses cost-weighted counting so cached tokens don't unfairly penalise
multi-turn conversations. Anthropic's pricing:
- uncached input: 100%
- cache creation: 25%
- cache read: 10%
- output: 100%
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
Args:
user_id: The user's ID.
prompt_tokens: Number of prompt tokens used.
completion_tokens: Number of completion tokens used.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
"""
total = prompt_tokens + completion_tokens
weighted_input = (
prompt_tokens + int(cache_creation_tokens * 0.25) + int(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
if total <= 0:
return
raw_total = (
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
)
now = datetime.now(UTC)
try:
redis = await get_redis_async()

View File

@@ -293,6 +293,31 @@ class TestRecordTokenUsage:
# Should not raise
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
@pytest.mark.asyncio
async def test_cost_weighted_counting(self):
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(
_USER,
prompt_tokens=100, # uncached → 100
completion_tokens=50, # output → 50
cache_read_tokens=10000, # 10% → 1000
cache_creation_tokens=400, # 25% → 100
)
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 1250 # daily
assert incrby_calls[1].args[1] == 1250 # weekly
@pytest.mark.asyncio
async def test_handles_redis_error_during_pipeline_execute(self):
"""Should not raise when pipeline.execute() fails with RedisError."""

View File

@@ -189,9 +189,17 @@ class StreamUsage(StreamBaseResponse):
"""Token usage statistics."""
type: ResponseType = ResponseType.USAGE
promptTokens: int = Field(..., description="Number of prompt tokens")
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(..., description="Total number of tokens")
totalTokens: int = Field(
..., description="Total number of tokens (raw, not weighted)"
)
cacheReadTokens: int = Field(
default=0, description="Prompt tokens served from cache (10% cost)"
)
cacheCreationTokens: int = Field(
default=0, description="Prompt tokens written to cache (25% cost)"
)
class StreamError(StreamBaseResponse):

View File

@@ -739,8 +739,10 @@ async def stream_chat_completion_sdk(
# Make sure there is no more code between the lock acquisition and try-block.
# Token usage accumulators — populated from ResultMessage at end of turn
turn_prompt_tokens = 0
turn_prompt_tokens = 0 # uncached input tokens only
turn_completion_tokens = 0
turn_cache_read_tokens = 0
turn_cache_creation_tokens = 0
turn_cost_usd: float | None = None
try:
@@ -1142,23 +1144,23 @@ async def stream_chat_completion_sdk(
# input_tokens = uncached only
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
# Total input = sum of all three.
if sdk_msg.usage:
turn_prompt_tokens += (
sdk_msg.usage.get("input_tokens", 0)
+ sdk_msg.usage.get("cache_read_input_tokens", 0)
+ sdk_msg.usage.get("cache_creation_input_tokens", 0)
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
turn_cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
)
turn_cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
)
turn_completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
)
logger.info(
"%s Token usage: input=%d (uncached=%d, cache_read=%d, cache_create=%d), output=%d",
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
log_prefix,
turn_prompt_tokens,
sdk_msg.usage.get("input_tokens", 0),
sdk_msg.usage.get("cache_read_input_tokens", 0),
sdk_msg.usage.get("cache_creation_input_tokens", 0),
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
)
if sdk_msg.total_cost_usd is not None:
@@ -1365,11 +1367,18 @@ async def stream_chat_completion_sdk(
# rate-limit recording even if an exception interrupts between here
# and the finally block.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = turn_prompt_tokens + turn_completion_tokens
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=total_tokens,
cacheReadTokens=turn_cache_read_tokens,
cacheCreationTokens=turn_cache_creation_tokens,
)
# Transcript upload is handled exclusively in the finally block
@@ -1440,19 +1449,29 @@ async def stream_chat_completion_sdk(
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = turn_prompt_tokens + turn_completion_tokens
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
if session is not None:
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
)
logger.info(
"%s Turn usage: prompt=%d, completion=%d, total=%d, cost_usd=%s",
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
"output=%d, total=%d, cost_usd=%s",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
total_tokens,
turn_cost_usd,
@@ -1463,6 +1482,8 @@ async def stream_chat_completion_sdk(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(