mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Add event synchronously (#2700)
* add to event stream sync * remove async from tests
This commit is contained in:
@@ -122,7 +122,7 @@ class AgentController:
|
||||
self.state.last_error = message
|
||||
if exception:
|
||||
self.state.last_error += f': {exception}'
|
||||
await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
|
||||
self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
|
||||
|
||||
async def add_history(self, action: Action, observation: Observation):
|
||||
if isinstance(action, NullAction) and isinstance(observation, NullObservation):
|
||||
@@ -211,7 +211,7 @@ class AgentController:
|
||||
if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
|
||||
self.reset_task()
|
||||
|
||||
await self.event_stream.add_event(
|
||||
self.event_stream.add_event(
|
||||
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
|
||||
)
|
||||
|
||||
@@ -221,8 +221,6 @@ class AgentController:
|
||||
|
||||
def get_agent_state(self):
|
||||
"""Returns the current state of the agent task."""
|
||||
if self.delegate is not None:
|
||||
return self.delegate.get_agent_state()
|
||||
return self.state.agent_state
|
||||
|
||||
async def start_delegate(self, action: AgentDelegateAction):
|
||||
@@ -301,7 +299,7 @@ class AgentController:
|
||||
# clean up delegate status
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
await self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
@@ -358,7 +356,7 @@ class AgentController:
|
||||
await self.add_history(action, NullObservation(''))
|
||||
|
||||
if not isinstance(action, NullAction):
|
||||
await self.event_stream.add_event(action, EventSource.AGENT)
|
||||
self.event_stream.add_event(action, EventSource.AGENT)
|
||||
|
||||
await self.update_state_after_step()
|
||||
|
||||
|
||||
@@ -114,6 +114,8 @@ def get_console_handler():
|
||||
"""
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
if config.debug:
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
return console_handler
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ async def run_agent_controller(
|
||||
# start event is a MessageAction with the task, either resumed or new
|
||||
if config.enable_cli_session and initial_state is not None:
|
||||
# we're resuming the previous session
|
||||
await event_stream.add_event(
|
||||
event_stream.add_event(
|
||||
MessageAction(
|
||||
content="Let's get back on track. If you experienced errors before, do NOT resume your task. Ask me about it."
|
||||
),
|
||||
@@ -108,7 +108,7 @@ async def run_agent_controller(
|
||||
)
|
||||
elif initial_state is None:
|
||||
# init with the provided task
|
||||
await event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
|
||||
event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
|
||||
|
||||
async def on_event(event: Event):
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
@@ -120,10 +120,10 @@ async def run_agent_controller(
|
||||
else:
|
||||
message = fake_user_response_fn(controller.get_state())
|
||||
action = MessageAction(content=message)
|
||||
await event_stream.add_event(action, EventSource.USER)
|
||||
event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
|
||||
while controller.get_agent_state() not in [
|
||||
while controller.state.agent_state not in [
|
||||
AgentState.FINISHED,
|
||||
AgentState.REJECTED,
|
||||
AgentState.ERROR,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Callable, Iterable
|
||||
@@ -25,7 +26,7 @@ class EventStream:
|
||||
# when there are agent delegates
|
||||
_subscribers: dict[str, list[Callable]]
|
||||
_cur_id: int
|
||||
_lock: asyncio.Lock
|
||||
_lock: threading.Lock
|
||||
_file_store: FileStore
|
||||
|
||||
def __init__(self, sid: str):
|
||||
@@ -33,7 +34,7 @@ class EventStream:
|
||||
self._file_store = get_file_store()
|
||||
self._subscribers = {}
|
||||
self._cur_id = 0
|
||||
self._lock = asyncio.Lock()
|
||||
self._lock = threading.Lock()
|
||||
self._reinitialize_from_file_store()
|
||||
|
||||
def _reinitialize_from_file_store(self):
|
||||
@@ -93,12 +94,11 @@ class EventStream:
|
||||
if len(self._subscribers[id]) == 0:
|
||||
del self._subscribers[id]
|
||||
|
||||
# TODO: make this not async
|
||||
async def add_event(self, event: Event, source: EventSource):
|
||||
logger.debug(f'Adding event {event} from {source}')
|
||||
async with self._lock:
|
||||
event._id = self._cur_id # type: ignore[attr-defined]
|
||||
def add_event(self, event: Event, source: EventSource):
|
||||
with self._lock:
|
||||
event._id = self._cur_id # type: ignore [attr-defined]
|
||||
self._cur_id += 1
|
||||
logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
|
||||
event._timestamp = datetime.now() # type: ignore[attr-defined]
|
||||
event._source = source # type: ignore[attr-defined]
|
||||
data = event_to_dict(event)
|
||||
@@ -108,5 +108,4 @@ class EventStream:
|
||||
)
|
||||
for stack in self._subscribers.values():
|
||||
callback = stack[-1]
|
||||
logger.debug(f'Notifying subscriber {callback} of event {event}')
|
||||
await callback(event)
|
||||
asyncio.create_task(callback(event))
|
||||
|
||||
@@ -114,7 +114,7 @@ class Runtime:
|
||||
observation = await self.run_action(event)
|
||||
observation._cause = event.id # type: ignore[attr-defined]
|
||||
source = event.source if event.source else EventSource.AGENT
|
||||
await self.event_stream.add_event(observation, source)
|
||||
self.event_stream.add_event(observation, source)
|
||||
|
||||
async def run_action(self, action: Action) -> Observation:
|
||||
"""
|
||||
@@ -149,7 +149,7 @@ class Runtime:
|
||||
for _id, cmd in self.sandbox.background_commands.items():
|
||||
output = cmd.read_logs()
|
||||
if output:
|
||||
await self.event_stream.add_event(
|
||||
self.event_stream.add_event(
|
||||
CmdOutputObservation(
|
||||
content=output, command_id=_id, command=cmd.command
|
||||
),
|
||||
|
||||
@@ -61,10 +61,10 @@ class Session:
|
||||
logger.exception('Error in loop_recv: %s', e)
|
||||
|
||||
async def _initialize_agent(self, data: dict):
|
||||
await self.agent_session.event_stream.add_event(
|
||||
self.agent_session.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
|
||||
)
|
||||
await self.agent_session.event_stream.add_event(
|
||||
self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
|
||||
)
|
||||
try:
|
||||
@@ -75,7 +75,7 @@ class Session:
|
||||
f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
|
||||
)
|
||||
return
|
||||
await self.agent_session.event_stream.add_event(
|
||||
self.agent_session.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
|
||||
)
|
||||
|
||||
@@ -102,7 +102,7 @@ class Session:
|
||||
await self._initialize_agent(data)
|
||||
return
|
||||
event = event_from_dict(data.copy())
|
||||
await self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
|
||||
async def send(self, data: dict[str, object]) -> bool:
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from opendevin.events import EventSource, EventStream
|
||||
from opendevin.events.action import NullAction
|
||||
from opendevin.events.observation import NullObservation
|
||||
@@ -11,17 +9,15 @@ def collect_events(stream):
|
||||
return [event for event in stream.get_events()]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_flow():
|
||||
def test_basic_flow():
|
||||
stream = EventStream('abc')
|
||||
await stream.add_event(NullAction(), EventSource.AGENT)
|
||||
stream.add_event(NullAction(), EventSource.AGENT)
|
||||
assert len(collect_events(stream)) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_storage():
|
||||
def test_stream_storage():
|
||||
stream = EventStream('def')
|
||||
await stream.add_event(NullObservation(''), EventSource.AGENT)
|
||||
stream.add_event(NullObservation(''), EventSource.AGENT)
|
||||
assert len(collect_events(stream)) == 1
|
||||
content = stream._file_store.read('sessions/def/events/0.json')
|
||||
assert content is not None
|
||||
@@ -38,11 +34,10 @@ async def test_stream_storage():
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rehydration():
|
||||
def test_rehydration():
|
||||
stream1 = EventStream('es1')
|
||||
await stream1.add_event(NullObservation('obs1'), EventSource.AGENT)
|
||||
await stream1.add_event(NullObservation('obs2'), EventSource.AGENT)
|
||||
stream1.add_event(NullObservation('obs1'), EventSource.AGENT)
|
||||
stream1.add_event(NullObservation('obs2'), EventSource.AGENT)
|
||||
assert len(collect_events(stream1)) == 2
|
||||
|
||||
stream2 = EventStream('es2')
|
||||
|
||||
Reference in New Issue
Block a user