mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-14 01:08:01 -05:00
103 lines
3.4 KiB
Python
103 lines
3.4 KiB
Python
from pydantic import BaseModel, Field
|
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
from openhands.events.action import (
|
|
Action,
|
|
ChangeAgentStateAction,
|
|
MessageAction,
|
|
NullAction,
|
|
)
|
|
from openhands.events.event import EventSource
|
|
from openhands.events.observation import (
|
|
AgentStateChangedObservation,
|
|
NullObservation,
|
|
Observation,
|
|
)
|
|
from openhands.events.serialization.event import event_to_dict
|
|
from openhands.security.invariant.nodes import Function, Message, ToolCall, ToolOutput
|
|
|
|
TraceElement = Message | ToolCall | ToolOutput | Function
|
|
|
|
|
|
def get_next_id(trace: list[TraceElement]) -> str:
|
|
used_ids = [el.id for el in trace if isinstance(el, ToolCall)]
|
|
for i in range(1, len(used_ids) + 2):
|
|
if str(i) not in used_ids:
|
|
return str(i)
|
|
return '1'
|
|
|
|
|
|
def get_last_id(
|
|
trace: list[TraceElement],
|
|
) -> str | None:
|
|
for el in reversed(trace):
|
|
if isinstance(el, ToolCall):
|
|
return el.id
|
|
return None
|
|
|
|
|
|
def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]:
|
|
next_id = get_next_id(trace)
|
|
inv_trace: list[TraceElement] = []
|
|
if isinstance(action, MessageAction):
|
|
if action.source == EventSource.USER:
|
|
inv_trace.append(Message(role='user', content=action.content))
|
|
else:
|
|
inv_trace.append(Message(role='assistant', content=action.content))
|
|
elif isinstance(action, (NullAction, ChangeAgentStateAction)):
|
|
pass
|
|
elif hasattr(action, 'action') and action.action is not None:
|
|
event_dict = event_to_dict(action)
|
|
args = event_dict.get('args', {})
|
|
thought = args.pop('thought', None)
|
|
|
|
function = Function(name=action.action, arguments=args)
|
|
if thought is not None:
|
|
inv_trace.append(Message(role='assistant', content=thought))
|
|
inv_trace.append(ToolCall(id=next_id, type='function', function=function))
|
|
else:
|
|
logger.error(f'Unknown action type: {type(action)}')
|
|
return inv_trace
|
|
|
|
|
|
def parse_observation(
|
|
trace: list[TraceElement], obs: Observation
|
|
) -> list[TraceElement]:
|
|
last_id = get_last_id(trace)
|
|
if isinstance(obs, (NullObservation, AgentStateChangedObservation)):
|
|
return []
|
|
elif hasattr(obs, 'content') and obs.content is not None:
|
|
return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)]
|
|
else:
|
|
logger.error(f'Unknown observation type: {type(obs)}')
|
|
return []
|
|
|
|
|
|
def parse_element(
|
|
trace: list[TraceElement], element: Action | Observation
|
|
) -> list[TraceElement]:
|
|
if isinstance(element, Action):
|
|
return parse_action(trace, element)
|
|
return parse_observation(trace, element)
|
|
|
|
|
|
def parse_trace(trace: list[tuple[Action, Observation]]) -> list[TraceElement]:
|
|
inv_trace: list[TraceElement] = []
|
|
for action, obs in trace:
|
|
inv_trace.extend(parse_action(inv_trace, action))
|
|
inv_trace.extend(parse_observation(inv_trace, obs))
|
|
return inv_trace
|
|
|
|
|
|
class InvariantState(BaseModel):
|
|
trace: list[TraceElement] = Field(default_factory=list)
|
|
|
|
def add_action(self, action: Action) -> None:
|
|
self.trace.extend(parse_action(self.trace, action))
|
|
|
|
def add_observation(self, obs: Observation) -> None:
|
|
self.trace.extend(parse_observation(self.trace, obs))
|
|
|
|
def concatenate(self, other: 'InvariantState') -> None:
|
|
self.trace.extend(other.trace)
|