Add event synchronously (#2700)

* add to event stream sync

* remove async from tests
This commit is contained in:
Engel Nyst
2024-07-05 00:15:51 +02:00
committed by GitHub
parent 1b10e2b9d5
commit 0b8d357bef
7 changed files with 31 additions and 37 deletions

View File

@@ -122,7 +122,7 @@ class AgentController:
self.state.last_error = message
if exception:
self.state.last_error += f': {exception}'
await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
async def add_history(self, action: Action, observation: Observation):
if isinstance(action, NullAction) and isinstance(observation, NullObservation):
@@ -211,7 +211,7 @@ class AgentController:
if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
self.reset_task()
await self.event_stream.add_event(
self.event_stream.add_event(
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
)
@@ -221,8 +221,6 @@ class AgentController:
def get_agent_state(self):
"""Returns the current state of the agent task."""
if self.delegate is not None:
return self.delegate.get_agent_state()
return self.state.agent_state
async def start_delegate(self, action: AgentDelegateAction):
@@ -301,7 +299,7 @@ class AgentController:
# clean up delegate status
self.delegate = None
self.delegateAction = None
await self.event_stream.add_event(obs, EventSource.AGENT)
self.event_stream.add_event(obs, EventSource.AGENT)
return
logger.info(
@@ -358,7 +356,7 @@ class AgentController:
await self.add_history(action, NullObservation(''))
if not isinstance(action, NullAction):
await self.event_stream.add_event(action, EventSource.AGENT)
self.event_stream.add_event(action, EventSource.AGENT)
await self.update_state_after_step()

View File

@@ -114,6 +114,8 @@ def get_console_handler():
"""
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
if config.debug:
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(console_formatter)
return console_handler

View File

@@ -100,7 +100,7 @@ async def run_agent_controller(
# start event is a MessageAction with the task, either resumed or new
if config.enable_cli_session and initial_state is not None:
# we're resuming the previous session
await event_stream.add_event(
event_stream.add_event(
MessageAction(
content="Let's get back on track. If you experienced errors before, do NOT resume your task. Ask me about it."
),
@@ -108,7 +108,7 @@ async def run_agent_controller(
)
elif initial_state is None:
# init with the provided task
await event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
async def on_event(event: Event):
if isinstance(event, AgentStateChangedObservation):
@@ -120,10 +120,10 @@ async def run_agent_controller(
else:
message = fake_user_response_fn(controller.get_state())
action = MessageAction(content=message)
await event_stream.add_event(action, EventSource.USER)
event_stream.add_event(action, EventSource.USER)
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
while controller.get_agent_state() not in [
while controller.state.agent_state not in [
AgentState.FINISHED,
AgentState.REJECTED,
AgentState.ERROR,

View File

@@ -1,5 +1,6 @@
import asyncio
import json
import threading
from datetime import datetime
from enum import Enum
from typing import Callable, Iterable
@@ -25,7 +26,7 @@ class EventStream:
# when there are agent delegates
_subscribers: dict[str, list[Callable]]
_cur_id: int
_lock: asyncio.Lock
_lock: threading.Lock
_file_store: FileStore
def __init__(self, sid: str):
@@ -33,7 +34,7 @@ class EventStream:
self._file_store = get_file_store()
self._subscribers = {}
self._cur_id = 0
self._lock = asyncio.Lock()
self._lock = threading.Lock()
self._reinitialize_from_file_store()
def _reinitialize_from_file_store(self):
@@ -93,12 +94,11 @@ class EventStream:
if len(self._subscribers[id]) == 0:
del self._subscribers[id]
# TODO: make this not async
async def add_event(self, event: Event, source: EventSource):
logger.debug(f'Adding event {event} from {source}')
async with self._lock:
event._id = self._cur_id # type: ignore[attr-defined]
def add_event(self, event: Event, source: EventSource):
with self._lock:
event._id = self._cur_id # type: ignore [attr-defined]
self._cur_id += 1
logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
event._timestamp = datetime.now() # type: ignore[attr-defined]
event._source = source # type: ignore[attr-defined]
data = event_to_dict(event)
@@ -108,5 +108,4 @@ class EventStream:
)
for stack in self._subscribers.values():
callback = stack[-1]
logger.debug(f'Notifying subscriber {callback} of event {event}')
await callback(event)
asyncio.create_task(callback(event))

View File

@@ -114,7 +114,7 @@ class Runtime:
observation = await self.run_action(event)
observation._cause = event.id # type: ignore[attr-defined]
source = event.source if event.source else EventSource.AGENT
await self.event_stream.add_event(observation, source)
self.event_stream.add_event(observation, source)
async def run_action(self, action: Action) -> Observation:
"""
@@ -149,7 +149,7 @@ class Runtime:
for _id, cmd in self.sandbox.background_commands.items():
output = cmd.read_logs()
if output:
await self.event_stream.add_event(
self.event_stream.add_event(
CmdOutputObservation(
content=output, command_id=_id, command=cmd.command
),

View File

@@ -61,10 +61,10 @@ class Session:
logger.exception('Error in loop_recv: %s', e)
async def _initialize_agent(self, data: dict):
await self.agent_session.event_stream.add_event(
self.agent_session.event_stream.add_event(
ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
)
await self.agent_session.event_stream.add_event(
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
)
try:
@@ -75,7 +75,7 @@ class Session:
f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
)
return
await self.agent_session.event_stream.add_event(
self.agent_session.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
)
@@ -102,7 +102,7 @@ class Session:
await self._initialize_agent(data)
return
event = event_from_dict(data.copy())
await self.agent_session.event_stream.add_event(event, EventSource.USER)
self.agent_session.event_stream.add_event(event, EventSource.USER)
async def send(self, data: dict[str, object]) -> bool:
try:

View File

@@ -1,7 +1,5 @@
import json
import pytest
from opendevin.events import EventSource, EventStream
from opendevin.events.action import NullAction
from opendevin.events.observation import NullObservation
@@ -11,17 +9,15 @@ def collect_events(stream):
return [event for event in stream.get_events()]
@pytest.mark.asyncio
async def test_basic_flow():
def test_basic_flow():
stream = EventStream('abc')
await stream.add_event(NullAction(), EventSource.AGENT)
stream.add_event(NullAction(), EventSource.AGENT)
assert len(collect_events(stream)) == 1
@pytest.mark.asyncio
async def test_stream_storage():
def test_stream_storage():
stream = EventStream('def')
await stream.add_event(NullObservation(''), EventSource.AGENT)
stream.add_event(NullObservation(''), EventSource.AGENT)
assert len(collect_events(stream)) == 1
content = stream._file_store.read('sessions/def/events/0.json')
assert content is not None
@@ -38,11 +34,10 @@ async def test_stream_storage():
}
@pytest.mark.asyncio
async def test_rehydration():
def test_rehydration():
stream1 = EventStream('es1')
await stream1.add_event(NullObservation('obs1'), EventSource.AGENT)
await stream1.add_event(NullObservation('obs2'), EventSource.AGENT)
stream1.add_event(NullObservation('obs1'), EventSource.AGENT)
stream1.add_event(NullObservation('obs2'), EventSource.AGENT)
assert len(collect_events(stream1)) == 2
stream2 = EventStream('es2')