Agent factory can be async (#247)

This commit is contained in:
Jack Gerrits
2024-07-23 11:49:38 -07:00
committed by GitHub
parent 718fad6e0d
commit a52d3bab53
47 changed files with 352 additions and 299 deletions

View File

@@ -1,6 +1,6 @@
import pytest
from pytest_mock import MockerFixture
from agnext.core import AgentRuntime, agent_instantiation_context, AgentId
from agnext.core import AgentRuntime, AGENT_INSTANTIATION_CONTEXT_VAR, AgentId
from test_utils import NoopAgent
@@ -11,7 +11,7 @@ async def test_base_agent_create(mocker: MockerFixture) -> None:
runtime = mocker.Mock(spec=AgentRuntime)
# Shows how to set the context for the agent instantiation in a test context
agent_instantiation_context.set((runtime, AgentId("name", "namespace")))
AGENT_INSTANTIATION_CONTEXT_VAR.set((runtime, AgentId("name", "namespace")))
agent = NoopAgent()
assert agent.runtime == runtime

View File

@@ -57,7 +57,7 @@ class NestingLongRunningAgent(TypeRoutedAgent):
async def test_cancellation_with_token() -> None:
runtime = SingleThreadedAgentRuntime()
long_running = runtime.register_and_get("long_running", LongRunningAgent)
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token))
assert not response.done()
@@ -73,7 +73,7 @@ async def test_cancellation_with_token() -> None:
await response
assert response.done()
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
assert long_running_agent.called
assert long_running_agent.cancelled
@@ -83,8 +83,8 @@ async def test_cancellation_with_token() -> None:
async def test_nested_cancellation_only_outer_called() -> None:
runtime = SingleThreadedAgentRuntime()
long_running = runtime.register_and_get("long_running", LongRunningAgent)
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
@@ -100,10 +100,10 @@ async def test_nested_cancellation_only_outer_called() -> None:
await response
assert response.done()
nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
assert long_running_agent.called is False
assert long_running_agent.cancelled is False
@@ -111,8 +111,8 @@ async def test_nested_cancellation_only_outer_called() -> None:
async def test_nested_cancellation_inner_called() -> None:
runtime = SingleThreadedAgentRuntime()
long_running = runtime.register_and_get("long_running", LongRunningAgent )
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
long_running = await runtime.register_and_get("long_running", LongRunningAgent )
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
@@ -130,9 +130,9 @@ async def test_nested_cancellation_inner_called() -> None:
await response
assert response.done()
nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
assert long_running_agent.called
assert long_running_agent.cancelled

View File

@@ -28,7 +28,7 @@ async def test_register_receives_publish() -> None:
namespace = id.namespace
await queue.put((namespace, message.content))
runtime.register("name", lambda: ClosureAgent("My agent", log_message))
await runtime.register("name", lambda: ClosureAgent("My agent", log_message))
run_context = runtime.start()
await runtime.publish_message(Message("first message"), namespace="default")
await runtime.publish_message(Message("second message"), namespace="default")

View File

@@ -19,7 +19,7 @@ async def test_intervention_count_messages() -> None:
handler = DebugInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
loopback = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
_response = await runtime.send_message(MessageType(), recipient=loopback)
@@ -27,7 +27,7 @@ async def test_intervention_count_messages() -> None:
await run_context.stop()
assert handler.num_messages == 1
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
assert loopback_agent.num_calls == 1
@pytest.mark.asyncio
@@ -40,7 +40,7 @@ async def test_intervention_drop_send() -> None:
handler = DropSendInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
loopback = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
with pytest.raises(MessageDroppedException):
@@ -48,7 +48,7 @@ async def test_intervention_drop_send() -> None:
await run_context.stop()
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
assert loopback_agent.num_calls == 0
@@ -62,7 +62,7 @@ async def test_intervention_drop_response() -> None:
handler = DropResponseInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
loopback = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
with pytest.raises(MessageDroppedException):
@@ -84,7 +84,7 @@ async def test_intervention_raise_exception_on_send() -> None:
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
long_running = runtime.register_and_get("name", LoopbackAgent)
long_running = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
with pytest.raises(InterventionException):
@@ -92,7 +92,7 @@ async def test_intervention_raise_exception_on_send() -> None:
await run_context.stop()
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
assert long_running_agent.num_calls == 0
@pytest.mark.asyncio
@@ -108,12 +108,12 @@ async def test_intervention_raise_exception_on_respond() -> None:
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
long_running = runtime.register_and_get("name", LoopbackAgent)
long_running = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
with pytest.raises(InterventionException):
_response = await runtime.send_message(MessageType(), recipient=long_running)
await run_context.stop()
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
assert long_running_agent.num_calls == 1

View File

@@ -14,31 +14,31 @@ async def test_agent_names_must_be_unique() -> None:
assert agent.id == id
return agent
agent1 = runtime.register_and_get("name1", agent_factory)
agent1 = await runtime.register_and_get("name1", agent_factory)
assert agent1 == AgentId("name1", "default")
with pytest.raises(ValueError):
_agent1 = runtime.register_and_get("name1", NoopAgent)
_agent1 = await runtime.register_and_get("name1", NoopAgent)
_agent1 = runtime.register_and_get("name3", NoopAgent)
_agent1 = await runtime.register_and_get("name3", NoopAgent)
@pytest.mark.asyncio
async def test_register_receives_publish() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.register("name", LoopbackAgent)
await runtime.register("name", LoopbackAgent)
run_context = runtime.start()
await runtime.publish_message(MessageType(), namespace="default")
await run_context.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name")) # type: ignore
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore
other_long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name", namespace="other")) # type: ignore
assert other_long_running_agent.num_calls == 0
@@ -54,7 +54,7 @@ async def test_register_receives_publish_cascade() -> None:
# Register agents
for i in range(num_agents):
runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
run_context = runtime.start()
@@ -67,5 +67,5 @@ async def test_register_receives_publish_cascade() -> None:
# Check that each agent received the correct number of messages.
for i in range(num_agents):
agent: CascadingAgent = runtime._get_agent(runtime.get(f"name{i}")) # type: ignore
agent: CascadingAgent = await runtime._get_agent(await runtime.get(f"name{i}")) # type: ignore
assert agent.num_calls == total_num_calls_expected

View File

@@ -5,8 +5,8 @@ from agnext.application import SingleThreadedAgentRuntime
from agnext.core import BaseAgent, CancellationToken
class StatefulAgent(BaseAgent): # type: ignore
def __init__(self) -> None: # type: ignore
class StatefulAgent(BaseAgent):
def __init__(self) -> None:
super().__init__("A stateful agent", [])
self.state = 0
@@ -14,7 +14,7 @@ class StatefulAgent(BaseAgent): # type: ignore
def subscriptions(self) -> Sequence[type]:
return []
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> None:
raise NotImplementedError
def save_state(self) -> Mapping[str, Any]:
@@ -28,8 +28,8 @@ class StatefulAgent(BaseAgent): # type: ignore
async def test_agent_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = runtime.register_and_get("name1", StatefulAgent)
agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
assert agent1.state == 0
agent1.state = 1
assert agent1.state == 1
@@ -46,19 +46,19 @@ async def test_agent_can_save_state() -> None:
async def test_runtime_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = runtime.register_and_get("name1", StatefulAgent)
agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
assert agent1.state == 0
agent1.state = 1
assert agent1.state == 1
runtime_state = runtime.save_state()
runtime_state = await runtime.save_state()
runtime2 = SingleThreadedAgentRuntime()
agent2_id = runtime2.register_and_get("name1", StatefulAgent)
agent2: StatefulAgent = runtime2._get_agent(agent2_id) # type: ignore
agent2_id = await runtime2.register_and_get("name1", StatefulAgent)
agent2: StatefulAgent = await runtime2._get_agent(agent2_id) # type: ignore
runtime2.load_state(runtime_state)
await runtime2.load_state(runtime_state)
assert agent2.state == 1