mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Update send_message to be a single async operation. Add start helper to runtime to manage this (#165)
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from asyncio import CancelledError, Future
|
||||
from asyncio import CancelledError, Future, Task
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
|
||||
|
||||
from ..core import (
|
||||
@@ -80,6 +83,40 @@ class Counter:
|
||||
self.threadLock.release()
|
||||
|
||||
|
||||
class RunContext:
|
||||
class RunState(Enum):
|
||||
RUNNING = 0
|
||||
CANCELLED = 1
|
||||
UNTIL_IDLE = 2
|
||||
|
||||
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
|
||||
self._runtime = runtime
|
||||
self._run_state = RunContext.RunState.RUNNING
|
||||
self._run_task = asyncio.create_task(self._run())
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _run(self) -> None:
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._run_state == RunContext.RunState.CANCELLED:
|
||||
return
|
||||
elif self._run_state == RunContext.RunState.UNTIL_IDLE:
|
||||
if self._runtime.idle:
|
||||
return
|
||||
|
||||
await self._runtime.process_next()
|
||||
|
||||
async def stop(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.CANCELLED
|
||||
await self._run_task
|
||||
|
||||
async def stop_when_idle(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.UNTIL_IDLE
|
||||
await self._run_task
|
||||
|
||||
|
||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None:
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
@@ -90,6 +127,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._intervention_handler = intervention_handler
|
||||
self._known_namespaces: set[str] = set()
|
||||
self._outstanding_tasks = Counter()
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
|
||||
@property
|
||||
def unprocessed_messages(
|
||||
@@ -113,7 +151,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any | None]:
|
||||
) -> Any:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
@@ -149,7 +187,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
)
|
||||
)
|
||||
|
||||
return future
|
||||
cancellation_token.link_future(future)
|
||||
|
||||
return await future
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
@@ -334,7 +374,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_send(message_envelope))
|
||||
task = asyncio.create_task(self._process_send(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case PublishMessageEnvelope(
|
||||
message=message,
|
||||
sender=sender,
|
||||
@@ -352,7 +394,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_publish(message_envelope))
|
||||
task = asyncio.create_task(self._process_publish(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._intervention_handler is not None:
|
||||
try:
|
||||
@@ -369,16 +413,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_response(message_envelope))
|
||||
task = asyncio.create_task(self._process_response(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Yield control to the message loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def process_until_idle(self) -> None:
|
||||
"""Process messages until there is no unprocessed message and no message currently being processed."""
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return len(self._message_queue) == 0 and self._outstanding_tasks.get() == 0
|
||||
|
||||
while len(self.unprocessed_messages) > 0 or self.outstanding_tasks > 0:
|
||||
await self.process_next()
|
||||
def start(self) -> RunContext:
|
||||
return RunContext(self)
|
||||
|
||||
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
return self._get_agent(agent).metadata
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import Future
|
||||
from typing import TYPE_CHECKING, Any, Mapping
|
||||
|
||||
from ._agent_id import AgentId
|
||||
@@ -32,7 +31,7 @@ class AgentProxy:
|
||||
*,
|
||||
sender: AgentId,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any]:
|
||||
) -> Any:
|
||||
return await self._runtime.send_message(
|
||||
message,
|
||||
recipient=self._agent,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import Future
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable
|
||||
|
||||
@@ -19,8 +18,6 @@ agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextV
|
||||
|
||||
@runtime_checkable
|
||||
class AgentRuntime(Protocol):
|
||||
# Returns the response of the message
|
||||
# Can raise CantHandleException
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
@@ -28,17 +25,8 @@ class AgentRuntime(Protocol):
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any]:
|
||||
"""Send a message to an agent and return a future that will resolve to the response.
|
||||
|
||||
The act of sending a message may be asynchronous, and the response to the message itself is also asynchronous. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
response_future = await runtime.send_message(MyMessage("Hello"), recipient=agent_id)
|
||||
response = await response_future
|
||||
|
||||
The returned future only needs to be awaited if the response is needed. If the response is not needed, the future can be ignored.
|
||||
) -> Any:
|
||||
"""Send a message to an agent and get a response.
|
||||
|
||||
Args:
|
||||
message (Any): The message to send.
|
||||
@@ -49,14 +37,14 @@ class AgentRuntime(Protocol):
|
||||
Raises:
|
||||
CantHandleException: If the recipient cannot handle the message.
|
||||
UndeliverableException: If the message cannot be delivered.
|
||||
Other: Any other exception raised by the recipient.
|
||||
|
||||
Returns:
|
||||
Future[Any]: A future that will resolve to the response of the message.
|
||||
Any: The response from the agent.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
# No responses from publishing
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Future
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from ._agent import Agent
|
||||
@@ -55,19 +54,17 @@ class BaseAgent(ABC, Agent):
|
||||
recipient: AgentId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any]:
|
||||
) -> Any:
|
||||
"""See :py:meth:`agnext.core.AgentRuntime.send_message` for more information."""
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
future = await self._runtime.send_message(
|
||||
return await self._runtime.send_message(
|
||||
message,
|
||||
sender=self.id,
|
||||
recipient=recipient,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
cancellation_token.link_future(future)
|
||||
return future
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user