Use LLM APIs responses in token counting (#5604)

Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
Engel Nyst
2025-02-23 17:58:47 +01:00
committed by GitHub
parent abac25cc4c
commit 2d2dbf1561
5 changed files with 295 additions and 15 deletions

View File

@@ -29,6 +29,7 @@ from openhands.events.observation import (
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import truncate_content
from openhands.llm.metrics import Metrics, TokenUsage
def events_to_messages(
@@ -362,3 +363,47 @@ def apply_prompt_caching(messages: list[Message]) -> None:
-1
].cache_prompt = True # Last item inside the message content
break
def get_token_usage_for_event(event: Event, metrics: Metrics) -> TokenUsage | None:
"""
Returns at most one token usage record for the `model_response.id` in this event's
`tool_call_metadata`.
If no response_id is found, or none match in metrics.token_usages, returns None.
"""
if event.tool_call_metadata and event.tool_call_metadata.model_response:
response_id = event.tool_call_metadata.model_response.get('id')
if response_id:
return next(
(
usage
for usage in metrics.token_usages
if usage.response_id == response_id
),
None,
)
return None
def get_token_usage_for_event_id(
events: list[Event], event_id: int, metrics: Metrics
) -> TokenUsage | None:
"""
Starting from the event with .id == event_id and moving backwards in `events`,
find the first TokenUsage record (if any) associated with a response_id from
tool_call_metadata.model_response.id.
Returns the first match found, or None if none is found.
"""
# find the index of the event with the given id
idx = next((i for i, e in enumerate(events) if e.id == event_id), None)
if idx is None:
return None
# search backward from idx down to 0
for i in range(idx, -1, -1):
usage = get_token_usage_for_event(events[i], metrics)
if usage is not None:
return usage
return None

View File

@@ -497,20 +497,21 @@ class LLM(RetryMixin, DebugMixin):
stats += 'Response Latency: %.3f seconds\n' % latest_latency.latency
usage: Usage | None = response.get('usage')
response_id = response.get('id', 'unknown')
if usage:
# keep track of the input and output tokens
input_tokens = usage.get('prompt_tokens')
output_tokens = usage.get('completion_tokens')
prompt_tokens = usage.get('prompt_tokens', 0)
completion_tokens = usage.get('completion_tokens', 0)
if input_tokens:
stats += 'Input tokens: ' + str(input_tokens)
if prompt_tokens:
stats += 'Input tokens: ' + str(prompt_tokens)
if output_tokens:
if completion_tokens:
stats += (
(' | ' if input_tokens else '')
(' | ' if prompt_tokens else '')
+ 'Output tokens: '
+ str(output_tokens)
+ str(completion_tokens)
+ '\n'
)
@@ -519,7 +520,7 @@ class LLM(RetryMixin, DebugMixin):
'prompt_tokens_details'
)
cache_hit_tokens = (
prompt_tokens_details.cached_tokens if prompt_tokens_details else None
prompt_tokens_details.cached_tokens if prompt_tokens_details else 0
)
if cache_hit_tokens:
stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'
@@ -528,10 +529,20 @@ class LLM(RetryMixin, DebugMixin):
# but litellm doesn't separate them in the usage stats
# so we can read it from the provider-specific extra field
model_extra = usage.get('model_extra', {})
cache_write_tokens = model_extra.get('cache_creation_input_tokens')
cache_write_tokens = model_extra.get('cache_creation_input_tokens', 0)
if cache_write_tokens:
stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
# Record in metrics
# We'll treat cache_hit_tokens as "cache read" and cache_write_tokens as "cache write"
self.metrics.add_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_hit_tokens,
cache_write_tokens=cache_write_tokens,
response_id=response_id,
)
# log the stats
if stats:
logger.debug(stats)

View File

@@ -17,11 +17,23 @@ class ResponseLatency(BaseModel):
response_id: str
class TokenUsage(BaseModel):
"""Metric tracking detailed token usage per completion call."""
model: str
prompt_tokens: int
completion_tokens: int
cache_read_tokens: int
cache_write_tokens: int
response_id: str
class Metrics:
"""Metrics class can record various metrics during running and evaluation.
Currently, we define the following metrics:
accumulated_cost: the total cost (USD $) of the current LLM.
response_latency: the time taken for each LLM completion call.
We track:
- accumulated_cost and costs
- A list of ResponseLatency
- A list of TokenUsage (one per call).
"""
def __init__(self, model_name: str = 'default') -> None:
@@ -29,6 +41,7 @@ class Metrics:
self._costs: list[Cost] = []
self._response_latencies: list[ResponseLatency] = []
self.model_name = model_name
self._token_usages: list[TokenUsage] = []
@property
def accumulated_cost(self) -> float:
@@ -54,6 +67,16 @@ class Metrics:
def response_latencies(self, value: list[ResponseLatency]) -> None:
self._response_latencies = value
@property
def token_usages(self) -> list[TokenUsage]:
if not hasattr(self, '_token_usages'):
self._token_usages = []
return self._token_usages
@token_usages.setter
def token_usages(self, value: list[TokenUsage]) -> None:
self._token_usages = value
def add_cost(self, value: float) -> None:
if value < 0:
raise ValueError('Added cost cannot be negative.')
@@ -67,10 +90,33 @@ class Metrics:
)
)
def add_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
cache_read_tokens: int,
cache_write_tokens: int,
response_id: str,
) -> None:
"""Add a single usage record."""
self._token_usages.append(
TokenUsage(
model=self.model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
response_id=response_id,
)
)
def merge(self, other: 'Metrics') -> None:
"""Merge 'other' metrics into this one."""
self._accumulated_cost += other.accumulated_cost
self._costs += other._costs
self._response_latencies += other._response_latencies
# use the property so older picked objects that lack the field won't crash
self.token_usages += other.token_usages
self.response_latencies += other.response_latencies
def get(self) -> dict:
"""Return the metrics in a dictionary."""
@@ -80,12 +126,14 @@ class Metrics:
'response_latencies': [
latency.model_dump() for latency in self._response_latencies
],
'token_usages': [usage.model_dump() for usage in self._token_usages],
}
def reset(self):
self._accumulated_cost = 0.0
self._costs = []
self._response_latencies = []
self._token_usages = []
def log(self):
"""Log the metrics."""

View File

@@ -2,6 +2,7 @@ import copy
from unittest.mock import MagicMock, patch
import pytest
from litellm import PromptTokensDetails
from litellm.exceptions import (
RateLimitError,
)
@@ -429,3 +430,62 @@ def test_get_token_count_error_handling(
mock_logger.error.assert_called_once_with(
'Error getting token count for\n model gpt-4o\nToken counting failed'
)
@patch('openhands.llm.llm.litellm_completion')
def test_llm_token_usage(mock_litellm_completion, default_config):
# This mock response includes usage details with prompt_tokens,
# completion_tokens, prompt_tokens_details.cached_tokens, and model_extra.cache_creation_input_tokens
mock_response_1 = {
'id': 'test-response-usage',
'choices': [{'message': {'content': 'Usage test response'}}],
'usage': {
'prompt_tokens': 12,
'completion_tokens': 3,
'prompt_tokens_details': PromptTokensDetails(cached_tokens=2),
'model_extra': {'cache_creation_input_tokens': 5},
},
}
# Create a second usage scenario to test accumulation and a different response_id
mock_response_2 = {
'id': 'test-response-usage-2',
'choices': [{'message': {'content': 'Second usage test response'}}],
'usage': {
'prompt_tokens': 7,
'completion_tokens': 2,
'prompt_tokens_details': PromptTokensDetails(cached_tokens=1),
'model_extra': {'cache_creation_input_tokens': 3},
},
}
# We'll make mock_litellm_completion return these responses in sequence
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
llm = LLM(config=default_config)
# First call
llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}])
# Verify we have exactly one usage record after first call
token_usage_list = llm.metrics.get()['token_usages']
assert len(token_usage_list) == 1
usage_entry_1 = token_usage_list[0]
assert usage_entry_1['prompt_tokens'] == 12
assert usage_entry_1['completion_tokens'] == 3
assert usage_entry_1['cache_read_tokens'] == 2
assert usage_entry_1['cache_write_tokens'] == 5
assert usage_entry_1['response_id'] == 'test-response-usage'
# Second call
llm.completion(messages=[{'role': 'user', 'content': 'Hello again!'}])
# Now we expect two usage records total
token_usage_list = llm.metrics.get()['token_usages']
assert len(token_usage_list) == 2
usage_entry_2 = token_usage_list[-1]
assert usage_entry_2['prompt_tokens'] == 7
assert usage_entry_2['completion_tokens'] == 2
assert usage_entry_2['cache_read_tokens'] == 1
assert usage_entry_2['cache_write_tokens'] == 3
assert usage_entry_2['response_id'] == 'test-response-usage-2'

View File

@@ -3,13 +3,18 @@ from unittest.mock import Mock
import pytest
from openhands.core.message import ImageContent, TextContent
from openhands.core.message_utils import get_action_message, get_observation_message
from openhands.core.message_utils import (
get_action_message,
get_observation_message,
get_token_usage_for_event,
get_token_usage_for_event_id,
)
from openhands.events.action import (
AgentFinishAction,
CmdRunAction,
MessageAction,
)
from openhands.events.event import EventSource, FileEditSource, FileReadSource
from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
CmdOutputMetadata,
@@ -21,6 +26,7 @@ from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.files import FileEditObservation, FileReadObservation
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.tool import ToolCallMetadata
from openhands.llm.metrics import Metrics, TokenUsage
def test_cmd_output_observation_message():
@@ -269,3 +275,113 @@ def test_agent_finish_action_with_tool_metadata():
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert 'Initial thought\nTask completed' in result.content[0].text
def test_get_token_usage_for_event():
"""Test that we get the single matching usage record (if any) based on the event's model_response.id."""
metrics = Metrics(model_name='test-model')
usage_record = TokenUsage(
model='test-model',
prompt_tokens=10,
completion_tokens=5,
cache_read_tokens=2,
cache_write_tokens=1,
response_id='test-response-id',
)
metrics.add_token_usage(
prompt_tokens=usage_record.prompt_tokens,
completion_tokens=usage_record.completion_tokens,
cache_read_tokens=usage_record.cache_read_tokens,
cache_write_tokens=usage_record.cache_write_tokens,
response_id=usage_record.response_id,
)
# Create an event referencing that response_id
event = Event()
mock_tool_call_metadata = ToolCallMetadata(
tool_call_id='test-tool-call',
function_name='fake_function',
model_response={'id': 'test-response-id'},
total_calls_in_response=1,
)
event._tool_call_metadata = (
mock_tool_call_metadata # normally you'd do event.tool_call_metadata = ...
)
# We should find that usage record
found = get_token_usage_for_event(event, metrics)
assert found is not None
assert found.prompt_tokens == 10
assert found.response_id == 'test-response-id'
# If we change the event's response ID, we won't find anything
mock_tool_call_metadata.model_response.id = 'some-other-id'
found2 = get_token_usage_for_event(event, metrics)
assert found2 is None
# If the event has no tool_call_metadata, also returns None
event._tool_call_metadata = None
found3 = get_token_usage_for_event(event, metrics)
assert found3 is None
def test_get_token_usage_for_event_id():
"""
Test that we search backward from the event with the given id,
finding the first usage record that matches a response_id in that or previous events.
"""
metrics = Metrics(model_name='test-model')
usage_1 = TokenUsage(
model='test-model',
prompt_tokens=12,
completion_tokens=3,
cache_read_tokens=2,
cache_write_tokens=5,
response_id='resp-1',
)
usage_2 = TokenUsage(
model='test-model',
prompt_tokens=7,
completion_tokens=2,
cache_read_tokens=1,
cache_write_tokens=3,
response_id='resp-2',
)
metrics._token_usages.append(usage_1)
metrics._token_usages.append(usage_2)
# Build a list of events
events = []
for i in range(5):
e = Event()
e._id = i
# We'll attach usage_1 to event 1, usage_2 to event 3
if i == 1:
e._tool_call_metadata = ToolCallMetadata(
tool_call_id='tid1',
function_name='fn1',
model_response={'id': 'resp-1'},
total_calls_in_response=1,
)
elif i == 3:
e._tool_call_metadata = ToolCallMetadata(
tool_call_id='tid2',
function_name='fn2',
model_response={'id': 'resp-2'},
total_calls_in_response=1,
)
events.append(e)
# If we ask for event_id=3, we find usage_2 immediately
found_3 = get_token_usage_for_event_id(events, 3, metrics)
assert found_3 is not None
assert found_3.response_id == 'resp-2'
# If we ask for event_id=2, no usage in event2, so we check event1 -> usage_1 found
found_2 = get_token_usage_for_event_id(events, 2, metrics)
assert found_2 is not None
assert found_2.response_id == 'resp-1'
# If we ask for event_id=0, no usage in event0 or earlier, so return None
found_0 = get_token_usage_for_event_id(events, 0, metrics)
assert found_0 is None