From 6189fdb05cf129b2b650655439815f581aeb2e29 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Mon, 24 Jun 2024 16:52:09 -0400 Subject: [PATCH] simplify namespace usage (#116) * simplify namespace usage * format * pyright --- python/docs/src/core-concepts/namespace.md | 7 +-- .../guides/termination-with-intervention.md | 2 +- python/examples/coder_reviewer_pub_sub.py | 2 +- python/examples/mixture_of_agents_pub_sub.py | 2 +- .../_single_threaded_agent_runtime.py | 50 ++++++------------- python/src/agnext/core/__init__.py | 3 +- python/src/agnext/core/_agent_runtime.py | 21 ++------ python/tests/test_intervention.py | 10 ++-- python/tests/test_runtime.py | 31 ------------ 9 files changed, 32 insertions(+), 96 deletions(-) diff --git a/python/docs/src/core-concepts/namespace.md b/python/docs/src/core-concepts/namespace.md index 235aad7d8..16f83c039 100644 --- a/python/docs/src/core-concepts/namespace.md +++ b/python/docs/src/core-concepts/namespace.md @@ -1,7 +1,6 @@ # Namespace -A namespace is a logical boundary between agents. By default, agents in one -namespace cannot communicate with agents in another namespace. +Namespace allow for defining logical boundaries between agents. Namespaces are strings, and the default is `default`. @@ -15,4 +14,6 @@ Two possible use cases of agents are: The {py:class}`agnext.core.AgentId` is used to address an agent, it is the combination of the agent's namespace and its name. -When getting an agent reference ({py:meth}`agnext.core.AgentRuntime.get`) or proxy ({py:meth}`agnext.core.AgentRuntime.get_proxy`) from the runtime the namespace can be specified. Agents have an ID property ({py:attr}`agnext.core.Agent.id`) that returns the agent's id. Additionally, the register method takes a factory that can optionally accept the ID as an argument ({py:meth}`agnext.core.AgentRuntime.register`). \ No newline at end of file +When getting an agent reference ({py:meth}`agnext.core.AgentRuntime.get`) or proxy ({py:meth}`agnext.core.AgentRuntime.get_proxy`) from the runtime the namespace can be specified. Agents have an ID property ({py:attr}`agnext.core.Agent.id`) that returns the agent's id. Additionally, the register method takes a factory that can optionally accept the ID as an argument ({py:meth}`agnext.core.AgentRuntime.register`). + +By default, there are no restrictions and are left to the application to enforce. The runtime will however automatically create agents in a namespace if it does not exist. diff --git a/python/docs/src/guides/termination-with-intervention.md b/python/docs/src/guides/termination-with-intervention.md index 1723df4a4..4e4161439 100644 --- a/python/docs/src/guides/termination-with-intervention.md +++ b/python/docs/src/guides/termination-with-intervention.md @@ -68,7 +68,7 @@ Finally, we add this handler to the runtime and use it to detect termination and async def main() -> None: termination_handler = TerminationHandler() runtime = SingleThreadedAgentRuntime( - before_send=termination_handler + intervention_handler=termination_handler ) # Add Agents and kick off task diff --git a/python/examples/coder_reviewer_pub_sub.py b/python/examples/coder_reviewer_pub_sub.py index c9815fde9..a20f0a073 100644 --- a/python/examples/coder_reviewer_pub_sub.py +++ b/python/examples/coder_reviewer_pub_sub.py @@ -267,7 +267,7 @@ class DisplayAgent(TypeRoutedAgent): async def main() -> None: termination_handler = TerminationHandler() - runtime = SingleThreadedAgentRuntime(before_send=termination_handler) + runtime = SingleThreadedAgentRuntime(intervention_handler=termination_handler) runtime.register( "ReviewerAgent", lambda: ReviewerAgent( diff --git a/python/examples/mixture_of_agents_pub_sub.py b/python/examples/mixture_of_agents_pub_sub.py index 6eff0ec43..b134929b4 100644 --- a/python/examples/mixture_of_agents_pub_sub.py +++ b/python/examples/mixture_of_agents_pub_sub.py @@ -133,7 +133,7 @@ class DisplayAgent(TypeRoutedAgent): async def main() -> None: termination_handler = TerminationHandler() - runtime = SingleThreadedAgentRuntime(before_send=termination_handler) + runtime = SingleThreadedAgentRuntime(intervention_handler=termination_handler) # TODO: use different models for each agent. runtime.register( "ReferenceAgent1", diff --git a/python/src/agnext/application/_single_threaded_agent_runtime.py b/python/src/agnext/application/_single_threaded_agent_runtime.py index 0db1ae5e3..83f101074 100644 --- a/python/src/agnext/application/_single_threaded_agent_runtime.py +++ b/python/src/agnext/application/_single_threaded_agent_runtime.py @@ -6,7 +6,7 @@ from asyncio import Future from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast +from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast from ..core import ( Agent, @@ -14,7 +14,6 @@ from ..core import ( AgentMetadata, AgentProxy, AgentRuntime, - AllNamespaces, CancellationToken, agent_instantiation_context, ) @@ -82,15 +81,13 @@ class Counter: class SingleThreadedAgentRuntime(AgentRuntime): - def __init__(self, *, before_send: InterventionHandler | None = None) -> None: + def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None: self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] # (namespace, type) -> List[AgentId] self._per_type_subscribers: DefaultDict[tuple[str, type], Set[AgentId]] = defaultdict(set) self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {} - # If empty, then all namespaces are valid for that agent type - self._valid_namespaces: Dict[str, Sequence[str]] = {} self._instantiated_agents: Dict[AgentId, Agent] = {} - self._before_send = before_send + self._intervention_handler = intervention_handler self._known_namespaces: set[str] = set() self._outstanding_tasks = Counter() @@ -322,9 +319,11 @@ class SingleThreadedAgentRuntime(AgentRuntime): match message_envelope: case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): - if self._before_send is not None: + if self._intervention_handler is not None: try: - temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient) + temp_message = await self._intervention_handler.on_send( + message, sender=sender, recipient=recipient + ) except BaseException as e: future.set_exception(e) return @@ -339,9 +338,9 @@ class SingleThreadedAgentRuntime(AgentRuntime): message=message, sender=sender, ): - if self._before_send is not None: + if self._intervention_handler is not None: try: - temp_message = await self._before_send.on_publish(message, sender=sender) + temp_message = await self._intervention_handler.on_publish(message, sender=sender) except BaseException as e: # TODO: we should raise the intervention exception to the publisher. logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) @@ -354,9 +353,11 @@ class SingleThreadedAgentRuntime(AgentRuntime): self._outstanding_tasks.increment() asyncio.create_task(self._process_publish(message_envelope)) case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): - if self._before_send is not None: + if self._intervention_handler is not None: try: - temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient) + temp_message = await self._intervention_handler.on_response( + message, sender=sender, recipient=recipient + ) except BaseException as e: # TODO: should we raise the exception to sender of the response instead? future.set_exception(e) @@ -385,21 +386,14 @@ class SingleThreadedAgentRuntime(AgentRuntime): self, name: str, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], - *, - valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces, ) -> None: if name in self._agent_factories: raise ValueError(f"Agent with name {name} already exists.") self._agent_factories[name] = agent_factory - if valid_namespaces is not AllNamespaces: - self._valid_namespaces[name] = cast(Sequence[str], valid_namespaces) - else: - self._valid_namespaces[name] = [] # For all already prepared namespaces we need to prepare this agent for namespace in self._known_namespaces: - if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)): - self._get_agent(AgentId(name=name, namespace=namespace)) + self._get_agent(AgentId(name=name, namespace=namespace)) def _invoke_agent_factory( self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId @@ -419,24 +413,11 @@ class SingleThreadedAgentRuntime(AgentRuntime): return agent - def _type_valid_for_namespace(self, agent_id: AgentId) -> bool: - if agent_id.name not in self._agent_factories: - raise KeyError(f"Agent with name {agent_id.name} not found.") - - valid_namespaces = self._valid_namespaces[agent_id.name] - if len(valid_namespaces) == 0: - return True - - return agent_id.namespace in valid_namespaces - def _get_agent(self, agent_id: AgentId) -> Agent: self._process_seen_namespace(agent_id.namespace) if agent_id in self._instantiated_agents: return self._instantiated_agents[agent_id] - if not self._type_valid_for_namespace(agent_id): - raise ValueError(f"Agent with name {agent_id.name} not valid for namespace {agent_id.namespace}.") - if agent_id.name not in self._agent_factories: raise ValueError(f"Agent with name {agent_id.name} not found.") @@ -463,5 +444,4 @@ class SingleThreadedAgentRuntime(AgentRuntime): self._known_namespaces.add(namespace) for name in self._known_agent_names: - if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)): - self._get_agent(AgentId(name=name, namespace=namespace)) + self._get_agent(AgentId(name=name, namespace=namespace)) diff --git a/python/src/agnext/core/__init__.py b/python/src/agnext/core/__init__.py index 210e87fd3..083456716 100644 --- a/python/src/agnext/core/__init__.py +++ b/python/src/agnext/core/__init__.py @@ -7,7 +7,7 @@ from ._agent_id import AgentId from ._agent_metadata import AgentMetadata from ._agent_props import AgentChildren from ._agent_proxy import AgentProxy -from ._agent_runtime import AgentRuntime, AllNamespaces, agent_instantiation_context +from ._agent_runtime import AgentRuntime, agent_instantiation_context from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken @@ -17,7 +17,6 @@ __all__ = [ "AgentProxy", "AgentMetadata", "AgentRuntime", - "AllNamespaces", "BaseAgent", "CancellationToken", "AgentChildren", diff --git a/python/src/agnext/core/_agent_runtime.py b/python/src/agnext/core/_agent_runtime.py index a372eb1fc..0236878ae 100644 --- a/python/src/agnext/core/_agent_runtime.py +++ b/python/src/agnext/core/_agent_runtime.py @@ -2,7 +2,7 @@ from __future__ import annotations from asyncio import Future from contextvars import ContextVar -from typing import Any, Callable, Mapping, Protocol, Sequence, Type, TypeVar, overload, runtime_checkable +from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable from ._agent import Agent from ._agent_id import AgentId @@ -17,10 +17,6 @@ T = TypeVar("T", bound=Agent) agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar("agent_instantiation_context") -class AllNamespaces: - pass - - @runtime_checkable class AgentRuntime(Protocol): # Returns the response of the message @@ -45,7 +41,9 @@ class AgentRuntime(Protocol): @overload def register( - self, name: str, agent_factory: Callable[[], T], *, valid_namespaces: Sequence[str] | Type[AllNamespaces] = ... + self, + name: str, + agent_factory: Callable[[], T], ) -> None: ... @overload @@ -53,23 +51,18 @@ class AgentRuntime(Protocol): self, name: str, agent_factory: Callable[[AgentRuntime, AgentId], T], - *, - valid_namespaces: Sequence[str] | Type[AllNamespaces] = ..., ) -> None: ... def register( self, name: str, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], - *, - valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces, ) -> None: """Register an agent factory with the runtime associated with a specific name. The name must be unique. Args: name (str): The name of the type agent this factory creates. agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent. - valid_namespaces (Sequence[str] | Type[AllNamespaces], optional): Valid namespaces for this type. Defaults to AllNamespaces. Example: @@ -99,7 +92,6 @@ class AgentRuntime(Protocol): agent_factory: Callable[[], T], *, namespace: str = "default", - valid_namespaces: Sequence[str] | Type[AllNamespaces] = ..., ) -> AgentId: ... @overload @@ -109,7 +101,6 @@ class AgentRuntime(Protocol): agent_factory: Callable[[AgentRuntime, AgentId], T], *, namespace: str = "default", - valid_namespaces: Sequence[str] | Type[AllNamespaces] = ..., ) -> AgentId: ... def register_and_get( @@ -118,7 +109,6 @@ class AgentRuntime(Protocol): agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], *, namespace: str = "default", - valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces, ) -> AgentId: self.register(name, agent_factory) return self.get(name, namespace=namespace) @@ -130,7 +120,6 @@ class AgentRuntime(Protocol): agent_factory: Callable[[], T], *, namespace: str = "default", - valid_namespaces: Sequence[str] | Type[AllNamespaces] = ..., ) -> AgentProxy: ... @overload @@ -140,7 +129,6 @@ class AgentRuntime(Protocol): agent_factory: Callable[[AgentRuntime, AgentId], T], *, namespace: str = "default", - valid_namespaces: Sequence[str] | Type[AllNamespaces] = ..., ) -> AgentProxy: ... def register_and_get_proxy( @@ -149,7 +137,6 @@ class AgentRuntime(Protocol): agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], *, namespace: str = "default", - valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces, ) -> AgentProxy: self.register(name, agent_factory) return self.get_proxy(name, namespace=namespace) diff --git a/python/tests/test_intervention.py b/python/tests/test_intervention.py index ae3683e8b..88e673528 100644 --- a/python/tests/test_intervention.py +++ b/python/tests/test_intervention.py @@ -18,7 +18,7 @@ async def test_intervention_count_messages() -> None: return message handler = DebugInterventionHandler() - runtime = SingleThreadedAgentRuntime(before_send=handler) + runtime = SingleThreadedAgentRuntime(intervention_handler=handler) loopback = runtime.register_and_get("name", LoopbackAgent) response = runtime.send_message(MessageType(), recipient=loopback) @@ -38,7 +38,7 @@ async def test_intervention_drop_send() -> None: return DropMessage handler = DropSendInterventionHandler() - runtime = SingleThreadedAgentRuntime(before_send=handler) + runtime = SingleThreadedAgentRuntime(intervention_handler=handler) loopback = runtime.register_and_get("name", LoopbackAgent) response = runtime.send_message(MessageType(), recipient=loopback) @@ -61,7 +61,7 @@ async def test_intervention_drop_response() -> None: return DropMessage handler = DropResponseInterventionHandler() - runtime = SingleThreadedAgentRuntime(before_send=handler) + runtime = SingleThreadedAgentRuntime(intervention_handler=handler) loopback = runtime.register_and_get("name", LoopbackAgent) response = runtime.send_message(MessageType(), recipient=loopback) @@ -84,7 +84,7 @@ async def test_intervention_raise_exception_on_send() -> None: raise InterventionException handler = ExceptionInterventionHandler() - runtime = SingleThreadedAgentRuntime(before_send=handler) + runtime = SingleThreadedAgentRuntime(intervention_handler=handler) long_running = runtime.register_and_get("name", LoopbackAgent) response = runtime.send_message(MessageType(), recipient=long_running) @@ -109,7 +109,7 @@ async def test_intervention_raise_exception_on_respond() -> None: raise InterventionException handler = ExceptionInterventionHandler() - runtime = SingleThreadedAgentRuntime(before_send=handler) + runtime = SingleThreadedAgentRuntime(intervention_handler=handler) long_running = runtime.register_and_get("name", LoopbackAgent) response = runtime.send_message(MessageType(), recipient=long_running) diff --git a/python/tests/test_runtime.py b/python/tests/test_runtime.py index c9683f16b..8a663a56e 100644 --- a/python/tests/test_runtime.py +++ b/python/tests/test_runtime.py @@ -41,34 +41,3 @@ async def test_register_receives_publish() -> None: other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore assert other_long_running_agent.num_calls == 0 - - -@pytest.mark.asyncio -async def test_try_instantiate_agent_invalid_namespace() -> None: - runtime = SingleThreadedAgentRuntime() - - runtime.register("name", LoopbackAgent, valid_namespaces=["default"]) - await runtime.publish_message(MessageType(), namespace="non_default") - - while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0: - await runtime.process_next() - - # Agent in default namespace should have received the message - long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore - assert long_running_agent.num_calls == 0 - - with pytest.raises(ValueError): - _agent = runtime.get("name", namespace="non_default") - -@pytest.mark.asyncio -async def test_send_crosses_namepace() -> None: - runtime = SingleThreadedAgentRuntime() - - runtime.register("name", LoopbackAgent) - - default_ns_agent = runtime.get("name") - non_default_ns_agent = runtime.get("name", namespace="non_default") - - with pytest.raises(ValueError): - await runtime.send_message(MessageType(), default_ns_agent, sender=non_default_ns_agent) -