Add require response and remove generic types (#13)

This commit is contained in:
Jack Gerrits
2024-05-23 16:00:05 -04:00
committed by GitHub
parent d77390dc07
commit 8d1f4aedc0
16 changed files with 286 additions and 209 deletions

View File

@@ -16,14 +16,14 @@ class MessageType:
# To do cancellation, only the token should be interacted with as a user
# If you cancel a future, it may not work as you expect.
class LongRunningAgent(TypeRoutedAgent[MessageType]):
def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None:
class LongRunningAgent(TypeRoutedAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
super().__init__(name, router)
self.called = False
self.cancelled = False
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100))
cancellation_token.link_future(sleep)
@@ -34,19 +34,22 @@ class LongRunningAgent(TypeRoutedAgent[MessageType]):
self.cancelled = True
raise
class NestingLongRunningAgent(TypeRoutedAgent[MessageType]):
def __init__(self, name: str, router: AgentRuntime[MessageType], nested_agent: Agent[MessageType]) -> None:
class NestingLongRunningAgent(TypeRoutedAgent):
def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None:
super().__init__(name, router)
self.called = False
self.cancelled = False
self._nested_agent = nested_agent
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
assert require_response == True
self.called = True
response = self._send_message(message, self._nested_agent, cancellation_token)
response = self._send_message(message, self._nested_agent, require_response=require_response, cancellation_token=cancellation_token)
try:
return await response
val = await response
assert isinstance(val, MessageType)
return val
except asyncio.CancelledError:
self.cancelled = True
raise
@@ -54,7 +57,7 @@ class NestingLongRunningAgent(TypeRoutedAgent[MessageType]):
@pytest.mark.asyncio
async def test_cancellation_with_token() -> None:
router = SingleThreadedAgentRuntime[MessageType]()
router = SingleThreadedAgentRuntime()
long_running = LongRunningAgent("name", router)
token = CancellationToken()
@@ -75,7 +78,7 @@ async def test_cancellation_with_token() -> None:
@pytest.mark.asyncio
async def test_nested_cancellation_only_outer_called() -> None:
router = SingleThreadedAgentRuntime[MessageType]()
router = SingleThreadedAgentRuntime()
long_running = LongRunningAgent("name", router)
nested = NestingLongRunningAgent("nested", router, long_running)
@@ -98,7 +101,7 @@ async def test_nested_cancellation_only_outer_called() -> None:
@pytest.mark.asyncio
async def test_nested_cancellation_inner_called() -> None:
router = SingleThreadedAgentRuntime[MessageType]()
router = SingleThreadedAgentRuntime()
long_running = LongRunningAgent("name", router)
nested = NestingLongRunningAgent("nested", router, long_running)

View File

@@ -13,30 +13,30 @@ from agnext.core.intervention import DefaultInterventionHandler, DropMessage
class MessageType:
...
class LoopbackAgent(TypeRoutedAgent[MessageType]):
def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None:
class LoopbackAgent(TypeRoutedAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
super().__init__(name, router)
self.num_calls = 0
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
self.num_calls += 1
return message
@pytest.mark.asyncio
async def test_intervention_count_messages() -> None:
class DebugInterventionHandler(DefaultInterventionHandler[MessageType]):
class DebugInterventionHandler(DefaultInterventionHandler):
def __init__(self):
self.num_messages = 0
async def on_send(self, message: MessageType, *, sender: Agent[MessageType] | None, recipient: Agent[MessageType]) -> MessageType:
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType:
self.num_messages += 1
return message
handler = DebugInterventionHandler()
router = SingleThreadedAgentRuntime[MessageType](before_send=handler)
router = SingleThreadedAgentRuntime(before_send=handler)
long_running = LoopbackAgent("name", router)
response = router.send_message(MessageType(), recipient=long_running)
@@ -50,12 +50,12 @@ async def test_intervention_count_messages() -> None:
@pytest.mark.asyncio
async def test_intervention_drop_send() -> None:
class DropSendInterventionHandler(DefaultInterventionHandler[MessageType]):
async def on_send(self, message: MessageType, *, sender: Agent[MessageType] | None, recipient: Agent[MessageType]) -> MessageType | type[DropMessage]:
class DropSendInterventionHandler(DefaultInterventionHandler):
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType | type[DropMessage]:
return DropMessage
handler = DropSendInterventionHandler()
router = SingleThreadedAgentRuntime[MessageType](before_send=handler)
router = SingleThreadedAgentRuntime(before_send=handler)
long_running = LoopbackAgent("name", router)
response = router.send_message(MessageType(), recipient=long_running)
@@ -72,12 +72,12 @@ async def test_intervention_drop_send() -> None:
@pytest.mark.asyncio
async def test_intervention_drop_response() -> None:
class DropResponseInterventionHandler(DefaultInterventionHandler[MessageType]):
async def on_response(self, message: MessageType, *, sender: Agent[MessageType], recipient: Agent[MessageType] | None) -> MessageType | type[DropMessage]:
class DropResponseInterventionHandler(DefaultInterventionHandler):
async def on_response(self, message: MessageType, *, sender: Agent, recipient: Agent | None) -> MessageType | type[DropMessage]:
return DropMessage
handler = DropResponseInterventionHandler()
router = SingleThreadedAgentRuntime[MessageType](before_send=handler)
router = SingleThreadedAgentRuntime(before_send=handler)
long_running = LoopbackAgent("name", router)
response = router.send_message(MessageType(), recipient=long_running)