From 2d74fa9cafef002069cf45e377f170f3f0da961d Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sat, 22 Jun 2024 14:50:32 -0400 Subject: [PATCH] add example of use contextvar (#105) --- python/pyproject.toml | 1 + python/src/agnext/core/_agent_id.py | 3 +++ python/tests/test_base_agent.py | 19 +++++++++++++++++++ python/tests/test_runtime.py | 23 +++++++++++------------ python/tests/test_utils/__init__.py | 12 ++++++++++-- 5 files changed, 44 insertions(+), 14 deletions(-) create mode 100644 python/tests/test_base_agent.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 9b407ca00..ce4ecd7c1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "pip", "pytest", "pytest-xdist", + "pytest-mock" ] [tool.hatch.envs.default.extra-scripts] diff --git a/python/src/agnext/core/_agent_id.py b/python/src/agnext/core/_agent_id.py index 222bae2e4..1a39482e6 100644 --- a/python/src/agnext/core/_agent_id.py +++ b/python/src/agnext/core/_agent_id.py @@ -12,6 +12,9 @@ class AgentId: def __hash__(self) -> int: return hash((self._namespace, self._name)) + def __repr__(self) -> str: + return f"AgentId({self._name}, {self._namespace})" + def __eq__(self, value: object) -> bool: if not isinstance(value, AgentId): return False diff --git a/python/tests/test_base_agent.py b/python/tests/test_base_agent.py new file mode 100644 index 000000000..81e364ced --- /dev/null +++ b/python/tests/test_base_agent.py @@ -0,0 +1,19 @@ +import pytest +from pytest_mock import MockerFixture +from agnext.core import AgentRuntime, agent_instantiation_context, AgentId + +from test_utils import NoopAgent + + + +@pytest.mark.asyncio +async def test_base_agent_create(mocker: MockerFixture) -> None: + runtime = mocker.Mock(spec=AgentRuntime) + + # Shows how to set the context for the agent instantiation in a test context + agent_instantiation_context.set((runtime, AgentId("name", "namespace"))) + + agent = NoopAgent() + assert agent.runtime == runtime + assert agent.id == AgentId("name", "namespace") + diff --git a/python/tests/test_runtime.py b/python/tests/test_runtime.py index ea5798c25..c9683f16b 100644 --- a/python/tests/test_runtime.py +++ b/python/tests/test_runtime.py @@ -1,23 +1,22 @@ -from typing import Any - import pytest + from agnext.application import SingleThreadedAgentRuntime -from agnext.core import BaseAgent, CancellationToken -from test_utils import LoopbackAgent, MessageType +from agnext.core import AgentId, AgentRuntime +from test_utils import LoopbackAgent, MessageType, NoopAgent -class NoopAgent(BaseAgent): # type: ignore - def __init__(self) -> None: # type: ignore - super().__init__("A no op agent", []) - - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore - raise NotImplementedError - @pytest.mark.asyncio async def test_agent_names_must_be_unique() -> None: runtime = SingleThreadedAgentRuntime() - _agent1 = runtime.register_and_get("name1", NoopAgent) + def agent_factory(runtime: AgentRuntime, id: AgentId) -> NoopAgent: + assert id == AgentId("name1", "default") + agent = NoopAgent() + assert agent.id == id + return agent + + agent1 = runtime.register_and_get("name1", agent_factory) + assert agent1 == AgentId("name1", "default") with pytest.raises(ValueError): _agent1 = runtime.register_and_get("name1", NoopAgent) diff --git a/python/tests/test_utils/__init__.py b/python/tests/test_utils/__init__.py index 12b679774..694bc19c2 100644 --- a/python/tests/test_utils/__init__.py +++ b/python/tests/test_utils/__init__.py @@ -1,7 +1,8 @@ from dataclasses import dataclass +from typing import Any from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import CancellationToken +from agnext.core import CancellationToken, BaseAgent @dataclass @@ -17,4 +18,11 @@ class LoopbackAgent(TypeRoutedAgent): @message_handler async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: self.num_calls += 1 - return message \ No newline at end of file + return message + +class NoopAgent(BaseAgent): + def __init__(self) -> None: + super().__init__("A no op agent", []) + + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: + raise NotImplementedError \ No newline at end of file