mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Agent factory can be async (#247)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from agnext.core import AgentRuntime, agent_instantiation_context, AgentId
|
||||
from agnext.core import AgentRuntime, AGENT_INSTANTIATION_CONTEXT_VAR, AgentId
|
||||
|
||||
from test_utils import NoopAgent
|
||||
|
||||
@@ -11,7 +11,7 @@ async def test_base_agent_create(mocker: MockerFixture) -> None:
|
||||
runtime = mocker.Mock(spec=AgentRuntime)
|
||||
|
||||
# Shows how to set the context for the agent instantiation in a test context
|
||||
agent_instantiation_context.set((runtime, AgentId("name", "namespace")))
|
||||
AGENT_INSTANTIATION_CONTEXT_VAR.set((runtime, AgentId("name", "namespace")))
|
||||
|
||||
agent = NoopAgent()
|
||||
assert agent.runtime == runtime
|
||||
|
||||
@@ -57,7 +57,7 @@ class NestingLongRunningAgent(TypeRoutedAgent):
|
||||
async def test_cancellation_with_token() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = runtime.register_and_get("long_running", LongRunningAgent)
|
||||
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
|
||||
token = CancellationToken()
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token))
|
||||
assert not response.done()
|
||||
@@ -73,7 +73,7 @@ async def test_cancellation_with_token() -> None:
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.called
|
||||
assert long_running_agent.cancelled
|
||||
|
||||
@@ -83,8 +83,8 @@ async def test_cancellation_with_token() -> None:
|
||||
async def test_nested_cancellation_only_outer_called() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = runtime.register_and_get("long_running", LongRunningAgent)
|
||||
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
|
||||
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
|
||||
token = CancellationToken()
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
|
||||
@@ -100,10 +100,10 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore
|
||||
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
|
||||
assert nested_agent.called
|
||||
assert nested_agent.cancelled
|
||||
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.called is False
|
||||
assert long_running_agent.cancelled is False
|
||||
|
||||
@@ -111,8 +111,8 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
||||
async def test_nested_cancellation_inner_called() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = runtime.register_and_get("long_running", LongRunningAgent )
|
||||
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
long_running = await runtime.register_and_get("long_running", LongRunningAgent )
|
||||
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
|
||||
token = CancellationToken()
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
|
||||
@@ -130,9 +130,9 @@ async def test_nested_cancellation_inner_called() -> None:
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore
|
||||
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
|
||||
assert nested_agent.called
|
||||
assert nested_agent.cancelled
|
||||
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.called
|
||||
assert long_running_agent.cancelled
|
||||
|
||||
@@ -28,7 +28,7 @@ async def test_register_receives_publish() -> None:
|
||||
namespace = id.namespace
|
||||
await queue.put((namespace, message.content))
|
||||
|
||||
runtime.register("name", lambda: ClosureAgent("My agent", log_message))
|
||||
await runtime.register("name", lambda: ClosureAgent("My agent", log_message))
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(Message("first message"), namespace="default")
|
||||
await runtime.publish_message(Message("second message"), namespace="default")
|
||||
|
||||
@@ -19,7 +19,7 @@ async def test_intervention_count_messages() -> None:
|
||||
|
||||
handler = DebugInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
|
||||
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
@@ -27,7 +27,7 @@ async def test_intervention_count_messages() -> None:
|
||||
await run_context.stop()
|
||||
|
||||
assert handler.num_messages == 1
|
||||
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
|
||||
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
|
||||
assert loopback_agent.num_calls == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -40,7 +40,7 @@ async def test_intervention_drop_send() -> None:
|
||||
handler = DropSendInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
@@ -48,7 +48,7 @@ async def test_intervention_drop_send() -> None:
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
|
||||
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
|
||||
assert loopback_agent.num_calls == 0
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ async def test_intervention_drop_response() -> None:
|
||||
handler = DropResponseInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
@@ -84,7 +84,7 @@ async def test_intervention_raise_exception_on_send() -> None:
|
||||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
long_running = await runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(InterventionException):
|
||||
@@ -92,7 +92,7 @@ async def test_intervention_raise_exception_on_send() -> None:
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -108,12 +108,12 @@ async def test_intervention_raise_exception_on_respond() -> None:
|
||||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
long_running = await runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
with pytest.raises(InterventionException):
|
||||
_response = await runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
@@ -14,31 +14,31 @@ async def test_agent_names_must_be_unique() -> None:
|
||||
assert agent.id == id
|
||||
return agent
|
||||
|
||||
agent1 = runtime.register_and_get("name1", agent_factory)
|
||||
agent1 = await runtime.register_and_get("name1", agent_factory)
|
||||
assert agent1 == AgentId("name1", "default")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_agent1 = runtime.register_and_get("name1", NoopAgent)
|
||||
_agent1 = await runtime.register_and_get("name1", NoopAgent)
|
||||
|
||||
_agent1 = runtime.register_and_get("name3", NoopAgent)
|
||||
_agent1 = await runtime.register_and_get("name3", NoopAgent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
runtime.register("name", LoopbackAgent)
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(MessageType(), namespace="default")
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
|
||||
long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name")) # type: ignore
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore
|
||||
other_long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name", namespace="other")) # type: ignore
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
@@ -67,5 +67,5 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
|
||||
# Check that each agent received the correct number of messages.
|
||||
for i in range(num_agents):
|
||||
agent: CascadingAgent = runtime._get_agent(runtime.get(f"name{i}")) # type: ignore
|
||||
agent: CascadingAgent = await runtime._get_agent(await runtime.get(f"name{i}")) # type: ignore
|
||||
assert agent.num_calls == total_num_calls_expected
|
||||
|
||||
@@ -5,8 +5,8 @@ from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.core import BaseAgent, CancellationToken
|
||||
|
||||
|
||||
class StatefulAgent(BaseAgent): # type: ignore
|
||||
def __init__(self) -> None: # type: ignore
|
||||
class StatefulAgent(BaseAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("A stateful agent", [])
|
||||
self.state = 0
|
||||
|
||||
@@ -14,7 +14,7 @@ class StatefulAgent(BaseAgent): # type: ignore
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
return []
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
@@ -28,8 +28,8 @@ class StatefulAgent(BaseAgent): # type: ignore
|
||||
async def test_agent_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1_id = runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore
|
||||
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
@@ -46,19 +46,19 @@ async def test_agent_can_save_state() -> None:
|
||||
async def test_runtime_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1_id = runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore
|
||||
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
|
||||
runtime_state = runtime.save_state()
|
||||
runtime_state = await runtime.save_state()
|
||||
|
||||
runtime2 = SingleThreadedAgentRuntime()
|
||||
agent2_id = runtime2.register_and_get("name1", StatefulAgent)
|
||||
agent2: StatefulAgent = runtime2._get_agent(agent2_id) # type: ignore
|
||||
agent2_id = await runtime2.register_and_get("name1", StatefulAgent)
|
||||
agent2: StatefulAgent = await runtime2._get_agent(agent2_id) # type: ignore
|
||||
|
||||
runtime2.load_state(runtime_state)
|
||||
await runtime2.load_state(runtime_state)
|
||||
assert agent2.state == 1
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user