Use response_id to track token usage for MessageActions (#6913)

Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
Engel Nyst
2025-03-26 21:07:01 +01:00
committed by GitHub
parent c5491e87aa
commit 9850f1767a
4 changed files with 125 additions and 17 deletions

View File

@@ -215,6 +215,13 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
)
)
# Add response id to actions
# This will ensure we can match both actions without tool calls (e.g. MessageAction)
# and actions with tool calls (e.g. CmdRunAction, IPythonRunCellAction, etc.)
# with the token usage data
for action in actions:
action.response_id = response.id
assert len(actions) >= 1
return actions

View File

@@ -4,22 +4,30 @@ from openhands.llm.metrics import Metrics, TokenUsage
def get_token_usage_for_event(event: Event, metrics: Metrics) -> TokenUsage | None:
"""
Returns at most one token usage record for the `model_response.id` in this event's
`tool_call_metadata`.
Returns at most one token usage record for either:
- `tool_call_metadata.model_response.id`, if possible
- otherwise event.response_id, if set
If no response_id is found, or none match in metrics.token_usages, returns None.
If neither exist or none matches in metrics.token_usages, returns None.
"""
# 1) Use the tool_call_metadata's response.id if present
if event.tool_call_metadata and event.tool_call_metadata.model_response:
response_id = event.tool_call_metadata.model_response.get('id')
if response_id:
return next(
(
usage
for usage in metrics.token_usages
if usage.response_id == response_id
),
tool_response_id = event.tool_call_metadata.model_response.get('id')
if tool_response_id:
usage_rec = next(
(u for u in metrics.token_usages if u.response_id == tool_response_id),
None,
)
if usage_rec is not None:
return usage_rec
# 2) Fallback to the top-level event.response_id if present
if event.response_id:
return next(
(u for u in metrics.token_usages if u.response_id == event.response_id),
None,
)
return None
@@ -28,17 +36,17 @@ def get_token_usage_for_event_id(
) -> TokenUsage | None:
"""
Starting from the event with .id == event_id and moving backwards in `events`,
find the first TokenUsage record (if any) associated with a response_id from
tool_call_metadata.model_response.id.
find the first TokenUsage record (if any) associated either with:
- tool_call_metadata.model_response.id, or
- event.response_id
Returns the first match found, or None if none is found.
"""
# find the index of the event with the given id
# Find the index of the event with the given id
idx = next((i for i, e in enumerate(events) if e.id == event_id), None)
if idx is None:
return None
# search backward from idx down to 0
# Search backward from idx down to 0
for i in range(idx, -1, -1):
usage = get_token_usage_for_event(events[i], metrics)
if usage is not None:

View File

@@ -115,7 +115,7 @@ class Event:
def llm_metrics(self, value: Metrics) -> None:
self._llm_metrics = value
# optional field
# optional field, metadata about the tool call, if the event has a tool call
@property
def tool_call_metadata(self) -> ToolCallMetadata | None:
if hasattr(self, '_tool_call_metadata'):
@@ -126,3 +126,14 @@ class Event:
@tool_call_metadata.setter
def tool_call_metadata(self, value: ToolCallMetadata) -> None:
self._tool_call_metadata = value
# optional field, the id of the response from the LLM
@property
def response_id(self) -> str | None:
if hasattr(self, '_response_id'):
return self._response_id # type: ignore[attr-defined]
return None
@response_id.setter
def response_id(self, value: str) -> None:
self._response_id = value

View File

@@ -115,3 +115,85 @@ def test_get_token_usage_for_event_id():
# If we ask for event_id=0, no usage in event0 or earlier, so return None
found_0 = get_token_usage_for_event_id(events, 0, metrics)
assert found_0 is None
def test_get_token_usage_for_event_fallback():
"""
Verify that if tool_call_metadata.model_response.id is missing or mismatched,
but event.response_id is set to a valid usage ID, we find the usage record via fallback.
"""
metrics = Metrics(model_name='fallback-test')
usage_record = TokenUsage(
model='fallback-test',
prompt_tokens=22,
completion_tokens=8,
cache_read_tokens=3,
cache_write_tokens=2,
response_id='fallback-response-id',
)
metrics.add_token_usage(
prompt_tokens=usage_record.prompt_tokens,
completion_tokens=usage_record.completion_tokens,
cache_read_tokens=usage_record.cache_read_tokens,
cache_write_tokens=usage_record.cache_write_tokens,
response_id=usage_record.response_id,
)
event = Event()
# Provide some mismatched tool_call_metadata:
event._tool_call_metadata = ToolCallMetadata(
tool_call_id='irrelevant-tool-call',
function_name='fake_function',
model_response={'id': 'not-matching-any-usage'},
total_calls_in_response=1,
)
# But also set event.response_id to the actual usage ID
event._response_id = 'fallback-response-id'
found = get_token_usage_for_event(event, metrics)
assert found is not None
assert found.prompt_tokens == 22
assert found.response_id == 'fallback-response-id'
def test_get_token_usage_for_event_id_fallback():
"""
Verify that get_token_usage_for_event_id also falls back to event.response_id
if tool_call_metadata.model_response.id is missing or mismatched.
"""
# NOTE: this should never happen (tm), but there is a hint in the code that it might:
# message_utils.py: 166 ("(overwrites any previous message with the same response_id)")
# so we'll handle it gracefully.
metrics = Metrics(model_name='fallback-test')
usage_record = TokenUsage(
model='fallback-test',
prompt_tokens=15,
completion_tokens=4,
cache_read_tokens=1,
cache_write_tokens=0,
response_id='resp-fallback',
)
metrics.token_usages.append(usage_record)
events = []
for i in range(3):
e = Event()
e._id = i
if i == 1:
# Mismatch in tool_call_metadata
e._tool_call_metadata = ToolCallMetadata(
tool_call_id='tool-123',
function_name='whatever',
model_response={'id': 'no-such-response'},
total_calls_in_response=1,
)
# But the event's top-level response_id is correct
e._response_id = 'resp-fallback'
events.append(e)
# Searching from event_id=2 goes back to event1, which has fallback response_id
found_usage = get_token_usage_for_event_id(events, 2, metrics)
assert found_usage is not None
assert found_usage.response_id == 'resp-fallback'
assert found_usage.prompt_tokens == 15