mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Implement intervention (#8)
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
from asyncio import Future
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar
|
||||
from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar, cast
|
||||
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
from agnext.core.exceptions import MessageDroppedException
|
||||
from agnext.core.intervention import DropMessage, InterventionHandler
|
||||
|
||||
from ..core.agent import Agent
|
||||
from ..core.agent_runtime import AgentRuntime
|
||||
@@ -11,7 +13,7 @@ from ..core.agent_runtime import AgentRuntime
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class BroadcastMessageEnvelope(Generic[T]):
|
||||
"""A message envelope for broadcasting messages to all agents that can handle
|
||||
the message of the type T."""
|
||||
@@ -19,37 +21,42 @@ class BroadcastMessageEnvelope(Generic[T]):
|
||||
message: T
|
||||
future: Future[List[T]]
|
||||
cancellation_token: CancellationToken
|
||||
sender: Agent[T] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class SendMessageEnvelope(Generic[T]):
|
||||
"""A message envelope for sending a message to a specific agent that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: T
|
||||
destination: Agent[T]
|
||||
sender: Agent[T] | None
|
||||
recipient: Agent[T]
|
||||
future: Future[T]
|
||||
cancellation_token: CancellationToken
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class ResponseMessageEnvelope(Generic[T]):
|
||||
"""A message envelope for sending a response to a message."""
|
||||
|
||||
message: T
|
||||
future: Future[T]
|
||||
sender: Agent[T]
|
||||
recipient: Agent[T] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class BroadcastResponseMessageEnvelope(Generic[T]):
|
||||
"""A message envelope for sending a response to a message."""
|
||||
|
||||
message: List[T]
|
||||
future: Future[List[T]]
|
||||
recipient: Agent[T] | None
|
||||
|
||||
|
||||
class SingleThreadedAgentRuntime(AgentRuntime[T]):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, *, before_send: InterventionHandler[T] | None = None) -> None:
|
||||
self._message_queue: List[
|
||||
BroadcastMessageEnvelope[T]
|
||||
| SendMessageEnvelope[T]
|
||||
@@ -58,6 +65,7 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
|
||||
] = []
|
||||
self._per_type_subscribers: Dict[Type[T], List[Agent[T]]] = {}
|
||||
self._agents: Set[Agent[T]] = set()
|
||||
self._before_send = before_send
|
||||
|
||||
def add_agent(self, agent: Agent[T]) -> None:
|
||||
for message_type in agent.subscriptions:
|
||||
@@ -68,7 +76,12 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
|
||||
|
||||
# Returns the response of the message
|
||||
def send_message(
|
||||
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
|
||||
self,
|
||||
message: T,
|
||||
recipient: Agent[T],
|
||||
*,
|
||||
sender: Agent[T] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[T]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
@@ -76,21 +89,35 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
|
||||
loop = asyncio.get_event_loop()
|
||||
future: Future[T] = loop.create_future()
|
||||
|
||||
self._message_queue.append(SendMessageEnvelope(message, destination, future, cancellation_token))
|
||||
self._message_queue.append(
|
||||
SendMessageEnvelope(
|
||||
message=message,
|
||||
recipient=recipient,
|
||||
future=future,
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
)
|
||||
)
|
||||
|
||||
return future
|
||||
|
||||
# Returns the response of all handling agents
|
||||
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
|
||||
def broadcast_message(
|
||||
self, message: T, *, sender: Agent[T] | None = None, cancellation_token: CancellationToken | None = None
|
||||
) -> Future[List[T]]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
future: Future[List[T]] = asyncio.get_event_loop().create_future()
|
||||
self._message_queue.append(BroadcastMessageEnvelope(message, future, cancellation_token))
|
||||
self._message_queue.append(
|
||||
BroadcastMessageEnvelope(
|
||||
message=message, future=future, cancellation_token=cancellation_token, sender=sender
|
||||
)
|
||||
)
|
||||
return future
|
||||
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None:
|
||||
recipient = message_envelope.destination
|
||||
recipient = message_envelope.recipient
|
||||
if recipient not in self._agents:
|
||||
message_envelope.future.set_exception(Exception("Recipient not found"))
|
||||
return
|
||||
@@ -103,7 +130,14 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
|
||||
message_envelope.future.set_exception(e)
|
||||
return
|
||||
|
||||
self._message_queue.append(ResponseMessageEnvelope(response, message_envelope.future))
|
||||
self._message_queue.append(
|
||||
ResponseMessageEnvelope(
|
||||
message=response,
|
||||
future=message_envelope.future,
|
||||
sender=message_envelope.recipient,
|
||||
recipient=message_envelope.sender,
|
||||
)
|
||||
)
|
||||
|
||||
async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T]) -> None:
|
||||
responses: List[Awaitable[T]] = []
|
||||
@@ -117,7 +151,11 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
|
||||
message_envelope.future.set_exception(e)
|
||||
return
|
||||
|
||||
self._message_queue.append(BroadcastResponseMessageEnvelope(all_responses, message_envelope.future))
|
||||
self._message_queue.append(
|
||||
BroadcastResponseMessageEnvelope(
|
||||
message=all_responses, future=message_envelope.future, recipient=message_envelope.sender
|
||||
)
|
||||
)
|
||||
|
||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope[T]) -> None:
|
||||
message_envelope.future.set_result(message_envelope.message)
|
||||
@@ -134,18 +172,51 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
|
||||
message_envelope = self._message_queue.pop(0)
|
||||
|
||||
match message_envelope:
|
||||
case SendMessageEnvelope(message, destination, future, cancellation_token):
|
||||
asyncio.create_task(
|
||||
self._process_send(SendMessageEnvelope(message, destination, future, cancellation_token))
|
||||
)
|
||||
case BroadcastMessageEnvelope(message, future, cancellation_token):
|
||||
asyncio.create_task(
|
||||
self._process_broadcast(BroadcastMessageEnvelope(message, future, cancellation_token))
|
||||
)
|
||||
case ResponseMessageEnvelope(message, future):
|
||||
asyncio.create_task(self._process_response(ResponseMessageEnvelope(message, future)))
|
||||
case BroadcastResponseMessageEnvelope(message, future):
|
||||
asyncio.create_task(self._process_broadcast_response(BroadcastResponseMessageEnvelope(message, future)))
|
||||
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
|
||||
message_envelope.message = cast(T, temp_message)
|
||||
|
||||
asyncio.create_task(self._process_send(message_envelope))
|
||||
case BroadcastMessageEnvelope(
|
||||
message=message,
|
||||
sender=sender,
|
||||
future=future,
|
||||
):
|
||||
if self._before_send is not None:
|
||||
temp_message = await self._before_send.on_broadcast(message, sender=sender)
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
|
||||
message_envelope.message = cast(T, temp_message)
|
||||
|
||||
asyncio.create_task(self._process_broadcast(message_envelope))
|
||||
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
|
||||
message_envelope.message = cast(T, temp_message)
|
||||
|
||||
asyncio.create_task(self._process_response(message_envelope))
|
||||
|
||||
case BroadcastResponseMessageEnvelope(message=message, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
temp_message_list = await self._before_send.on_broadcast_response(message, recipient=recipient)
|
||||
if temp_message_list is DropMessage or isinstance(temp_message_list, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
|
||||
message_envelope.message = list(temp_message_list) # type: ignore
|
||||
|
||||
asyncio.create_task(self._process_broadcast_response(message_envelope))
|
||||
|
||||
# Yield control to the message loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@@ -14,8 +14,15 @@ class AgentRuntime(Protocol[T]):
|
||||
|
||||
# Returns the response of the message
|
||||
def send_message(
|
||||
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
|
||||
self,
|
||||
message: T,
|
||||
recipient: Agent[T],
|
||||
*,
|
||||
sender: Agent[T] | None = None,
|
||||
cancellation_token: CancellationToken | None,
|
||||
) -> Future[T]: ...
|
||||
|
||||
# Returns the response of all handling agents
|
||||
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: ...
|
||||
def broadcast_message(
|
||||
self, message: T, *, sender: Agent[T] | None = None, cancellation_token: CancellationToken | None = None
|
||||
) -> Future[List[T]]: ...
|
||||
|
||||
@@ -28,17 +28,19 @@ class BaseAgent(ABC, Agent[T]):
|
||||
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ...
|
||||
|
||||
def _send_message(
|
||||
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
|
||||
self, message: T, recipient: Agent[T], cancellation_token: CancellationToken | None = None
|
||||
) -> Future[T]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
future = self._router.send_message(message, destination, cancellation_token)
|
||||
future = self._router.send_message(
|
||||
message, sender=self, recipient=recipient, cancellation_token=cancellation_token
|
||||
)
|
||||
cancellation_token.link_future(future)
|
||||
return future
|
||||
|
||||
def _broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
future = self._router.broadcast_message(message, cancellation_token)
|
||||
future = self._router.broadcast_message(message, sender=self, cancellation_token=cancellation_token)
|
||||
cancellation_token.link_future(future)
|
||||
return future
|
||||
|
||||
@@ -4,3 +4,7 @@ class CantHandleException(Exception):
|
||||
|
||||
class UndeliverableException(Exception):
|
||||
"""Raised when a message can't be delivered."""
|
||||
|
||||
|
||||
class MessageDroppedException(Exception):
|
||||
"""Raised when a message is dropped."""
|
||||
|
||||
39
src/agnext/core/intervention.py
Normal file
39
src/agnext/core/intervention.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Awaitable, Callable, Protocol, Sequence, TypeVar, final
|
||||
|
||||
from agnext.core.agent import Agent
|
||||
|
||||
|
||||
@final
|
||||
class DropMessage: ...
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
InterventionFunction = Callable[[T], T | Awaitable[type[DropMessage]]]
|
||||
|
||||
|
||||
class InterventionHandler(Protocol[T]):
|
||||
async def on_send(self, message: T, *, sender: Agent[T] | None, recipient: Agent[T]) -> T | type[DropMessage]: ...
|
||||
async def on_broadcast(self, message: T, *, sender: Agent[T] | None) -> T | type[DropMessage]: ...
|
||||
async def on_response(
|
||||
self, message: T, *, sender: Agent[T], recipient: Agent[T] | None
|
||||
) -> T | type[DropMessage]: ...
|
||||
async def on_broadcast_response(
|
||||
self, message: Sequence[T], *, recipient: Agent[T] | None
|
||||
) -> Sequence[T] | type[DropMessage]: ...
|
||||
|
||||
|
||||
class DefaultInterventionHandler(InterventionHandler[T]):
|
||||
async def on_send(self, message: T, *, sender: Agent[T] | None, recipient: Agent[T]) -> T | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_broadcast(self, message: T, *, sender: Agent[T] | None) -> T | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_response(self, message: T, *, sender: Agent[T], recipient: Agent[T] | None) -> T | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_broadcast_response(
|
||||
self, message: Sequence[T], *, recipient: Agent[T] | None
|
||||
) -> Sequence[T] | type[DropMessage]:
|
||||
return message
|
||||
2
test.sh
2
test.sh
@@ -19,3 +19,5 @@ echo "--- Running pyright ---"
|
||||
pyright
|
||||
echo "--- Running mypy ---"
|
||||
mypy
|
||||
echo "--- Running pytest ---"
|
||||
pytest
|
||||
@@ -58,7 +58,7 @@ async def test_cancellation_with_token() -> None:
|
||||
|
||||
long_running = LongRunningAgent("name", router)
|
||||
token = CancellationToken()
|
||||
response = router.send_message(MessageType(), long_running, token)
|
||||
response = router.send_message(MessageType(), recipient=long_running, cancellation_token=token)
|
||||
assert not response.done()
|
||||
|
||||
await router.process_next()
|
||||
@@ -81,7 +81,7 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
||||
nested = NestingLongRunningAgent("nested", router, long_running)
|
||||
|
||||
token = CancellationToken()
|
||||
response = router.send_message(MessageType(), nested, token)
|
||||
response = router.send_message(MessageType(), nested, cancellation_token=token)
|
||||
assert not response.done()
|
||||
|
||||
await router.process_next()
|
||||
@@ -104,7 +104,7 @@ async def test_nested_cancellation_inner_called() -> None:
|
||||
nested = NestingLongRunningAgent("nested", router, long_running)
|
||||
|
||||
token = CancellationToken()
|
||||
response = router.send_message(MessageType(), nested, token)
|
||||
response = router.send_message(MessageType(), nested, cancellation_token=token)
|
||||
assert not response.done()
|
||||
|
||||
await router.process_next()
|
||||
|
||||
91
tests/test_intervention.py
Normal file
91
tests/test_intervention.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from agnext.core.agent import Agent
|
||||
from agnext.core.agent_runtime import AgentRuntime
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
from agnext.core.exceptions import MessageDroppedException
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
|
||||
@dataclass
|
||||
class MessageType:
|
||||
...
|
||||
|
||||
class LoopbackAgent(TypeRoutedAgent[MessageType]):
|
||||
def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None:
|
||||
super().__init__(name, router)
|
||||
self.num_calls = 0
|
||||
|
||||
|
||||
@message_handler(MessageType)
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.num_calls += 1
|
||||
return message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_count_messages() -> None:
|
||||
|
||||
class DebugInterventionHandler(DefaultInterventionHandler[MessageType]):
|
||||
def __init__(self):
|
||||
self.num_messages = 0
|
||||
|
||||
async def on_send(self, message: MessageType, *, sender: Agent[MessageType] | None, recipient: Agent[MessageType]) -> MessageType:
|
||||
self.num_messages += 1
|
||||
return message
|
||||
|
||||
handler = DebugInterventionHandler()
|
||||
router = SingleThreadedAgentRuntime[MessageType](before_send=handler)
|
||||
|
||||
long_running = LoopbackAgent("name", router)
|
||||
response = router.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
while not response.done():
|
||||
await router.process_next()
|
||||
|
||||
assert handler.num_messages == 1
|
||||
assert long_running.num_calls == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_drop_send() -> None:
|
||||
|
||||
class DropSendInterventionHandler(DefaultInterventionHandler[MessageType]):
|
||||
async def on_send(self, message: MessageType, *, sender: Agent[MessageType] | None, recipient: Agent[MessageType]) -> MessageType | type[DropMessage]:
|
||||
return DropMessage
|
||||
|
||||
handler = DropSendInterventionHandler()
|
||||
router = SingleThreadedAgentRuntime[MessageType](before_send=handler)
|
||||
|
||||
long_running = LoopbackAgent("name", router)
|
||||
response = router.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
while not response.done():
|
||||
await router.process_next()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
await response
|
||||
|
||||
assert long_running.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_drop_response() -> None:
|
||||
|
||||
class DropResponseInterventionHandler(DefaultInterventionHandler[MessageType]):
|
||||
async def on_response(self, message: MessageType, *, sender: Agent[MessageType], recipient: Agent[MessageType] | None) -> MessageType | type[DropMessage]:
|
||||
return DropMessage
|
||||
|
||||
handler = DropResponseInterventionHandler()
|
||||
router = SingleThreadedAgentRuntime[MessageType](before_send=handler)
|
||||
|
||||
long_running = LoopbackAgent("name", router)
|
||||
response = router.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
while not response.done():
|
||||
await router.process_next()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
await response
|
||||
|
||||
assert long_running.num_calls == 1
|
||||
Reference in New Issue
Block a user