diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index e92b6a0e7..9a10302b9 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -51,12 +51,15 @@ jobs: - run: pip install ".[dev]" - uses: jakebailey/pyright-action@v2 - # test: - # runs-on: ubuntu-latest - # steps: - # - uses: actions/checkout@v4 - # - uses: actions/setup-python@v5 - # with: - # python-version: '3.10' - # - run: pip install ".[dev]" - # - run: pytest + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["pypy3.10", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - run: pip install ".[dev]" + - run: pytest diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..9b388533a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/examples/futures.py b/examples/futures.py index 781cbfc38..70cf9ee43 100644 --- a/examples/futures.py +++ b/examples/futures.py @@ -5,11 +5,10 @@ from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_h from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime from agnext.core.agent import Agent from agnext.core.agent_runtime import AgentRuntime -from agnext.core.message import Message @dataclass -class MessageType(Message): +class MessageType: body: str sender: str diff --git a/examples/round_robin_chat.py b/examples/round_robin_chat.py index 220965b45..7f60538ce 100644 --- a/examples/round_robin_chat.py +++ b/examples/round_robin_chat.py @@ -1,18 +1,18 @@ -from dataclasses import dataclass -import random import asyncio +import random +from dataclasses import dataclass from typing import List + from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime from agnext.core.agent_runtime import AgentRuntime -from agnext.core.message import Message # TODO: a runtime should be able to handle multiple types of messages # TODO: allow request and response to be different message types # should support this in handlers. @dataclass -class GroupChatMessage(Message): +class GroupChatMessage: body: str sender: str require_response: bool diff --git a/pyproject.toml b/pyproject.toml index 64235cea6..69f6ab6c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,15 +13,10 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dependencies = [ - "openai>=1.3", - "pillow", - "aiohttp", - "typing-extensions" -] +dependencies = ["openai>=1.3", "pillow", "aiohttp", "typing-extensions"] [project.optional-dependencies] -dev = ["ruff", "pyright", "mypy", "pytest", "types-Pillow"] +dev = ["ruff", "pyright", "mypy", "pytest", "pytest-asyncio", "types-Pillow"] [tool.setuptools.package-data] agnext = ["py.typed"] @@ -61,3 +56,7 @@ include = ["src", "examples"] typeCheckingMode = "strict" reportUnnecessaryIsInstance = false reportMissingTypeStubs = false + +[tool.pytest.ini_options] +minversion = "6.0" +testpaths = ["tests"] diff --git a/src/agnext/agent_components/type_routed_agent.py b/src/agnext/agent_components/type_routed_agent.py index b694a9b13..92eb58169 100644 --- a/src/agnext/agent_components/type_routed_agent.py +++ b/src/agnext/agent_components/type_routed_agent.py @@ -2,11 +2,10 @@ from typing import Any, Awaitable, Callable, Dict, Sequence, Type, TypeVar from agnext.core.agent_runtime import AgentRuntime from agnext.core.base_agent import BaseAgent +from agnext.core.cancellation_token import CancellationToken from agnext.core.exceptions import CantHandleException -from ..core.message import Message - -T = TypeVar("T", bound=Message) +T = TypeVar("T") # NOTE: this works on concrete types and not inheritance @@ -22,7 +21,7 @@ class TypeRoutedAgent(BaseAgent[T]): def __init__(self, name: str, router: AgentRuntime[T]) -> None: super().__init__(name, router) - self._handlers: Dict[Type[Any], Callable[[T], Awaitable[T]]] = {} + self._handlers: Dict[Type[Any], Callable[[T, CancellationToken], Awaitable[T]]] = {} router.add_agent(self) @@ -37,12 +36,12 @@ class TypeRoutedAgent(BaseAgent[T]): def subscriptions(self) -> Sequence[Type[T]]: return list(self._handlers.keys()) - async def on_message(self, message: T) -> T: + async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: handler = self._handlers.get(type(message)) if handler is not None: - return await handler(message) + return await handler(message, cancellation_token) else: - return await self.on_unhandled_message(message) + return await self.on_unhandled_message(message, cancellation_token) - async def on_unhandled_message(self, message: T) -> T: + async def on_unhandled_message(self, message: T, cancellation_token: CancellationToken) -> T: raise CantHandleException() diff --git a/src/agnext/application_components/single_threaded_agent_runtime.py b/src/agnext/application_components/single_threaded_agent_runtime.py index fbd9309b0..bda10272c 100644 --- a/src/agnext/application_components/single_threaded_agent_runtime.py +++ b/src/agnext/application_components/single_threaded_agent_runtime.py @@ -3,11 +3,12 @@ from asyncio import Future from dataclasses import dataclass from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar +from agnext.core.cancellation_token import CancellationToken + from ..core.agent import Agent from ..core.agent_runtime import AgentRuntime -from ..core.message import Message -T = TypeVar("T", bound=Message) +T = TypeVar("T") @dataclass @@ -17,6 +18,7 @@ class BroadcastMessageEnvelope(Generic[T]): message: T future: Future[List[T]] + cancellation_token: CancellationToken @dataclass @@ -27,6 +29,7 @@ class SendMessageEnvelope(Generic[T]): message: T destination: Agent[T] future: Future[T] + cancellation_token: CancellationToken @dataclass @@ -64,17 +67,26 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]): self._agents.add(agent) # Returns the response of the message - def send_message(self, message: T, destination: Agent[T]) -> Future[T]: + def send_message( + self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None + ) -> Future[T]: + if cancellation_token is None: + cancellation_token = CancellationToken() + loop = asyncio.get_event_loop() future: Future[T] = loop.create_future() - self._message_queue.append(SendMessageEnvelope(message, destination, future)) + self._message_queue.append(SendMessageEnvelope(message, destination, future, cancellation_token)) + return future # Returns the response of all handling agents - def broadcast_message(self, message: T) -> Future[List[T]]: + def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: + if cancellation_token is None: + cancellation_token = CancellationToken() + future: Future[List[T]] = asyncio.get_event_loop().create_future() - self._message_queue.append(BroadcastMessageEnvelope(message, future)) + self._message_queue.append(BroadcastMessageEnvelope(message, future, cancellation_token)) return future async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None: @@ -83,16 +95,28 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]): message_envelope.future.set_exception(Exception("Recipient not found")) return - response = await recipient.on_message(message_envelope.message) + try: + response = await recipient.on_message( + message_envelope.message, cancellation_token=message_envelope.cancellation_token + ) + except BaseException as e: + message_envelope.future.set_exception(e) + return + self._message_queue.append(ResponseMessageEnvelope(response, message_envelope.future)) async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T]) -> None: responses: List[Awaitable[T]] = [] for agent in self._per_type_subscribers.get(type(message_envelope.message), []): - future = agent.on_message(message_envelope.message) + future = agent.on_message(message_envelope.message, cancellation_token=message_envelope.cancellation_token) responses.append(future) - all_responses = await asyncio.gather(*responses) + try: + all_responses = await asyncio.gather(*responses) + except BaseException as e: + message_envelope.future.set_exception(e) + return + self._message_queue.append(BroadcastResponseMessageEnvelope(all_responses, message_envelope.future)) async def _process_response(self, message_envelope: ResponseMessageEnvelope[T]) -> None: @@ -110,10 +134,14 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]): message_envelope = self._message_queue.pop(0) match message_envelope: - case SendMessageEnvelope(message, destination, future): - asyncio.create_task(self._process_send(SendMessageEnvelope(message, destination, future))) - case BroadcastMessageEnvelope(message, future): - asyncio.create_task(self._process_broadcast(BroadcastMessageEnvelope(message, future))) + case SendMessageEnvelope(message, destination, future, cancellation_token): + asyncio.create_task( + self._process_send(SendMessageEnvelope(message, destination, future, cancellation_token)) + ) + case BroadcastMessageEnvelope(message, future, cancellation_token): + asyncio.create_task( + self._process_broadcast(BroadcastMessageEnvelope(message, future, cancellation_token)) + ) case ResponseMessageEnvelope(message, future): asyncio.create_task(self._process_response(ResponseMessageEnvelope(message, future))) case BroadcastResponseMessageEnvelope(message, future): diff --git a/src/agnext/core/agent.py b/src/agnext/core/agent.py index 67e49b623..d6078be68 100644 --- a/src/agnext/core/agent.py +++ b/src/agnext/core/agent.py @@ -1,8 +1,8 @@ from typing import Protocol, Sequence, Type, TypeVar -from .message import Message +from agnext.core.cancellation_token import CancellationToken -T = TypeVar("T", bound=Message) +T = TypeVar("T") class Agent(Protocol[T]): @@ -12,4 +12,4 @@ class Agent(Protocol[T]): @property def subscriptions(self) -> Sequence[Type[T]]: ... - async def on_message(self, message: T) -> T: ... + async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ... diff --git a/src/agnext/core/agent_runtime.py b/src/agnext/core/agent_runtime.py index af16637a0..6c936cf70 100644 --- a/src/agnext/core/agent_runtime.py +++ b/src/agnext/core/agent_runtime.py @@ -2,10 +2,9 @@ from asyncio import Future from typing import List, Protocol, TypeVar from agnext.core.agent import Agent +from agnext.core.cancellation_token import CancellationToken -from .message import Message - -T = TypeVar("T", bound=Message) +T = TypeVar("T") # Undeliverable - error @@ -14,7 +13,9 @@ class AgentRuntime(Protocol[T]): def add_agent(self, agent: Agent[T]) -> None: ... # Returns the response of the message - def send_message(self, message: T, destination: Agent[T]) -> Future[T]: ... + def send_message( + self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None + ) -> Future[T]: ... # Returns the response of all handling agents - def broadcast_message(self, message: T) -> Future[List[T]]: ... + def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: ... diff --git a/src/agnext/core/base_agent.py b/src/agnext/core/base_agent.py index 0e68b711f..5282c59e2 100644 --- a/src/agnext/core/base_agent.py +++ b/src/agnext/core/base_agent.py @@ -3,11 +3,11 @@ from asyncio import Future from typing import List, Sequence, Type, TypeVar from agnext.core.agent_runtime import AgentRuntime +from agnext.core.cancellation_token import CancellationToken from .agent import Agent -from .message import Message -T = TypeVar("T", bound=Message) +T = TypeVar("T") class BaseAgent(ABC, Agent[T]): @@ -25,10 +25,20 @@ class BaseAgent(ABC, Agent[T]): return [] @abstractmethod - async def on_message(self, message: T) -> T: ... + async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ... - def _send_message(self, message: T, destination: Agent[T]) -> Future[T]: - return self._router.send_message(message, destination) + def _send_message( + self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None + ) -> Future[T]: + if cancellation_token is None: + cancellation_token = CancellationToken() + future = self._router.send_message(message, destination, cancellation_token) + cancellation_token.link_future(future) + return future - def _broadcast_message(self, message: T) -> Future[List[T]]: - return self._router.broadcast_message(message) + def _broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: + if cancellation_token is None: + cancellation_token = CancellationToken() + future = self._router.broadcast_message(message, cancellation_token) + cancellation_token.link_future(future) + return future diff --git a/src/agnext/core/cancellation_token.py b/src/agnext/core/cancellation_token.py new file mode 100644 index 000000000..841b7f7e8 --- /dev/null +++ b/src/agnext/core/cancellation_token.py @@ -0,0 +1,39 @@ +import threading +from asyncio import Future +from typing import Any, Callable, List + + +class CancellationToken: + def __init__(self) -> None: + self._cancelled: bool = False + self._lock: threading.Lock = threading.Lock() + self._callbacks: List[Callable[[], None]] = [] + + def cancel(self) -> None: + with self._lock: + if not self._cancelled: + self._cancelled = True + for callback in self._callbacks: + callback() + + def is_cancelled(self) -> bool: + with self._lock: + return self._cancelled + + def add_callback(self, callback: Callable[[], None]) -> None: + with self._lock: + if self._cancelled: + callback() + else: + self._callbacks.append(callback) + + def link_future(self, future: Future[Any]) -> None: + with self._lock: + if self._cancelled: + future.cancel() + else: + + def _cancel() -> None: + future.cancel() + + self._callbacks.append(_cancel) diff --git a/src/agnext/core/message.py b/src/agnext/core/message.py deleted file mode 100644 index 8edf09146..000000000 --- a/src/agnext/core/message.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Protocol - - -class Message(Protocol): - sender: str - # reply_to: Optional[str] diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py new file mode 100644 index 000000000..f5996bfcb --- /dev/null +++ b/tests/test_cancellation.py @@ -0,0 +1,122 @@ +import asyncio +import pytest +from dataclasses import dataclass + +from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler +from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime +from agnext.core.agent import Agent +from agnext.core.agent_runtime import AgentRuntime +from agnext.core.cancellation_token import CancellationToken + +@dataclass +class MessageType: + ... + +# Note for future reader: +# To do cancellation, only the token should be interacted with as a user +# If you cancel a future, it may not work as you expect. + +class LongRunningAgent(TypeRoutedAgent[MessageType]): + def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None: + super().__init__(name, router) + self.called = False + self.cancelled = False + + @message_handler(MessageType) + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + self.called = True + sleep = asyncio.ensure_future(asyncio.sleep(100)) + cancellation_token.link_future(sleep) + try: + await sleep + return MessageType() + except asyncio.CancelledError: + self.cancelled = True + raise + +class NestingLongRunningAgent(TypeRoutedAgent[MessageType]): + def __init__(self, name: str, router: AgentRuntime[MessageType], nested_agent: Agent[MessageType]) -> None: + super().__init__(name, router) + self.called = False + self.cancelled = False + self._nested_agent = nested_agent + + @message_handler(MessageType) + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + self.called = True + response = self._send_message(message, self._nested_agent, cancellation_token) + try: + return await response + except asyncio.CancelledError: + self.cancelled = True + raise + + +@pytest.mark.asyncio +async def test_cancellation_with_token() -> None: + router = SingleThreadedAgentRuntime[MessageType]() + + long_running = LongRunningAgent("name", router) + token = CancellationToken() + response = router.send_message(MessageType(), long_running, token) + assert not response.done() + + await router.process_next() + token.cancel() + + with pytest.raises(asyncio.CancelledError): + await response + + assert response.done() + assert long_running.called + assert long_running.cancelled + + + +@pytest.mark.asyncio +async def test_nested_cancellation_only_outer_called() -> None: + router = SingleThreadedAgentRuntime[MessageType]() + + long_running = LongRunningAgent("name", router) + nested = NestingLongRunningAgent("nested", router, long_running) + + token = CancellationToken() + response = router.send_message(MessageType(), nested, token) + assert not response.done() + + await router.process_next() + token.cancel() + + with pytest.raises(asyncio.CancelledError): + await response + + assert response.done() + assert nested.called + assert nested.cancelled + assert long_running.called == False + assert long_running.cancelled == False + +@pytest.mark.asyncio +async def test_nested_cancellation_inner_called() -> None: + router = SingleThreadedAgentRuntime[MessageType]() + + long_running = LongRunningAgent("name", router) + nested = NestingLongRunningAgent("nested", router, long_running) + + token = CancellationToken() + response = router.send_message(MessageType(), nested, token) + assert not response.done() + + await router.process_next() + # allow the inner agent to process + await router.process_next() + token.cancel() + + with pytest.raises(asyncio.CancelledError): + await response + + assert response.done() + assert nested.called + assert nested.cancelled + assert long_running.called + assert long_running.cancelled \ No newline at end of file