Implement intervention (#8)

This commit is contained in:
Jack Gerrits
2024-05-20 17:30:45 -06:00
committed by GitHub
parent 5afbadbe43
commit 77c8cca9ae
8 changed files with 250 additions and 34 deletions

View File

@@ -1,9 +1,11 @@
import asyncio
from asyncio import Future
from dataclasses import dataclass
from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar
from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar, cast
from agnext.core.cancellation_token import CancellationToken
from agnext.core.exceptions import MessageDroppedException
from agnext.core.intervention import DropMessage, InterventionHandler
from ..core.agent import Agent
from ..core.agent_runtime import AgentRuntime
@@ -11,7 +13,7 @@ from ..core.agent_runtime import AgentRuntime
T = TypeVar("T")
@dataclass
@dataclass(kw_only=True)
class BroadcastMessageEnvelope(Generic[T]):
"""A message envelope for broadcasting messages to all agents that can handle
the message of the type T."""
@@ -19,37 +21,42 @@ class BroadcastMessageEnvelope(Generic[T]):
message: T
future: Future[List[T]]
cancellation_token: CancellationToken
sender: Agent[T] | None
@dataclass
@dataclass(kw_only=True)
class SendMessageEnvelope(Generic[T]):
"""A message envelope for sending a message to a specific agent that can handle
the message of the type T."""
message: T
destination: Agent[T]
sender: Agent[T] | None
recipient: Agent[T]
future: Future[T]
cancellation_token: CancellationToken
@dataclass
@dataclass(kw_only=True)
class ResponseMessageEnvelope(Generic[T]):
"""A message envelope for sending a response to a message."""
message: T
future: Future[T]
sender: Agent[T]
recipient: Agent[T] | None
@dataclass
@dataclass(kw_only=True)
class BroadcastResponseMessageEnvelope(Generic[T]):
"""A message envelope for sending a response to a message."""
message: List[T]
future: Future[List[T]]
recipient: Agent[T] | None
class SingleThreadedAgentRuntime(AgentRuntime[T]):
def __init__(self) -> None:
def __init__(self, *, before_send: InterventionHandler[T] | None = None) -> None:
self._message_queue: List[
BroadcastMessageEnvelope[T]
| SendMessageEnvelope[T]
@@ -58,6 +65,7 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
] = []
self._per_type_subscribers: Dict[Type[T], List[Agent[T]]] = {}
self._agents: Set[Agent[T]] = set()
self._before_send = before_send
def add_agent(self, agent: Agent[T]) -> None:
for message_type in agent.subscriptions:
@@ -68,7 +76,12 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
# Returns the response of the message
def send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
self,
message: T,
recipient: Agent[T],
*,
sender: Agent[T] | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[T]:
if cancellation_token is None:
cancellation_token = CancellationToken()
@@ -76,21 +89,35 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
loop = asyncio.get_event_loop()
future: Future[T] = loop.create_future()
self._message_queue.append(SendMessageEnvelope(message, destination, future, cancellation_token))
self._message_queue.append(
SendMessageEnvelope(
message=message,
recipient=recipient,
future=future,
cancellation_token=cancellation_token,
sender=sender,
)
)
return future
# Returns the response of all handling agents
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
def broadcast_message(
self, message: T, *, sender: Agent[T] | None = None, cancellation_token: CancellationToken | None = None
) -> Future[List[T]]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future: Future[List[T]] = asyncio.get_event_loop().create_future()
self._message_queue.append(BroadcastMessageEnvelope(message, future, cancellation_token))
self._message_queue.append(
BroadcastMessageEnvelope(
message=message, future=future, cancellation_token=cancellation_token, sender=sender
)
)
return future
async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None:
recipient = message_envelope.destination
recipient = message_envelope.recipient
if recipient not in self._agents:
message_envelope.future.set_exception(Exception("Recipient not found"))
return
@@ -103,7 +130,14 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
message_envelope.future.set_exception(e)
return
self._message_queue.append(ResponseMessageEnvelope(response, message_envelope.future))
self._message_queue.append(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
)
)
async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T]) -> None:
responses: List[Awaitable[T]] = []
@@ -117,7 +151,11 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
message_envelope.future.set_exception(e)
return
self._message_queue.append(BroadcastResponseMessageEnvelope(all_responses, message_envelope.future))
self._message_queue.append(
BroadcastResponseMessageEnvelope(
message=all_responses, future=message_envelope.future, recipient=message_envelope.sender
)
)
async def _process_response(self, message_envelope: ResponseMessageEnvelope[T]) -> None:
message_envelope.future.set_result(message_envelope.message)
@@ -134,18 +172,51 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
message_envelope = self._message_queue.pop(0)
match message_envelope:
case SendMessageEnvelope(message, destination, future, cancellation_token):
asyncio.create_task(
self._process_send(SendMessageEnvelope(message, destination, future, cancellation_token))
)
case BroadcastMessageEnvelope(message, future, cancellation_token):
asyncio.create_task(
self._process_broadcast(BroadcastMessageEnvelope(message, future, cancellation_token))
)
case ResponseMessageEnvelope(message, future):
asyncio.create_task(self._process_response(ResponseMessageEnvelope(message, future)))
case BroadcastResponseMessageEnvelope(message, future):
asyncio.create_task(self._process_broadcast_response(BroadcastResponseMessageEnvelope(message, future)))
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = cast(T, temp_message)
asyncio.create_task(self._process_send(message_envelope))
case BroadcastMessageEnvelope(
message=message,
sender=sender,
future=future,
):
if self._before_send is not None:
temp_message = await self._before_send.on_broadcast(message, sender=sender)
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = cast(T, temp_message)
asyncio.create_task(self._process_broadcast(message_envelope))
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = cast(T, temp_message)
asyncio.create_task(self._process_response(message_envelope))
case BroadcastResponseMessageEnvelope(message=message, recipient=recipient, future=future):
if self._before_send is not None:
temp_message_list = await self._before_send.on_broadcast_response(message, recipient=recipient)
if temp_message_list is DropMessage or isinstance(temp_message_list, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = list(temp_message_list) # type: ignore
asyncio.create_task(self._process_broadcast_response(message_envelope))
# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)

View File

@@ -14,8 +14,15 @@ class AgentRuntime(Protocol[T]):
# Returns the response of the message
def send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
self,
message: T,
recipient: Agent[T],
*,
sender: Agent[T] | None = None,
cancellation_token: CancellationToken | None,
) -> Future[T]: ...
# Returns the response of all handling agents
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: ...
def broadcast_message(
self, message: T, *, sender: Agent[T] | None = None, cancellation_token: CancellationToken | None = None
) -> Future[List[T]]: ...

View File

@@ -28,17 +28,19 @@ class BaseAgent(ABC, Agent[T]):
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ...
def _send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
self, message: T, recipient: Agent[T], cancellation_token: CancellationToken | None = None
) -> Future[T]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._router.send_message(message, destination, cancellation_token)
future = self._router.send_message(
message, sender=self, recipient=recipient, cancellation_token=cancellation_token
)
cancellation_token.link_future(future)
return future
def _broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._router.broadcast_message(message, cancellation_token)
future = self._router.broadcast_message(message, sender=self, cancellation_token=cancellation_token)
cancellation_token.link_future(future)
return future

View File

@@ -4,3 +4,7 @@ class CantHandleException(Exception):
class UndeliverableException(Exception):
"""Raised when a message can't be delivered."""
class MessageDroppedException(Exception):
"""Raised when a message is dropped."""

View File

@@ -0,0 +1,39 @@
from typing import Awaitable, Callable, Protocol, Sequence, TypeVar, final
from agnext.core.agent import Agent
@final
class DropMessage: ...
T = TypeVar("T")
InterventionFunction = Callable[[T], T | Awaitable[type[DropMessage]]]
class InterventionHandler(Protocol[T]):
async def on_send(self, message: T, *, sender: Agent[T] | None, recipient: Agent[T]) -> T | type[DropMessage]: ...
async def on_broadcast(self, message: T, *, sender: Agent[T] | None) -> T | type[DropMessage]: ...
async def on_response(
self, message: T, *, sender: Agent[T], recipient: Agent[T] | None
) -> T | type[DropMessage]: ...
async def on_broadcast_response(
self, message: Sequence[T], *, recipient: Agent[T] | None
) -> Sequence[T] | type[DropMessage]: ...
class DefaultInterventionHandler(InterventionHandler[T]):
async def on_send(self, message: T, *, sender: Agent[T] | None, recipient: Agent[T]) -> T | type[DropMessage]:
return message
async def on_broadcast(self, message: T, *, sender: Agent[T] | None) -> T | type[DropMessage]:
return message
async def on_response(self, message: T, *, sender: Agent[T], recipient: Agent[T] | None) -> T | type[DropMessage]:
return message
async def on_broadcast_response(
self, message: Sequence[T], *, recipient: Agent[T] | None
) -> Sequence[T] | type[DropMessage]:
return message

View File

@@ -19,3 +19,5 @@ echo "--- Running pyright ---"
pyright
echo "--- Running mypy ---"
mypy
echo "--- Running pytest ---"
pytest

View File

@@ -58,7 +58,7 @@ async def test_cancellation_with_token() -> None:
long_running = LongRunningAgent("name", router)
token = CancellationToken()
response = router.send_message(MessageType(), long_running, token)
response = router.send_message(MessageType(), recipient=long_running, cancellation_token=token)
assert not response.done()
await router.process_next()
@@ -81,7 +81,7 @@ async def test_nested_cancellation_only_outer_called() -> None:
nested = NestingLongRunningAgent("nested", router, long_running)
token = CancellationToken()
response = router.send_message(MessageType(), nested, token)
response = router.send_message(MessageType(), nested, cancellation_token=token)
assert not response.done()
await router.process_next()
@@ -104,7 +104,7 @@ async def test_nested_cancellation_inner_called() -> None:
nested = NestingLongRunningAgent("nested", router, long_running)
token = CancellationToken()
response = router.send_message(MessageType(), nested, token)
response = router.send_message(MessageType(), nested, cancellation_token=token)
assert not response.done()
await router.process_next()

View File

@@ -0,0 +1,91 @@
import pytest
from dataclasses import dataclass
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.core.agent import Agent
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.cancellation_token import CancellationToken
from agnext.core.exceptions import MessageDroppedException
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
@dataclass
class MessageType:
...
class LoopbackAgent(TypeRoutedAgent[MessageType]):
def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None:
super().__init__(name, router)
self.num_calls = 0
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.num_calls += 1
return message
@pytest.mark.asyncio
async def test_intervention_count_messages() -> None:
class DebugInterventionHandler(DefaultInterventionHandler[MessageType]):
def __init__(self):
self.num_messages = 0
async def on_send(self, message: MessageType, *, sender: Agent[MessageType] | None, recipient: Agent[MessageType]) -> MessageType:
self.num_messages += 1
return message
handler = DebugInterventionHandler()
router = SingleThreadedAgentRuntime[MessageType](before_send=handler)
long_running = LoopbackAgent("name", router)
response = router.send_message(MessageType(), recipient=long_running)
while not response.done():
await router.process_next()
assert handler.num_messages == 1
assert long_running.num_calls == 1
@pytest.mark.asyncio
async def test_intervention_drop_send() -> None:
class DropSendInterventionHandler(DefaultInterventionHandler[MessageType]):
async def on_send(self, message: MessageType, *, sender: Agent[MessageType] | None, recipient: Agent[MessageType]) -> MessageType | type[DropMessage]:
return DropMessage
handler = DropSendInterventionHandler()
router = SingleThreadedAgentRuntime[MessageType](before_send=handler)
long_running = LoopbackAgent("name", router)
response = router.send_message(MessageType(), recipient=long_running)
while not response.done():
await router.process_next()
with pytest.raises(MessageDroppedException):
await response
assert long_running.num_calls == 0
@pytest.mark.asyncio
async def test_intervention_drop_response() -> None:
class DropResponseInterventionHandler(DefaultInterventionHandler[MessageType]):
async def on_response(self, message: MessageType, *, sender: Agent[MessageType], recipient: Agent[MessageType] | None) -> MessageType | type[DropMessage]:
return DropMessage
handler = DropResponseInterventionHandler()
router = SingleThreadedAgentRuntime[MessageType](before_send=handler)
long_running = LoopbackAgent("name", router)
response = router.send_message(MessageType(), recipient=long_running)
while not response.done():
await router.process_next()
with pytest.raises(MessageDroppedException):
await response
assert long_running.num_calls == 1