mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
Use response_id to track token usage for MessageActions (#6913)
Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
@@ -215,6 +215,13 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
)
|
||||
)
|
||||
|
||||
# Add response id to actions
|
||||
# This will ensure we can match both actions without tool calls (e.g. MessageAction)
|
||||
# and actions with tool calls (e.g. CmdRunAction, IPythonRunCellAction, etc.)
|
||||
# with the token usage data
|
||||
for action in actions:
|
||||
action.response_id = response.id
|
||||
|
||||
assert len(actions) >= 1
|
||||
return actions
|
||||
|
||||
|
||||
@@ -4,22 +4,30 @@ from openhands.llm.metrics import Metrics, TokenUsage
|
||||
|
||||
def get_token_usage_for_event(event: Event, metrics: Metrics) -> TokenUsage | None:
|
||||
"""
|
||||
Returns at most one token usage record for the `model_response.id` in this event's
|
||||
`tool_call_metadata`.
|
||||
Returns at most one token usage record for either:
|
||||
- `tool_call_metadata.model_response.id`, if possible
|
||||
- otherwise event.response_id, if set
|
||||
|
||||
If no response_id is found, or none match in metrics.token_usages, returns None.
|
||||
If neither exist or none matches in metrics.token_usages, returns None.
|
||||
"""
|
||||
# 1) Use the tool_call_metadata's response.id if present
|
||||
if event.tool_call_metadata and event.tool_call_metadata.model_response:
|
||||
response_id = event.tool_call_metadata.model_response.get('id')
|
||||
if response_id:
|
||||
return next(
|
||||
(
|
||||
usage
|
||||
for usage in metrics.token_usages
|
||||
if usage.response_id == response_id
|
||||
),
|
||||
tool_response_id = event.tool_call_metadata.model_response.get('id')
|
||||
if tool_response_id:
|
||||
usage_rec = next(
|
||||
(u for u in metrics.token_usages if u.response_id == tool_response_id),
|
||||
None,
|
||||
)
|
||||
if usage_rec is not None:
|
||||
return usage_rec
|
||||
|
||||
# 2) Fallback to the top-level event.response_id if present
|
||||
if event.response_id:
|
||||
return next(
|
||||
(u for u in metrics.token_usages if u.response_id == event.response_id),
|
||||
None,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -28,17 +36,17 @@ def get_token_usage_for_event_id(
|
||||
) -> TokenUsage | None:
|
||||
"""
|
||||
Starting from the event with .id == event_id and moving backwards in `events`,
|
||||
find the first TokenUsage record (if any) associated with a response_id from
|
||||
tool_call_metadata.model_response.id.
|
||||
|
||||
find the first TokenUsage record (if any) associated either with:
|
||||
- tool_call_metadata.model_response.id, or
|
||||
- event.response_id
|
||||
Returns the first match found, or None if none is found.
|
||||
"""
|
||||
# find the index of the event with the given id
|
||||
# Find the index of the event with the given id
|
||||
idx = next((i for i, e in enumerate(events) if e.id == event_id), None)
|
||||
if idx is None:
|
||||
return None
|
||||
|
||||
# search backward from idx down to 0
|
||||
# Search backward from idx down to 0
|
||||
for i in range(idx, -1, -1):
|
||||
usage = get_token_usage_for_event(events[i], metrics)
|
||||
if usage is not None:
|
||||
|
||||
@@ -115,7 +115,7 @@ class Event:
|
||||
def llm_metrics(self, value: Metrics) -> None:
|
||||
self._llm_metrics = value
|
||||
|
||||
# optional field
|
||||
# optional field, metadata about the tool call, if the event has a tool call
|
||||
@property
|
||||
def tool_call_metadata(self) -> ToolCallMetadata | None:
|
||||
if hasattr(self, '_tool_call_metadata'):
|
||||
@@ -126,3 +126,14 @@ class Event:
|
||||
@tool_call_metadata.setter
|
||||
def tool_call_metadata(self, value: ToolCallMetadata) -> None:
|
||||
self._tool_call_metadata = value
|
||||
|
||||
# optional field, the id of the response from the LLM
|
||||
@property
|
||||
def response_id(self) -> str | None:
|
||||
if hasattr(self, '_response_id'):
|
||||
return self._response_id # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@response_id.setter
|
||||
def response_id(self, value: str) -> None:
|
||||
self._response_id = value
|
||||
|
||||
@@ -115,3 +115,85 @@ def test_get_token_usage_for_event_id():
|
||||
# If we ask for event_id=0, no usage in event0 or earlier, so return None
|
||||
found_0 = get_token_usage_for_event_id(events, 0, metrics)
|
||||
assert found_0 is None
|
||||
|
||||
|
||||
def test_get_token_usage_for_event_fallback():
|
||||
"""
|
||||
Verify that if tool_call_metadata.model_response.id is missing or mismatched,
|
||||
but event.response_id is set to a valid usage ID, we find the usage record via fallback.
|
||||
"""
|
||||
metrics = Metrics(model_name='fallback-test')
|
||||
usage_record = TokenUsage(
|
||||
model='fallback-test',
|
||||
prompt_tokens=22,
|
||||
completion_tokens=8,
|
||||
cache_read_tokens=3,
|
||||
cache_write_tokens=2,
|
||||
response_id='fallback-response-id',
|
||||
)
|
||||
metrics.add_token_usage(
|
||||
prompt_tokens=usage_record.prompt_tokens,
|
||||
completion_tokens=usage_record.completion_tokens,
|
||||
cache_read_tokens=usage_record.cache_read_tokens,
|
||||
cache_write_tokens=usage_record.cache_write_tokens,
|
||||
response_id=usage_record.response_id,
|
||||
)
|
||||
|
||||
event = Event()
|
||||
# Provide some mismatched tool_call_metadata:
|
||||
event._tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id='irrelevant-tool-call',
|
||||
function_name='fake_function',
|
||||
model_response={'id': 'not-matching-any-usage'},
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
# But also set event.response_id to the actual usage ID
|
||||
event._response_id = 'fallback-response-id'
|
||||
|
||||
found = get_token_usage_for_event(event, metrics)
|
||||
assert found is not None
|
||||
assert found.prompt_tokens == 22
|
||||
assert found.response_id == 'fallback-response-id'
|
||||
|
||||
|
||||
def test_get_token_usage_for_event_id_fallback():
|
||||
"""
|
||||
Verify that get_token_usage_for_event_id also falls back to event.response_id
|
||||
if tool_call_metadata.model_response.id is missing or mismatched.
|
||||
"""
|
||||
|
||||
# NOTE: this should never happen (tm), but there is a hint in the code that it might:
|
||||
# message_utils.py: 166 ("(overwrites any previous message with the same response_id)")
|
||||
# so we'll handle it gracefully.
|
||||
metrics = Metrics(model_name='fallback-test')
|
||||
usage_record = TokenUsage(
|
||||
model='fallback-test',
|
||||
prompt_tokens=15,
|
||||
completion_tokens=4,
|
||||
cache_read_tokens=1,
|
||||
cache_write_tokens=0,
|
||||
response_id='resp-fallback',
|
||||
)
|
||||
metrics.token_usages.append(usage_record)
|
||||
|
||||
events = []
|
||||
for i in range(3):
|
||||
e = Event()
|
||||
e._id = i
|
||||
if i == 1:
|
||||
# Mismatch in tool_call_metadata
|
||||
e._tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id='tool-123',
|
||||
function_name='whatever',
|
||||
model_response={'id': 'no-such-response'},
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
# But the event's top-level response_id is correct
|
||||
e._response_id = 'resp-fallback'
|
||||
events.append(e)
|
||||
|
||||
# Searching from event_id=2 goes back to event1, which has fallback response_id
|
||||
found_usage = get_token_usage_for_event_id(events, 2, metrics)
|
||||
assert found_usage is not None
|
||||
assert found_usage.response_id == 'resp-fallback'
|
||||
assert found_usage.prompt_tokens == 15
|
||||
|
||||
Reference in New Issue
Block a user