Make responses flow through the message queue (#5)

This commit is contained in:
Jack Gerrits
2024-05-19 17:12:49 -06:00
committed by GitHub
parent b67d096a95
commit b4b1fd5bdd

View File

@@ -1,7 +1,7 @@
import asyncio
from asyncio import Future
from dataclasses import dataclass
from typing import Dict, Generic, List, Set, Type, TypeVar
from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar
from ..core.agent import Agent
from ..core.agent_runtime import AgentRuntime
@@ -33,12 +33,26 @@ class SendMessageEnvelope(Generic[T]):
class ResponseMessageEnvelope(Generic[T]):
"""A message envelope for sending a response to a message."""
...
message: T
future: Future[T]
@dataclass
class BroadcastResponseMessageEnvelope(Generic[T]):
"""A message envelope for sending a response to a message."""
message: List[T]
future: Future[List[T]]
class SingleThreadedAgentRuntime(AgentRuntime[T]):
def __init__(self) -> None:
self._message_queue: List[BroadcastMessageEnvolope[T] | SendMessageEnvelope[T]] = []
self._message_queue: List[
BroadcastMessageEnvolope[T]
| SendMessageEnvelope[T]
| ResponseMessageEnvelope[T]
| BroadcastResponseMessageEnvelope[T]
] = []
self._per_type_subscribers: Dict[Type[T], List[Agent[T]]] = {}
self._agents: Set[Agent[T]] = set()
@@ -55,7 +69,6 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
future: Future[T] = loop.create_future()
self._message_queue.append(SendMessageEnvelope(message, destination, future))
return future
# Returns the response of all handling agents
@@ -71,14 +84,22 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
return
response = await recipient.on_message(message_envelope.message)
message_envelope.future.set_result(response)
self._message_queue.append(ResponseMessageEnvelope(response, message_envelope.future))
async def _process_broadcast(self, message_envelope: BroadcastMessageEnvolope[T]) -> None:
responses: List[T] = []
responses: List[Awaitable[T]] = []
for agent in self._per_type_subscribers.get(type(message_envelope.message), []):
response = await agent.on_message(message_envelope.message)
responses.append(response)
message_envelope.future.set_result(responses)
future = agent.on_message(message_envelope.message)
responses.append(future)
all_responses = await asyncio.gather(*responses)
self._message_queue.append(BroadcastResponseMessageEnvelope(all_responses, message_envelope.future))
async def _process_response(self, message_envelope: ResponseMessageEnvelope[T]) -> None:
message_envelope.future.set_result(message_envelope.message)
async def _process_broadcast_response(self, message_envelope: BroadcastResponseMessageEnvelope[T]) -> None:
message_envelope.future.set_result(message_envelope.message)
async def process_next(self) -> None:
if len(self._message_queue) == 0:
@@ -93,6 +114,10 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
asyncio.create_task(self._process_send(SendMessageEnvelope(message, destination, future)))
case BroadcastMessageEnvolope(message, future):
asyncio.create_task(self._process_broadcast(BroadcastMessageEnvolope(message, future)))
case ResponseMessageEnvelope(message, future):
asyncio.create_task(self._process_response(ResponseMessageEnvelope(message, future)))
case BroadcastResponseMessageEnvelope(message, future):
asyncio.create_task(self._process_broadcast_response(BroadcastResponseMessageEnvelope(message, future)))
# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)