Add support for task cancellation (#7)

* Add support for task cancellation

* add tests to CI

* matrix for python testing
This commit is contained in:
Jack Gerrits
2024-05-20 13:32:08 -06:00
committed by GitHub
parent f80c42e668
commit 5afbadbe43
13 changed files with 265 additions and 64 deletions

View File

@@ -51,12 +51,15 @@ jobs:
- run: pip install ".[dev]"
- uses: jakebailey/pyright-action@v2
# test:
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v5
# with:
# python-version: '3.10'
# - run: pip install ".[dev]"
# - run: pytest
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["pypy3.10", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- run: pip install ".[dev]"
- run: pytest

7
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

View File

@@ -5,11 +5,10 @@ from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_h
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.core.agent import Agent
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.message import Message
@dataclass
class MessageType(Message):
class MessageType:
body: str
sender: str

View File

@@ -1,18 +1,18 @@
from dataclasses import dataclass
import random
import asyncio
import random
from dataclasses import dataclass
from typing import List
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.message import Message
# TODO: a runtime should be able to handle multiple types of messages
# TODO: allow request and response to be different message types
# should support this in handlers.
@dataclass
class GroupChatMessage(Message):
class GroupChatMessage:
body: str
sender: str
require_response: bool

View File

@@ -13,15 +13,10 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"openai>=1.3",
"pillow",
"aiohttp",
"typing-extensions"
]
dependencies = ["openai>=1.3", "pillow", "aiohttp", "typing-extensions"]
[project.optional-dependencies]
dev = ["ruff", "pyright", "mypy", "pytest", "types-Pillow"]
dev = ["ruff", "pyright", "mypy", "pytest", "pytest-asyncio", "types-Pillow"]
[tool.setuptools.package-data]
agnext = ["py.typed"]
@@ -61,3 +56,7 @@ include = ["src", "examples"]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false
[tool.pytest.ini_options]
minversion = "6.0"
testpaths = ["tests"]

View File

@@ -2,11 +2,10 @@ from typing import Any, Awaitable, Callable, Dict, Sequence, Type, TypeVar
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.base_agent import BaseAgent
from agnext.core.cancellation_token import CancellationToken
from agnext.core.exceptions import CantHandleException
from ..core.message import Message
T = TypeVar("T", bound=Message)
T = TypeVar("T")
# NOTE: this works on concrete types and not inheritance
@@ -22,7 +21,7 @@ class TypeRoutedAgent(BaseAgent[T]):
def __init__(self, name: str, router: AgentRuntime[T]) -> None:
super().__init__(name, router)
self._handlers: Dict[Type[Any], Callable[[T], Awaitable[T]]] = {}
self._handlers: Dict[Type[Any], Callable[[T, CancellationToken], Awaitable[T]]] = {}
router.add_agent(self)
@@ -37,12 +36,12 @@ class TypeRoutedAgent(BaseAgent[T]):
def subscriptions(self) -> Sequence[Type[T]]:
return list(self._handlers.keys())
async def on_message(self, message: T) -> T:
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T:
handler = self._handlers.get(type(message))
if handler is not None:
return await handler(message)
return await handler(message, cancellation_token)
else:
return await self.on_unhandled_message(message)
return await self.on_unhandled_message(message, cancellation_token)
async def on_unhandled_message(self, message: T) -> T:
async def on_unhandled_message(self, message: T, cancellation_token: CancellationToken) -> T:
raise CantHandleException()

View File

@@ -3,11 +3,12 @@ from asyncio import Future
from dataclasses import dataclass
from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar
from agnext.core.cancellation_token import CancellationToken
from ..core.agent import Agent
from ..core.agent_runtime import AgentRuntime
from ..core.message import Message
T = TypeVar("T", bound=Message)
T = TypeVar("T")
@dataclass
@@ -17,6 +18,7 @@ class BroadcastMessageEnvelope(Generic[T]):
message: T
future: Future[List[T]]
cancellation_token: CancellationToken
@dataclass
@@ -27,6 +29,7 @@ class SendMessageEnvelope(Generic[T]):
message: T
destination: Agent[T]
future: Future[T]
cancellation_token: CancellationToken
@dataclass
@@ -64,17 +67,26 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
self._agents.add(agent)
# Returns the response of the message
def send_message(self, message: T, destination: Agent[T]) -> Future[T]:
def send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
) -> Future[T]:
if cancellation_token is None:
cancellation_token = CancellationToken()
loop = asyncio.get_event_loop()
future: Future[T] = loop.create_future()
self._message_queue.append(SendMessageEnvelope(message, destination, future))
self._message_queue.append(SendMessageEnvelope(message, destination, future, cancellation_token))
return future
# Returns the response of all handling agents
def broadcast_message(self, message: T) -> Future[List[T]]:
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future: Future[List[T]] = asyncio.get_event_loop().create_future()
self._message_queue.append(BroadcastMessageEnvelope(message, future))
self._message_queue.append(BroadcastMessageEnvelope(message, future, cancellation_token))
return future
async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None:
@@ -83,16 +95,28 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
message_envelope.future.set_exception(Exception("Recipient not found"))
return
response = await recipient.on_message(message_envelope.message)
try:
response = await recipient.on_message(
message_envelope.message, cancellation_token=message_envelope.cancellation_token
)
except BaseException as e:
message_envelope.future.set_exception(e)
return
self._message_queue.append(ResponseMessageEnvelope(response, message_envelope.future))
async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T]) -> None:
responses: List[Awaitable[T]] = []
for agent in self._per_type_subscribers.get(type(message_envelope.message), []):
future = agent.on_message(message_envelope.message)
future = agent.on_message(message_envelope.message, cancellation_token=message_envelope.cancellation_token)
responses.append(future)
all_responses = await asyncio.gather(*responses)
try:
all_responses = await asyncio.gather(*responses)
except BaseException as e:
message_envelope.future.set_exception(e)
return
self._message_queue.append(BroadcastResponseMessageEnvelope(all_responses, message_envelope.future))
async def _process_response(self, message_envelope: ResponseMessageEnvelope[T]) -> None:
@@ -110,10 +134,14 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
message_envelope = self._message_queue.pop(0)
match message_envelope:
case SendMessageEnvelope(message, destination, future):
asyncio.create_task(self._process_send(SendMessageEnvelope(message, destination, future)))
case BroadcastMessageEnvelope(message, future):
asyncio.create_task(self._process_broadcast(BroadcastMessageEnvelope(message, future)))
case SendMessageEnvelope(message, destination, future, cancellation_token):
asyncio.create_task(
self._process_send(SendMessageEnvelope(message, destination, future, cancellation_token))
)
case BroadcastMessageEnvelope(message, future, cancellation_token):
asyncio.create_task(
self._process_broadcast(BroadcastMessageEnvelope(message, future, cancellation_token))
)
case ResponseMessageEnvelope(message, future):
asyncio.create_task(self._process_response(ResponseMessageEnvelope(message, future)))
case BroadcastResponseMessageEnvelope(message, future):

View File

@@ -1,8 +1,8 @@
from typing import Protocol, Sequence, Type, TypeVar
from .message import Message
from agnext.core.cancellation_token import CancellationToken
T = TypeVar("T", bound=Message)
T = TypeVar("T")
class Agent(Protocol[T]):
@@ -12,4 +12,4 @@ class Agent(Protocol[T]):
@property
def subscriptions(self) -> Sequence[Type[T]]: ...
async def on_message(self, message: T) -> T: ...
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ...

View File

@@ -2,10 +2,9 @@ from asyncio import Future
from typing import List, Protocol, TypeVar
from agnext.core.agent import Agent
from agnext.core.cancellation_token import CancellationToken
from .message import Message
T = TypeVar("T", bound=Message)
T = TypeVar("T")
# Undeliverable - error
@@ -14,7 +13,9 @@ class AgentRuntime(Protocol[T]):
def add_agent(self, agent: Agent[T]) -> None: ...
# Returns the response of the message
def send_message(self, message: T, destination: Agent[T]) -> Future[T]: ...
def send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
) -> Future[T]: ...
# Returns the response of all handling agents
def broadcast_message(self, message: T) -> Future[List[T]]: ...
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: ...

View File

@@ -3,11 +3,11 @@ from asyncio import Future
from typing import List, Sequence, Type, TypeVar
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.cancellation_token import CancellationToken
from .agent import Agent
from .message import Message
T = TypeVar("T", bound=Message)
T = TypeVar("T")
class BaseAgent(ABC, Agent[T]):
@@ -25,10 +25,20 @@ class BaseAgent(ABC, Agent[T]):
return []
@abstractmethod
async def on_message(self, message: T) -> T: ...
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ...
def _send_message(self, message: T, destination: Agent[T]) -> Future[T]:
return self._router.send_message(message, destination)
def _send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
) -> Future[T]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._router.send_message(message, destination, cancellation_token)
cancellation_token.link_future(future)
return future
def _broadcast_message(self, message: T) -> Future[List[T]]:
return self._router.broadcast_message(message)
def _broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._router.broadcast_message(message, cancellation_token)
cancellation_token.link_future(future)
return future

View File

@@ -0,0 +1,39 @@
import threading
from asyncio import Future
from typing import Any, Callable, List
class CancellationToken:
def __init__(self) -> None:
self._cancelled: bool = False
self._lock: threading.Lock = threading.Lock()
self._callbacks: List[Callable[[], None]] = []
def cancel(self) -> None:
with self._lock:
if not self._cancelled:
self._cancelled = True
for callback in self._callbacks:
callback()
def is_cancelled(self) -> bool:
with self._lock:
return self._cancelled
def add_callback(self, callback: Callable[[], None]) -> None:
with self._lock:
if self._cancelled:
callback()
else:
self._callbacks.append(callback)
def link_future(self, future: Future[Any]) -> None:
with self._lock:
if self._cancelled:
future.cancel()
else:
def _cancel() -> None:
future.cancel()
self._callbacks.append(_cancel)

View File

@@ -1,6 +0,0 @@
from typing import Protocol
class Message(Protocol):
sender: str
# reply_to: Optional[str]

122
tests/test_cancellation.py Normal file
View File

@@ -0,0 +1,122 @@
import asyncio
import pytest
from dataclasses import dataclass
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.core.agent import Agent
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.cancellation_token import CancellationToken
@dataclass
class MessageType:
...
# Note for future reader:
# To do cancellation, only the token should be interacted with as a user
# If you cancel a future, it may not work as you expect.
class LongRunningAgent(TypeRoutedAgent[MessageType]):
def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None:
super().__init__(name, router)
self.called = False
self.cancelled = False
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100))
cancellation_token.link_future(sleep)
try:
await sleep
return MessageType()
except asyncio.CancelledError:
self.cancelled = True
raise
class NestingLongRunningAgent(TypeRoutedAgent[MessageType]):
def __init__(self, name: str, router: AgentRuntime[MessageType], nested_agent: Agent[MessageType]) -> None:
super().__init__(name, router)
self.called = False
self.cancelled = False
self._nested_agent = nested_agent
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
response = self._send_message(message, self._nested_agent, cancellation_token)
try:
return await response
except asyncio.CancelledError:
self.cancelled = True
raise
@pytest.mark.asyncio
async def test_cancellation_with_token() -> None:
router = SingleThreadedAgentRuntime[MessageType]()
long_running = LongRunningAgent("name", router)
token = CancellationToken()
response = router.send_message(MessageType(), long_running, token)
assert not response.done()
await router.process_next()
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
assert long_running.called
assert long_running.cancelled
@pytest.mark.asyncio
async def test_nested_cancellation_only_outer_called() -> None:
router = SingleThreadedAgentRuntime[MessageType]()
long_running = LongRunningAgent("name", router)
nested = NestingLongRunningAgent("nested", router, long_running)
token = CancellationToken()
response = router.send_message(MessageType(), nested, token)
assert not response.done()
await router.process_next()
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
assert nested.called
assert nested.cancelled
assert long_running.called == False
assert long_running.cancelled == False
@pytest.mark.asyncio
async def test_nested_cancellation_inner_called() -> None:
router = SingleThreadedAgentRuntime[MessageType]()
long_running = LongRunningAgent("name", router)
nested = NestingLongRunningAgent("nested", router, long_running)
token = CancellationToken()
response = router.send_message(MessageType(), nested, token)
assert not response.done()
await router.process_next()
# allow the inner agent to process
await router.process_next()
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
assert nested.called
assert nested.cancelled
assert long_running.called
assert long_running.cancelled