mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Add support for task cancellation (#7)
* Add support for task cancellation * add tests to CI * matrix for python testing
This commit is contained in:
21
.github/workflows/checks.yml
vendored
21
.github/workflows/checks.yml
vendored
@@ -51,12 +51,15 @@ jobs:
|
||||
- run: pip install ".[dev]"
|
||||
- uses: jakebailey/pyright-action@v2
|
||||
|
||||
# test:
|
||||
# runs-on: ubuntu-latest
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.10'
|
||||
# - run: pip install ".[dev]"
|
||||
# - run: pytest
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["pypy3.10", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- run: pip install ".[dev]"
|
||||
- run: pytest
|
||||
|
||||
7
.vscode/settings.json
vendored
Normal file
7
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"python.testing.pytestArgs": [
|
||||
"tests"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
||||
@@ -5,11 +5,10 @@ from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_h
|
||||
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from agnext.core.agent import Agent
|
||||
from agnext.core.agent_runtime import AgentRuntime
|
||||
from agnext.core.message import Message
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageType(Message):
|
||||
class MessageType:
|
||||
body: str
|
||||
sender: str
|
||||
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
import asyncio
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from agnext.core.agent_runtime import AgentRuntime
|
||||
from agnext.core.message import Message
|
||||
|
||||
|
||||
# TODO: a runtime should be able to handle multiple types of messages
|
||||
# TODO: allow request and response to be different message types
|
||||
# should support this in handlers.
|
||||
@dataclass
|
||||
class GroupChatMessage(Message):
|
||||
class GroupChatMessage:
|
||||
body: str
|
||||
sender: str
|
||||
require_response: bool
|
||||
|
||||
@@ -13,15 +13,10 @@ classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
dependencies = [
|
||||
"openai>=1.3",
|
||||
"pillow",
|
||||
"aiohttp",
|
||||
"typing-extensions"
|
||||
]
|
||||
dependencies = ["openai>=1.3", "pillow", "aiohttp", "typing-extensions"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["ruff", "pyright", "mypy", "pytest", "types-Pillow"]
|
||||
dev = ["ruff", "pyright", "mypy", "pytest", "pytest-asyncio", "types-Pillow"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
agnext = ["py.typed"]
|
||||
@@ -61,3 +56,7 @@ include = ["src", "examples"]
|
||||
typeCheckingMode = "strict"
|
||||
reportUnnecessaryIsInstance = false
|
||||
reportMissingTypeStubs = false
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
testpaths = ["tests"]
|
||||
|
||||
@@ -2,11 +2,10 @@ from typing import Any, Awaitable, Callable, Dict, Sequence, Type, TypeVar
|
||||
|
||||
from agnext.core.agent_runtime import AgentRuntime
|
||||
from agnext.core.base_agent import BaseAgent
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
from agnext.core.exceptions import CantHandleException
|
||||
|
||||
from ..core.message import Message
|
||||
|
||||
T = TypeVar("T", bound=Message)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# NOTE: this works on concrete types and not inheritance
|
||||
@@ -22,7 +21,7 @@ class TypeRoutedAgent(BaseAgent[T]):
|
||||
def __init__(self, name: str, router: AgentRuntime[T]) -> None:
|
||||
super().__init__(name, router)
|
||||
|
||||
self._handlers: Dict[Type[Any], Callable[[T], Awaitable[T]]] = {}
|
||||
self._handlers: Dict[Type[Any], Callable[[T, CancellationToken], Awaitable[T]]] = {}
|
||||
|
||||
router.add_agent(self)
|
||||
|
||||
@@ -37,12 +36,12 @@ class TypeRoutedAgent(BaseAgent[T]):
|
||||
def subscriptions(self) -> Sequence[Type[T]]:
|
||||
return list(self._handlers.keys())
|
||||
|
||||
async def on_message(self, message: T) -> T:
|
||||
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T:
|
||||
handler = self._handlers.get(type(message))
|
||||
if handler is not None:
|
||||
return await handler(message)
|
||||
return await handler(message, cancellation_token)
|
||||
else:
|
||||
return await self.on_unhandled_message(message)
|
||||
return await self.on_unhandled_message(message, cancellation_token)
|
||||
|
||||
async def on_unhandled_message(self, message: T) -> T:
|
||||
async def on_unhandled_message(self, message: T, cancellation_token: CancellationToken) -> T:
|
||||
raise CantHandleException()
|
||||
|
||||
@@ -3,11 +3,12 @@ from asyncio import Future
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar
|
||||
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
|
||||
from ..core.agent import Agent
|
||||
from ..core.agent_runtime import AgentRuntime
|
||||
from ..core.message import Message
|
||||
|
||||
T = TypeVar("T", bound=Message)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -17,6 +18,7 @@ class BroadcastMessageEnvelope(Generic[T]):
|
||||
|
||||
message: T
|
||||
future: Future[List[T]]
|
||||
cancellation_token: CancellationToken
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -27,6 +29,7 @@ class SendMessageEnvelope(Generic[T]):
|
||||
message: T
|
||||
destination: Agent[T]
|
||||
future: Future[T]
|
||||
cancellation_token: CancellationToken
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -64,17 +67,26 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
|
||||
self._agents.add(agent)
|
||||
|
||||
# Returns the response of the message
|
||||
def send_message(self, message: T, destination: Agent[T]) -> Future[T]:
|
||||
def send_message(
|
||||
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
|
||||
) -> Future[T]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
future: Future[T] = loop.create_future()
|
||||
|
||||
self._message_queue.append(SendMessageEnvelope(message, destination, future))
|
||||
self._message_queue.append(SendMessageEnvelope(message, destination, future, cancellation_token))
|
||||
|
||||
return future
|
||||
|
||||
# Returns the response of all handling agents
|
||||
def broadcast_message(self, message: T) -> Future[List[T]]:
|
||||
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
future: Future[List[T]] = asyncio.get_event_loop().create_future()
|
||||
self._message_queue.append(BroadcastMessageEnvelope(message, future))
|
||||
self._message_queue.append(BroadcastMessageEnvelope(message, future, cancellation_token))
|
||||
return future
|
||||
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None:
|
||||
@@ -83,16 +95,28 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
|
||||
message_envelope.future.set_exception(Exception("Recipient not found"))
|
||||
return
|
||||
|
||||
response = await recipient.on_message(message_envelope.message)
|
||||
try:
|
||||
response = await recipient.on_message(
|
||||
message_envelope.message, cancellation_token=message_envelope.cancellation_token
|
||||
)
|
||||
except BaseException as e:
|
||||
message_envelope.future.set_exception(e)
|
||||
return
|
||||
|
||||
self._message_queue.append(ResponseMessageEnvelope(response, message_envelope.future))
|
||||
|
||||
async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T]) -> None:
|
||||
responses: List[Awaitable[T]] = []
|
||||
for agent in self._per_type_subscribers.get(type(message_envelope.message), []):
|
||||
future = agent.on_message(message_envelope.message)
|
||||
future = agent.on_message(message_envelope.message, cancellation_token=message_envelope.cancellation_token)
|
||||
responses.append(future)
|
||||
|
||||
all_responses = await asyncio.gather(*responses)
|
||||
try:
|
||||
all_responses = await asyncio.gather(*responses)
|
||||
except BaseException as e:
|
||||
message_envelope.future.set_exception(e)
|
||||
return
|
||||
|
||||
self._message_queue.append(BroadcastResponseMessageEnvelope(all_responses, message_envelope.future))
|
||||
|
||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope[T]) -> None:
|
||||
@@ -110,10 +134,14 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]):
|
||||
message_envelope = self._message_queue.pop(0)
|
||||
|
||||
match message_envelope:
|
||||
case SendMessageEnvelope(message, destination, future):
|
||||
asyncio.create_task(self._process_send(SendMessageEnvelope(message, destination, future)))
|
||||
case BroadcastMessageEnvelope(message, future):
|
||||
asyncio.create_task(self._process_broadcast(BroadcastMessageEnvelope(message, future)))
|
||||
case SendMessageEnvelope(message, destination, future, cancellation_token):
|
||||
asyncio.create_task(
|
||||
self._process_send(SendMessageEnvelope(message, destination, future, cancellation_token))
|
||||
)
|
||||
case BroadcastMessageEnvelope(message, future, cancellation_token):
|
||||
asyncio.create_task(
|
||||
self._process_broadcast(BroadcastMessageEnvelope(message, future, cancellation_token))
|
||||
)
|
||||
case ResponseMessageEnvelope(message, future):
|
||||
asyncio.create_task(self._process_response(ResponseMessageEnvelope(message, future)))
|
||||
case BroadcastResponseMessageEnvelope(message, future):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Protocol, Sequence, Type, TypeVar
|
||||
|
||||
from .message import Message
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
|
||||
T = TypeVar("T", bound=Message)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Agent(Protocol[T]):
|
||||
@@ -12,4 +12,4 @@ class Agent(Protocol[T]):
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[Type[T]]: ...
|
||||
|
||||
async def on_message(self, message: T) -> T: ...
|
||||
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ...
|
||||
|
||||
@@ -2,10 +2,9 @@ from asyncio import Future
|
||||
from typing import List, Protocol, TypeVar
|
||||
|
||||
from agnext.core.agent import Agent
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
|
||||
from .message import Message
|
||||
|
||||
T = TypeVar("T", bound=Message)
|
||||
T = TypeVar("T")
|
||||
|
||||
# Undeliverable - error
|
||||
|
||||
@@ -14,7 +13,9 @@ class AgentRuntime(Protocol[T]):
|
||||
def add_agent(self, agent: Agent[T]) -> None: ...
|
||||
|
||||
# Returns the response of the message
|
||||
def send_message(self, message: T, destination: Agent[T]) -> Future[T]: ...
|
||||
def send_message(
|
||||
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
|
||||
) -> Future[T]: ...
|
||||
|
||||
# Returns the response of all handling agents
|
||||
def broadcast_message(self, message: T) -> Future[List[T]]: ...
|
||||
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: ...
|
||||
|
||||
@@ -3,11 +3,11 @@ from asyncio import Future
|
||||
from typing import List, Sequence, Type, TypeVar
|
||||
|
||||
from agnext.core.agent_runtime import AgentRuntime
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
|
||||
from .agent import Agent
|
||||
from .message import Message
|
||||
|
||||
T = TypeVar("T", bound=Message)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseAgent(ABC, Agent[T]):
|
||||
@@ -25,10 +25,20 @@ class BaseAgent(ABC, Agent[T]):
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
async def on_message(self, message: T) -> T: ...
|
||||
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ...
|
||||
|
||||
def _send_message(self, message: T, destination: Agent[T]) -> Future[T]:
|
||||
return self._router.send_message(message, destination)
|
||||
def _send_message(
|
||||
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
|
||||
) -> Future[T]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
future = self._router.send_message(message, destination, cancellation_token)
|
||||
cancellation_token.link_future(future)
|
||||
return future
|
||||
|
||||
def _broadcast_message(self, message: T) -> Future[List[T]]:
|
||||
return self._router.broadcast_message(message)
|
||||
def _broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
future = self._router.broadcast_message(message, cancellation_token)
|
||||
cancellation_token.link_future(future)
|
||||
return future
|
||||
|
||||
39
src/agnext/core/cancellation_token.py
Normal file
39
src/agnext/core/cancellation_token.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import threading
|
||||
from asyncio import Future
|
||||
from typing import Any, Callable, List
|
||||
|
||||
|
||||
class CancellationToken:
|
||||
def __init__(self) -> None:
|
||||
self._cancelled: bool = False
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
self._callbacks: List[Callable[[], None]] = []
|
||||
|
||||
def cancel(self) -> None:
|
||||
with self._lock:
|
||||
if not self._cancelled:
|
||||
self._cancelled = True
|
||||
for callback in self._callbacks:
|
||||
callback()
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
with self._lock:
|
||||
return self._cancelled
|
||||
|
||||
def add_callback(self, callback: Callable[[], None]) -> None:
|
||||
with self._lock:
|
||||
if self._cancelled:
|
||||
callback()
|
||||
else:
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def link_future(self, future: Future[Any]) -> None:
|
||||
with self._lock:
|
||||
if self._cancelled:
|
||||
future.cancel()
|
||||
else:
|
||||
|
||||
def _cancel() -> None:
|
||||
future.cancel()
|
||||
|
||||
self._callbacks.append(_cancel)
|
||||
@@ -1,6 +0,0 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class Message(Protocol):
|
||||
sender: str
|
||||
# reply_to: Optional[str]
|
||||
122
tests/test_cancellation.py
Normal file
122
tests/test_cancellation.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from agnext.core.agent import Agent
|
||||
from agnext.core.agent_runtime import AgentRuntime
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
|
||||
@dataclass
|
||||
class MessageType:
|
||||
...
|
||||
|
||||
# Note for future reader:
|
||||
# To do cancellation, only the token should be interacted with as a user
|
||||
# If you cancel a future, it may not work as you expect.
|
||||
|
||||
class LongRunningAgent(TypeRoutedAgent[MessageType]):
|
||||
def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None:
|
||||
super().__init__(name, router)
|
||||
self.called = False
|
||||
self.cancelled = False
|
||||
|
||||
@message_handler(MessageType)
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.called = True
|
||||
sleep = asyncio.ensure_future(asyncio.sleep(100))
|
||||
cancellation_token.link_future(sleep)
|
||||
try:
|
||||
await sleep
|
||||
return MessageType()
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
|
||||
class NestingLongRunningAgent(TypeRoutedAgent[MessageType]):
|
||||
def __init__(self, name: str, router: AgentRuntime[MessageType], nested_agent: Agent[MessageType]) -> None:
|
||||
super().__init__(name, router)
|
||||
self.called = False
|
||||
self.cancelled = False
|
||||
self._nested_agent = nested_agent
|
||||
|
||||
@message_handler(MessageType)
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.called = True
|
||||
response = self._send_message(message, self._nested_agent, cancellation_token)
|
||||
try:
|
||||
return await response
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_with_token() -> None:
|
||||
router = SingleThreadedAgentRuntime[MessageType]()
|
||||
|
||||
long_running = LongRunningAgent("name", router)
|
||||
token = CancellationToken()
|
||||
response = router.send_message(MessageType(), long_running, token)
|
||||
assert not response.done()
|
||||
|
||||
await router.process_next()
|
||||
token.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
assert long_running.called
|
||||
assert long_running.cancelled
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_cancellation_only_outer_called() -> None:
|
||||
router = SingleThreadedAgentRuntime[MessageType]()
|
||||
|
||||
long_running = LongRunningAgent("name", router)
|
||||
nested = NestingLongRunningAgent("nested", router, long_running)
|
||||
|
||||
token = CancellationToken()
|
||||
response = router.send_message(MessageType(), nested, token)
|
||||
assert not response.done()
|
||||
|
||||
await router.process_next()
|
||||
token.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
assert nested.called
|
||||
assert nested.cancelled
|
||||
assert long_running.called == False
|
||||
assert long_running.cancelled == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_cancellation_inner_called() -> None:
|
||||
router = SingleThreadedAgentRuntime[MessageType]()
|
||||
|
||||
long_running = LongRunningAgent("name", router)
|
||||
nested = NestingLongRunningAgent("nested", router, long_running)
|
||||
|
||||
token = CancellationToken()
|
||||
response = router.send_message(MessageType(), nested, token)
|
||||
assert not response.done()
|
||||
|
||||
await router.process_next()
|
||||
# allow the inner agent to process
|
||||
await router.process_next()
|
||||
token.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
assert nested.called
|
||||
assert nested.cancelled
|
||||
assert long_running.called
|
||||
assert long_running.cancelled
|
||||
Reference in New Issue
Block a user