mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
Use LLM APIs responses in token counting (#5604)
Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
@@ -29,6 +29,7 @@ from openhands.events.observation import (
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
|
||||
|
||||
def events_to_messages(
|
||||
@@ -362,3 +363,47 @@ def apply_prompt_caching(messages: list[Message]) -> None:
|
||||
-1
|
||||
].cache_prompt = True # Last item inside the message content
|
||||
break
|
||||
|
||||
|
||||
def get_token_usage_for_event(event: Event, metrics: Metrics) -> TokenUsage | None:
|
||||
"""
|
||||
Returns at most one token usage record for the `model_response.id` in this event's
|
||||
`tool_call_metadata`.
|
||||
|
||||
If no response_id is found, or none match in metrics.token_usages, returns None.
|
||||
"""
|
||||
if event.tool_call_metadata and event.tool_call_metadata.model_response:
|
||||
response_id = event.tool_call_metadata.model_response.get('id')
|
||||
if response_id:
|
||||
return next(
|
||||
(
|
||||
usage
|
||||
for usage in metrics.token_usages
|
||||
if usage.response_id == response_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_token_usage_for_event_id(
|
||||
events: list[Event], event_id: int, metrics: Metrics
|
||||
) -> TokenUsage | None:
|
||||
"""
|
||||
Starting from the event with .id == event_id and moving backwards in `events`,
|
||||
find the first TokenUsage record (if any) associated with a response_id from
|
||||
tool_call_metadata.model_response.id.
|
||||
|
||||
Returns the first match found, or None if none is found.
|
||||
"""
|
||||
# find the index of the event with the given id
|
||||
idx = next((i for i, e in enumerate(events) if e.id == event_id), None)
|
||||
if idx is None:
|
||||
return None
|
||||
|
||||
# search backward from idx down to 0
|
||||
for i in range(idx, -1, -1):
|
||||
usage = get_token_usage_for_event(events[i], metrics)
|
||||
if usage is not None:
|
||||
return usage
|
||||
return None
|
||||
|
||||
@@ -497,20 +497,21 @@ class LLM(RetryMixin, DebugMixin):
|
||||
stats += 'Response Latency: %.3f seconds\n' % latest_latency.latency
|
||||
|
||||
usage: Usage | None = response.get('usage')
|
||||
response_id = response.get('id', 'unknown')
|
||||
|
||||
if usage:
|
||||
# keep track of the input and output tokens
|
||||
input_tokens = usage.get('prompt_tokens')
|
||||
output_tokens = usage.get('completion_tokens')
|
||||
prompt_tokens = usage.get('prompt_tokens', 0)
|
||||
completion_tokens = usage.get('completion_tokens', 0)
|
||||
|
||||
if input_tokens:
|
||||
stats += 'Input tokens: ' + str(input_tokens)
|
||||
if prompt_tokens:
|
||||
stats += 'Input tokens: ' + str(prompt_tokens)
|
||||
|
||||
if output_tokens:
|
||||
if completion_tokens:
|
||||
stats += (
|
||||
(' | ' if input_tokens else '')
|
||||
(' | ' if prompt_tokens else '')
|
||||
+ 'Output tokens: '
|
||||
+ str(output_tokens)
|
||||
+ str(completion_tokens)
|
||||
+ '\n'
|
||||
)
|
||||
|
||||
@@ -519,7 +520,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
'prompt_tokens_details'
|
||||
)
|
||||
cache_hit_tokens = (
|
||||
prompt_tokens_details.cached_tokens if prompt_tokens_details else None
|
||||
prompt_tokens_details.cached_tokens if prompt_tokens_details else 0
|
||||
)
|
||||
if cache_hit_tokens:
|
||||
stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'
|
||||
@@ -528,10 +529,20 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# but litellm doesn't separate them in the usage stats
|
||||
# so we can read it from the provider-specific extra field
|
||||
model_extra = usage.get('model_extra', {})
|
||||
cache_write_tokens = model_extra.get('cache_creation_input_tokens')
|
||||
cache_write_tokens = model_extra.get('cache_creation_input_tokens', 0)
|
||||
if cache_write_tokens:
|
||||
stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
|
||||
|
||||
# Record in metrics
|
||||
# We'll treat cache_hit_tokens as "cache read" and cache_write_tokens as "cache write"
|
||||
self.metrics.add_token_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_hit_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
# log the stats
|
||||
if stats:
|
||||
logger.debug(stats)
|
||||
|
||||
@@ -17,11 +17,23 @@ class ResponseLatency(BaseModel):
|
||||
response_id: str
|
||||
|
||||
|
||||
class TokenUsage(BaseModel):
|
||||
"""Metric tracking detailed token usage per completion call."""
|
||||
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
cache_read_tokens: int
|
||||
cache_write_tokens: int
|
||||
response_id: str
|
||||
|
||||
|
||||
class Metrics:
|
||||
"""Metrics class can record various metrics during running and evaluation.
|
||||
Currently, we define the following metrics:
|
||||
accumulated_cost: the total cost (USD $) of the current LLM.
|
||||
response_latency: the time taken for each LLM completion call.
|
||||
We track:
|
||||
- accumulated_cost and costs
|
||||
- A list of ResponseLatency
|
||||
- A list of TokenUsage (one per call).
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = 'default') -> None:
|
||||
@@ -29,6 +41,7 @@ class Metrics:
|
||||
self._costs: list[Cost] = []
|
||||
self._response_latencies: list[ResponseLatency] = []
|
||||
self.model_name = model_name
|
||||
self._token_usages: list[TokenUsage] = []
|
||||
|
||||
@property
|
||||
def accumulated_cost(self) -> float:
|
||||
@@ -54,6 +67,16 @@ class Metrics:
|
||||
def response_latencies(self, value: list[ResponseLatency]) -> None:
|
||||
self._response_latencies = value
|
||||
|
||||
@property
|
||||
def token_usages(self) -> list[TokenUsage]:
|
||||
if not hasattr(self, '_token_usages'):
|
||||
self._token_usages = []
|
||||
return self._token_usages
|
||||
|
||||
@token_usages.setter
|
||||
def token_usages(self, value: list[TokenUsage]) -> None:
|
||||
self._token_usages = value
|
||||
|
||||
def add_cost(self, value: float) -> None:
|
||||
if value < 0:
|
||||
raise ValueError('Added cost cannot be negative.')
|
||||
@@ -67,10 +90,33 @@ class Metrics:
|
||||
)
|
||||
)
|
||||
|
||||
def add_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cache_read_tokens: int,
|
||||
cache_write_tokens: int,
|
||||
response_id: str,
|
||||
) -> None:
|
||||
"""Add a single usage record."""
|
||||
self._token_usages.append(
|
||||
TokenUsage(
|
||||
model=self.model_name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
response_id=response_id,
|
||||
)
|
||||
)
|
||||
|
||||
def merge(self, other: 'Metrics') -> None:
|
||||
"""Merge 'other' metrics into this one."""
|
||||
self._accumulated_cost += other.accumulated_cost
|
||||
self._costs += other._costs
|
||||
self._response_latencies += other._response_latencies
|
||||
# use the property so older picked objects that lack the field won't crash
|
||||
self.token_usages += other.token_usages
|
||||
self.response_latencies += other.response_latencies
|
||||
|
||||
def get(self) -> dict:
|
||||
"""Return the metrics in a dictionary."""
|
||||
@@ -80,12 +126,14 @@ class Metrics:
|
||||
'response_latencies': [
|
||||
latency.model_dump() for latency in self._response_latencies
|
||||
],
|
||||
'token_usages': [usage.model_dump() for usage in self._token_usages],
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
self._accumulated_cost = 0.0
|
||||
self._costs = []
|
||||
self._response_latencies = []
|
||||
self._token_usages = []
|
||||
|
||||
def log(self):
|
||||
"""Log the metrics."""
|
||||
|
||||
@@ -2,6 +2,7 @@ import copy
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from litellm import PromptTokensDetails
|
||||
from litellm.exceptions import (
|
||||
RateLimitError,
|
||||
)
|
||||
@@ -429,3 +430,62 @@ def test_get_token_count_error_handling(
|
||||
mock_logger.error.assert_called_once_with(
|
||||
'Error getting token count for\n model gpt-4o\nToken counting failed'
|
||||
)
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_llm_token_usage(mock_litellm_completion, default_config):
|
||||
# This mock response includes usage details with prompt_tokens,
|
||||
# completion_tokens, prompt_tokens_details.cached_tokens, and model_extra.cache_creation_input_tokens
|
||||
mock_response_1 = {
|
||||
'id': 'test-response-usage',
|
||||
'choices': [{'message': {'content': 'Usage test response'}}],
|
||||
'usage': {
|
||||
'prompt_tokens': 12,
|
||||
'completion_tokens': 3,
|
||||
'prompt_tokens_details': PromptTokensDetails(cached_tokens=2),
|
||||
'model_extra': {'cache_creation_input_tokens': 5},
|
||||
},
|
||||
}
|
||||
|
||||
# Create a second usage scenario to test accumulation and a different response_id
|
||||
mock_response_2 = {
|
||||
'id': 'test-response-usage-2',
|
||||
'choices': [{'message': {'content': 'Second usage test response'}}],
|
||||
'usage': {
|
||||
'prompt_tokens': 7,
|
||||
'completion_tokens': 2,
|
||||
'prompt_tokens_details': PromptTokensDetails(cached_tokens=1),
|
||||
'model_extra': {'cache_creation_input_tokens': 3},
|
||||
},
|
||||
}
|
||||
|
||||
# We'll make mock_litellm_completion return these responses in sequence
|
||||
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
|
||||
# First call
|
||||
llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}])
|
||||
|
||||
# Verify we have exactly one usage record after first call
|
||||
token_usage_list = llm.metrics.get()['token_usages']
|
||||
assert len(token_usage_list) == 1
|
||||
usage_entry_1 = token_usage_list[0]
|
||||
assert usage_entry_1['prompt_tokens'] == 12
|
||||
assert usage_entry_1['completion_tokens'] == 3
|
||||
assert usage_entry_1['cache_read_tokens'] == 2
|
||||
assert usage_entry_1['cache_write_tokens'] == 5
|
||||
assert usage_entry_1['response_id'] == 'test-response-usage'
|
||||
|
||||
# Second call
|
||||
llm.completion(messages=[{'role': 'user', 'content': 'Hello again!'}])
|
||||
|
||||
# Now we expect two usage records total
|
||||
token_usage_list = llm.metrics.get()['token_usages']
|
||||
assert len(token_usage_list) == 2
|
||||
usage_entry_2 = token_usage_list[-1]
|
||||
assert usage_entry_2['prompt_tokens'] == 7
|
||||
assert usage_entry_2['completion_tokens'] == 2
|
||||
assert usage_entry_2['cache_read_tokens'] == 1
|
||||
assert usage_entry_2['cache_write_tokens'] == 3
|
||||
assert usage_entry_2['response_id'] == 'test-response-usage-2'
|
||||
|
||||
@@ -3,13 +3,18 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
from openhands.core.message import ImageContent, TextContent
|
||||
from openhands.core.message_utils import get_action_message, get_observation_message
|
||||
from openhands.core.message_utils import (
|
||||
get_action_message,
|
||||
get_observation_message,
|
||||
get_token_usage_for_event,
|
||||
get_token_usage_for_event_id,
|
||||
)
|
||||
from openhands.events.action import (
|
||||
AgentFinishAction,
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import EventSource, FileEditSource, FileReadSource
|
||||
from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputMetadata,
|
||||
@@ -21,6 +26,7 @@ from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.files import FileEditObservation, FileReadObservation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
|
||||
|
||||
def test_cmd_output_observation_message():
|
||||
@@ -269,3 +275,113 @@ def test_agent_finish_action_with_tool_metadata():
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Initial thought\nTask completed' in result.content[0].text
|
||||
|
||||
|
||||
def test_get_token_usage_for_event():
|
||||
"""Test that we get the single matching usage record (if any) based on the event's model_response.id."""
|
||||
metrics = Metrics(model_name='test-model')
|
||||
usage_record = TokenUsage(
|
||||
model='test-model',
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
cache_read_tokens=2,
|
||||
cache_write_tokens=1,
|
||||
response_id='test-response-id',
|
||||
)
|
||||
metrics.add_token_usage(
|
||||
prompt_tokens=usage_record.prompt_tokens,
|
||||
completion_tokens=usage_record.completion_tokens,
|
||||
cache_read_tokens=usage_record.cache_read_tokens,
|
||||
cache_write_tokens=usage_record.cache_write_tokens,
|
||||
response_id=usage_record.response_id,
|
||||
)
|
||||
|
||||
# Create an event referencing that response_id
|
||||
event = Event()
|
||||
mock_tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id='test-tool-call',
|
||||
function_name='fake_function',
|
||||
model_response={'id': 'test-response-id'},
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
event._tool_call_metadata = (
|
||||
mock_tool_call_metadata # normally you'd do event.tool_call_metadata = ...
|
||||
)
|
||||
|
||||
# We should find that usage record
|
||||
found = get_token_usage_for_event(event, metrics)
|
||||
assert found is not None
|
||||
assert found.prompt_tokens == 10
|
||||
assert found.response_id == 'test-response-id'
|
||||
|
||||
# If we change the event's response ID, we won't find anything
|
||||
mock_tool_call_metadata.model_response.id = 'some-other-id'
|
||||
found2 = get_token_usage_for_event(event, metrics)
|
||||
assert found2 is None
|
||||
|
||||
# If the event has no tool_call_metadata, also returns None
|
||||
event._tool_call_metadata = None
|
||||
found3 = get_token_usage_for_event(event, metrics)
|
||||
assert found3 is None
|
||||
|
||||
|
||||
def test_get_token_usage_for_event_id():
|
||||
"""
|
||||
Test that we search backward from the event with the given id,
|
||||
finding the first usage record that matches a response_id in that or previous events.
|
||||
"""
|
||||
metrics = Metrics(model_name='test-model')
|
||||
usage_1 = TokenUsage(
|
||||
model='test-model',
|
||||
prompt_tokens=12,
|
||||
completion_tokens=3,
|
||||
cache_read_tokens=2,
|
||||
cache_write_tokens=5,
|
||||
response_id='resp-1',
|
||||
)
|
||||
usage_2 = TokenUsage(
|
||||
model='test-model',
|
||||
prompt_tokens=7,
|
||||
completion_tokens=2,
|
||||
cache_read_tokens=1,
|
||||
cache_write_tokens=3,
|
||||
response_id='resp-2',
|
||||
)
|
||||
metrics._token_usages.append(usage_1)
|
||||
metrics._token_usages.append(usage_2)
|
||||
|
||||
# Build a list of events
|
||||
events = []
|
||||
for i in range(5):
|
||||
e = Event()
|
||||
e._id = i
|
||||
# We'll attach usage_1 to event 1, usage_2 to event 3
|
||||
if i == 1:
|
||||
e._tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id='tid1',
|
||||
function_name='fn1',
|
||||
model_response={'id': 'resp-1'},
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
elif i == 3:
|
||||
e._tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id='tid2',
|
||||
function_name='fn2',
|
||||
model_response={'id': 'resp-2'},
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
events.append(e)
|
||||
|
||||
# If we ask for event_id=3, we find usage_2 immediately
|
||||
found_3 = get_token_usage_for_event_id(events, 3, metrics)
|
||||
assert found_3 is not None
|
||||
assert found_3.response_id == 'resp-2'
|
||||
|
||||
# If we ask for event_id=2, no usage in event2, so we check event1 -> usage_1 found
|
||||
found_2 = get_token_usage_for_event_id(events, 2, metrics)
|
||||
assert found_2 is not None
|
||||
assert found_2.response_id == 'resp-1'
|
||||
|
||||
# If we ask for event_id=0, no usage in event0 or earlier, so return None
|
||||
found_0 = get_token_usage_for_event_id(events, 0, metrics)
|
||||
assert found_0 is None
|
||||
|
||||
Reference in New Issue
Block a user