diff --git a/openhands/core/message_utils.py b/openhands/core/message_utils.py index 7683c7c445..1ce4b4f84b 100644 --- a/openhands/core/message_utils.py +++ b/openhands/core/message_utils.py @@ -29,6 +29,7 @@ from openhands.events.observation import ( from openhands.events.observation.error import ErrorObservation from openhands.events.observation.observation import Observation from openhands.events.serialization.event import truncate_content +from openhands.llm.metrics import Metrics, TokenUsage def events_to_messages( @@ -362,3 +363,47 @@ def apply_prompt_caching(messages: list[Message]) -> None: -1 ].cache_prompt = True # Last item inside the message content break + + +def get_token_usage_for_event(event: Event, metrics: Metrics) -> TokenUsage | None: + """ + Returns at most one token usage record for the `model_response.id` in this event's + `tool_call_metadata`. + + If no response_id is found, or none match in metrics.token_usages, returns None. + """ + if event.tool_call_metadata and event.tool_call_metadata.model_response: + response_id = event.tool_call_metadata.model_response.get('id') + if response_id: + return next( + ( + usage + for usage in metrics.token_usages + if usage.response_id == response_id + ), + None, + ) + return None + + +def get_token_usage_for_event_id( + events: list[Event], event_id: int, metrics: Metrics +) -> TokenUsage | None: + """ + Starting from the event with .id == event_id and moving backwards in `events`, + find the first TokenUsage record (if any) associated with a response_id from + tool_call_metadata.model_response.id. + + Returns the first match found, or None if none is found. + """ + # find the index of the event with the given id + idx = next((i for i, e in enumerate(events) if e.id == event_id), None) + if idx is None: + return None + + # search backward from idx down to 0 + for i in range(idx, -1, -1): + usage = get_token_usage_for_event(events[i], metrics) + if usage is not None: + return usage + return None diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index b40f11ca83..66bc6f99cb 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -497,20 +497,21 @@ class LLM(RetryMixin, DebugMixin): stats += 'Response Latency: %.3f seconds\n' % latest_latency.latency usage: Usage | None = response.get('usage') + response_id = response.get('id', 'unknown') if usage: # keep track of the input and output tokens - input_tokens = usage.get('prompt_tokens') - output_tokens = usage.get('completion_tokens') + prompt_tokens = usage.get('prompt_tokens', 0) + completion_tokens = usage.get('completion_tokens', 0) - if input_tokens: - stats += 'Input tokens: ' + str(input_tokens) + if prompt_tokens: + stats += 'Input tokens: ' + str(prompt_tokens) - if output_tokens: + if completion_tokens: stats += ( - (' | ' if input_tokens else '') + (' | ' if prompt_tokens else '') + 'Output tokens: ' - + str(output_tokens) + + str(completion_tokens) + '\n' ) @@ -519,7 +520,7 @@ class LLM(RetryMixin, DebugMixin): 'prompt_tokens_details' ) cache_hit_tokens = ( - prompt_tokens_details.cached_tokens if prompt_tokens_details else None + prompt_tokens_details.cached_tokens if prompt_tokens_details else 0 ) if cache_hit_tokens: stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n' @@ -528,10 +529,20 @@ class LLM(RetryMixin, DebugMixin): # but litellm doesn't separate them in the usage stats # so we can read it from the provider-specific extra field model_extra = usage.get('model_extra', {}) - cache_write_tokens = model_extra.get('cache_creation_input_tokens') + cache_write_tokens = model_extra.get('cache_creation_input_tokens', 0) if cache_write_tokens: stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n' + # Record in metrics + # We'll treat cache_hit_tokens as "cache read" and cache_write_tokens as "cache write" + self.metrics.add_token_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cache_read_tokens=cache_hit_tokens, + cache_write_tokens=cache_write_tokens, + response_id=response_id, + ) + # log the stats if stats: logger.debug(stats) diff --git a/openhands/llm/metrics.py b/openhands/llm/metrics.py index a010bb2691..a5ec0efd75 100644 --- a/openhands/llm/metrics.py +++ b/openhands/llm/metrics.py @@ -17,11 +17,23 @@ class ResponseLatency(BaseModel): response_id: str +class TokenUsage(BaseModel): + """Metric tracking detailed token usage per completion call.""" + + model: str + prompt_tokens: int + completion_tokens: int + cache_read_tokens: int + cache_write_tokens: int + response_id: str + + class Metrics: """Metrics class can record various metrics during running and evaluation. - Currently, we define the following metrics: - accumulated_cost: the total cost (USD $) of the current LLM. - response_latency: the time taken for each LLM completion call. + We track: + - accumulated_cost and costs + - A list of ResponseLatency + - A list of TokenUsage (one per call). """ def __init__(self, model_name: str = 'default') -> None: @@ -29,6 +41,7 @@ class Metrics: self._costs: list[Cost] = [] self._response_latencies: list[ResponseLatency] = [] self.model_name = model_name + self._token_usages: list[TokenUsage] = [] @property def accumulated_cost(self) -> float: @@ -54,6 +67,16 @@ class Metrics: def response_latencies(self, value: list[ResponseLatency]) -> None: self._response_latencies = value + @property + def token_usages(self) -> list[TokenUsage]: + if not hasattr(self, '_token_usages'): + self._token_usages = [] + return self._token_usages + + @token_usages.setter + def token_usages(self, value: list[TokenUsage]) -> None: + self._token_usages = value + def add_cost(self, value: float) -> None: if value < 0: raise ValueError('Added cost cannot be negative.') @@ -67,10 +90,33 @@ class Metrics: ) ) + def add_token_usage( + self, + prompt_tokens: int, + completion_tokens: int, + cache_read_tokens: int, + cache_write_tokens: int, + response_id: str, + ) -> None: + """Add a single usage record.""" + self._token_usages.append( + TokenUsage( + model=self.model_name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cache_read_tokens=cache_read_tokens, + cache_write_tokens=cache_write_tokens, + response_id=response_id, + ) + ) + def merge(self, other: 'Metrics') -> None: + """Merge 'other' metrics into this one.""" self._accumulated_cost += other.accumulated_cost self._costs += other._costs - self._response_latencies += other._response_latencies + # use the property so older picked objects that lack the field won't crash + self.token_usages += other.token_usages + self.response_latencies += other.response_latencies def get(self) -> dict: """Return the metrics in a dictionary.""" @@ -80,12 +126,14 @@ class Metrics: 'response_latencies': [ latency.model_dump() for latency in self._response_latencies ], + 'token_usages': [usage.model_dump() for usage in self._token_usages], } def reset(self): self._accumulated_cost = 0.0 self._costs = [] self._response_latencies = [] + self._token_usages = [] def log(self): """Log the metrics.""" diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index 1bfee85506..0ec7fe2521 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -2,6 +2,7 @@ import copy from unittest.mock import MagicMock, patch import pytest +from litellm import PromptTokensDetails from litellm.exceptions import ( RateLimitError, ) @@ -429,3 +430,62 @@ def test_get_token_count_error_handling( mock_logger.error.assert_called_once_with( 'Error getting token count for\n model gpt-4o\nToken counting failed' ) + + +@patch('openhands.llm.llm.litellm_completion') +def test_llm_token_usage(mock_litellm_completion, default_config): + # This mock response includes usage details with prompt_tokens, + # completion_tokens, prompt_tokens_details.cached_tokens, and model_extra.cache_creation_input_tokens + mock_response_1 = { + 'id': 'test-response-usage', + 'choices': [{'message': {'content': 'Usage test response'}}], + 'usage': { + 'prompt_tokens': 12, + 'completion_tokens': 3, + 'prompt_tokens_details': PromptTokensDetails(cached_tokens=2), + 'model_extra': {'cache_creation_input_tokens': 5}, + }, + } + + # Create a second usage scenario to test accumulation and a different response_id + mock_response_2 = { + 'id': 'test-response-usage-2', + 'choices': [{'message': {'content': 'Second usage test response'}}], + 'usage': { + 'prompt_tokens': 7, + 'completion_tokens': 2, + 'prompt_tokens_details': PromptTokensDetails(cached_tokens=1), + 'model_extra': {'cache_creation_input_tokens': 3}, + }, + } + + # We'll make mock_litellm_completion return these responses in sequence + mock_litellm_completion.side_effect = [mock_response_1, mock_response_2] + + llm = LLM(config=default_config) + + # First call + llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}]) + + # Verify we have exactly one usage record after first call + token_usage_list = llm.metrics.get()['token_usages'] + assert len(token_usage_list) == 1 + usage_entry_1 = token_usage_list[0] + assert usage_entry_1['prompt_tokens'] == 12 + assert usage_entry_1['completion_tokens'] == 3 + assert usage_entry_1['cache_read_tokens'] == 2 + assert usage_entry_1['cache_write_tokens'] == 5 + assert usage_entry_1['response_id'] == 'test-response-usage' + + # Second call + llm.completion(messages=[{'role': 'user', 'content': 'Hello again!'}]) + + # Now we expect two usage records total + token_usage_list = llm.metrics.get()['token_usages'] + assert len(token_usage_list) == 2 + usage_entry_2 = token_usage_list[-1] + assert usage_entry_2['prompt_tokens'] == 7 + assert usage_entry_2['completion_tokens'] == 2 + assert usage_entry_2['cache_read_tokens'] == 1 + assert usage_entry_2['cache_write_tokens'] == 3 + assert usage_entry_2['response_id'] == 'test-response-usage-2' diff --git a/tests/unit/test_message_utils.py b/tests/unit/test_message_utils.py index d3114519c8..0f3a189a9c 100644 --- a/tests/unit/test_message_utils.py +++ b/tests/unit/test_message_utils.py @@ -3,13 +3,18 @@ from unittest.mock import Mock import pytest from openhands.core.message import ImageContent, TextContent -from openhands.core.message_utils import get_action_message, get_observation_message +from openhands.core.message_utils import ( + get_action_message, + get_observation_message, + get_token_usage_for_event, + get_token_usage_for_event_id, +) from openhands.events.action import ( AgentFinishAction, CmdRunAction, MessageAction, ) -from openhands.events.event import EventSource, FileEditSource, FileReadSource +from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource from openhands.events.observation.browse import BrowserOutputObservation from openhands.events.observation.commands import ( CmdOutputMetadata, @@ -21,6 +26,7 @@ from openhands.events.observation.error import ErrorObservation from openhands.events.observation.files import FileEditObservation, FileReadObservation from openhands.events.observation.reject import UserRejectObservation from openhands.events.tool import ToolCallMetadata +from openhands.llm.metrics import Metrics, TokenUsage def test_cmd_output_observation_message(): @@ -269,3 +275,113 @@ def test_agent_finish_action_with_tool_metadata(): assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert 'Initial thought\nTask completed' in result.content[0].text + + +def test_get_token_usage_for_event(): + """Test that we get the single matching usage record (if any) based on the event's model_response.id.""" + metrics = Metrics(model_name='test-model') + usage_record = TokenUsage( + model='test-model', + prompt_tokens=10, + completion_tokens=5, + cache_read_tokens=2, + cache_write_tokens=1, + response_id='test-response-id', + ) + metrics.add_token_usage( + prompt_tokens=usage_record.prompt_tokens, + completion_tokens=usage_record.completion_tokens, + cache_read_tokens=usage_record.cache_read_tokens, + cache_write_tokens=usage_record.cache_write_tokens, + response_id=usage_record.response_id, + ) + + # Create an event referencing that response_id + event = Event() + mock_tool_call_metadata = ToolCallMetadata( + tool_call_id='test-tool-call', + function_name='fake_function', + model_response={'id': 'test-response-id'}, + total_calls_in_response=1, + ) + event._tool_call_metadata = ( + mock_tool_call_metadata # normally you'd do event.tool_call_metadata = ... + ) + + # We should find that usage record + found = get_token_usage_for_event(event, metrics) + assert found is not None + assert found.prompt_tokens == 10 + assert found.response_id == 'test-response-id' + + # If we change the event's response ID, we won't find anything + mock_tool_call_metadata.model_response.id = 'some-other-id' + found2 = get_token_usage_for_event(event, metrics) + assert found2 is None + + # If the event has no tool_call_metadata, also returns None + event._tool_call_metadata = None + found3 = get_token_usage_for_event(event, metrics) + assert found3 is None + + +def test_get_token_usage_for_event_id(): + """ + Test that we search backward from the event with the given id, + finding the first usage record that matches a response_id in that or previous events. + """ + metrics = Metrics(model_name='test-model') + usage_1 = TokenUsage( + model='test-model', + prompt_tokens=12, + completion_tokens=3, + cache_read_tokens=2, + cache_write_tokens=5, + response_id='resp-1', + ) + usage_2 = TokenUsage( + model='test-model', + prompt_tokens=7, + completion_tokens=2, + cache_read_tokens=1, + cache_write_tokens=3, + response_id='resp-2', + ) + metrics._token_usages.append(usage_1) + metrics._token_usages.append(usage_2) + + # Build a list of events + events = [] + for i in range(5): + e = Event() + e._id = i + # We'll attach usage_1 to event 1, usage_2 to event 3 + if i == 1: + e._tool_call_metadata = ToolCallMetadata( + tool_call_id='tid1', + function_name='fn1', + model_response={'id': 'resp-1'}, + total_calls_in_response=1, + ) + elif i == 3: + e._tool_call_metadata = ToolCallMetadata( + tool_call_id='tid2', + function_name='fn2', + model_response={'id': 'resp-2'}, + total_calls_in_response=1, + ) + events.append(e) + + # If we ask for event_id=3, we find usage_2 immediately + found_3 = get_token_usage_for_event_id(events, 3, metrics) + assert found_3 is not None + assert found_3.response_id == 'resp-2' + + # If we ask for event_id=2, no usage in event2, so we check event1 -> usage_1 found + found_2 = get_token_usage_for_event_id(events, 2, metrics) + assert found_2 is not None + assert found_2.response_id == 'resp-1' + + # If we ask for event_id=0, no usage in event0 or earlier, so return None + found_0 = get_token_usage_for_event_id(events, 0, metrics) + assert found_0 is None