mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Move python code to subdir (#98)
This commit is contained in:
124
python/tests/test_intervention.py
Normal file
124
python/tests/test_intervention.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.core import AgentId
|
||||
from agnext.core.exceptions import MessageDroppedException
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
from test_utils import LoopbackAgent, MessageType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_count_messages() -> None:
|
||||
|
||||
class DebugInterventionHandler(DefaultInterventionHandler):
|
||||
def __init__(self) -> None:
|
||||
self.num_messages = 0
|
||||
|
||||
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType:
|
||||
self.num_messages += 1
|
||||
return message
|
||||
|
||||
handler = DebugInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
|
||||
response = runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
assert handler.num_messages == 1
|
||||
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
|
||||
assert loopback_agent.num_calls == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_drop_send() -> None:
|
||||
|
||||
class DropSendInterventionHandler(DefaultInterventionHandler):
|
||||
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType | type[DropMessage]:
|
||||
return DropMessage
|
||||
|
||||
handler = DropSendInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
await response
|
||||
|
||||
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
|
||||
assert loopback_agent.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_drop_response() -> None:
|
||||
|
||||
class DropResponseInterventionHandler(DefaultInterventionHandler):
|
||||
async def on_response(self, message: MessageType, *, sender: AgentId, recipient: AgentId | None) -> MessageType | type[DropMessage]:
|
||||
return DropMessage
|
||||
|
||||
handler = DropResponseInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
await response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_raise_exception_on_send() -> None:
|
||||
|
||||
class InterventionException(Exception):
|
||||
pass
|
||||
|
||||
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
|
||||
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType | type[DropMessage]: # type: ignore
|
||||
raise InterventionException
|
||||
|
||||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
with pytest.raises(InterventionException):
|
||||
await response
|
||||
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_raise_exception_on_respond() -> None:
|
||||
|
||||
class InterventionException(Exception):
|
||||
pass
|
||||
|
||||
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
|
||||
async def on_response(self, message: MessageType, *, sender: AgentId, recipient: AgentId | None) -> MessageType | type[DropMessage]: # type: ignore
|
||||
raise InterventionException
|
||||
|
||||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
with pytest.raises(InterventionException):
|
||||
await response
|
||||
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.num_calls == 1
|
||||
Reference in New Issue
Block a user