Update send_message to be a single async operation. Add start helper to runtime to manage this (#165)

This commit is contained in:
Jack Gerrits
2024-07-01 11:53:45 -04:00
committed by GitHub
parent 28f11c726d
commit 766635394a
29 changed files with 170 additions and 124 deletions

View File

@@ -1,11 +1,14 @@
from __future__ import annotations
import asyncio
import inspect
import logging
import threading
from asyncio import CancelledError, Future
from asyncio import CancelledError, Future, Task
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
from ..core import (
@@ -80,6 +83,40 @@ class Counter:
self.threadLock.release()
class RunContext:
class RunState(Enum):
RUNNING = 0
CANCELLED = 1
UNTIL_IDLE = 2
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
self._runtime = runtime
self._run_state = RunContext.RunState.RUNNING
self._run_task = asyncio.create_task(self._run())
self._lock = asyncio.Lock()
async def _run(self) -> None:
while True:
async with self._lock:
if self._run_state == RunContext.RunState.CANCELLED:
return
elif self._run_state == RunContext.RunState.UNTIL_IDLE:
if self._runtime.idle:
return
await self._runtime.process_next()
async def stop(self) -> None:
async with self._lock:
self._run_state = RunContext.RunState.CANCELLED
await self._run_task
async def stop_when_idle(self) -> None:
async with self._lock:
self._run_state = RunContext.RunState.UNTIL_IDLE
await self._run_task
class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None:
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
@@ -90,6 +127,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._intervention_handler = intervention_handler
self._known_namespaces: set[str] = set()
self._outstanding_tasks = Counter()
self._background_tasks: Set[Task[Any]] = set()
@property
def unprocessed_messages(
@@ -113,7 +151,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]:
) -> Any:
if cancellation_token is None:
cancellation_token = CancellationToken()
@@ -149,7 +187,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
)
)
return future
cancellation_token.link_future(future)
return await future
async def publish_message(
self,
@@ -334,7 +374,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message_envelope.message = temp_message
self._outstanding_tasks.increment()
asyncio.create_task(self._process_send(message_envelope))
task = asyncio.create_task(self._process_send(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case PublishMessageEnvelope(
message=message,
sender=sender,
@@ -352,7 +394,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message_envelope.message = temp_message
self._outstanding_tasks.increment()
asyncio.create_task(self._process_publish(message_envelope))
task = asyncio.create_task(self._process_publish(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._intervention_handler is not None:
try:
@@ -369,16 +413,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message_envelope.message = temp_message
self._outstanding_tasks.increment()
asyncio.create_task(self._process_response(message_envelope))
task = asyncio.create_task(self._process_response(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)
async def process_until_idle(self) -> None:
"""Process messages until there is no unprocessed message and no message currently being processed."""
@property
def idle(self) -> bool:
return len(self._message_queue) == 0 and self._outstanding_tasks.get() == 0
while len(self.unprocessed_messages) > 0 or self.outstanding_tasks > 0:
await self.process_next()
def start(self) -> RunContext:
return RunContext(self)
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
return self._get_agent(agent).metadata

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
from asyncio import Future
from typing import TYPE_CHECKING, Any, Mapping
from ._agent_id import AgentId
@@ -32,7 +31,7 @@ class AgentProxy:
*,
sender: AgentId,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]:
) -> Any:
return await self._runtime.send_message(
message,
recipient=self._agent,

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
from asyncio import Future
from contextvars import ContextVar
from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable
@@ -19,8 +18,6 @@ agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextV
@runtime_checkable
class AgentRuntime(Protocol):
# Returns the response of the message
# Can raise CantHandleException
async def send_message(
self,
message: Any,
@@ -28,17 +25,8 @@ class AgentRuntime(Protocol):
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]:
"""Send a message to an agent and return a future that will resolve to the response.
The act of sending a message may be asynchronous, and the response to the message itself is also asynchronous. For example:
.. code-block:: python
response_future = await runtime.send_message(MyMessage("Hello"), recipient=agent_id)
response = await response_future
The returned future only needs to be awaited if the response is needed. If the response is not needed, the future can be ignored.
) -> Any:
"""Send a message to an agent and get a response.
Args:
message (Any): The message to send.
@@ -49,14 +37,14 @@ class AgentRuntime(Protocol):
Raises:
CantHandleException: If the recipient cannot handle the message.
UndeliverableException: If the message cannot be delivered.
Other: Any other exception raised by the recipient.
Returns:
Future[Any]: A future that will resolve to the response of the message.
Any: The response from the agent.
"""
...
# No responses from publishing
async def publish_message(
self,
message: Any,

View File

@@ -1,6 +1,5 @@
import warnings
from abc import ABC, abstractmethod
from asyncio import Future
from typing import Any, Mapping, Sequence
from ._agent import Agent
@@ -55,19 +54,17 @@ class BaseAgent(ABC, Agent):
recipient: AgentId,
*,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]:
) -> Any:
"""See :py:meth:`agnext.core.AgentRuntime.send_message` for more information."""
if cancellation_token is None:
cancellation_token = CancellationToken()
future = await self._runtime.send_message(
return await self._runtime.send_message(
message,
sender=self.id,
recipient=recipient,
cancellation_token=cancellation_token,
)
cancellation_token.link_future(future)
return future
async def publish_message(
self,