From afdbae2bb9b154072b666d5d71f373f7a2f111c1 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Fri, 21 Jun 2024 10:47:51 -0400 Subject: [PATCH] remove binding from base agent (#100) --- .../_single_threaded_agent_runtime.py | 18 +++++++--- python/src/agnext/core/__init__.py | 3 +- python/src/agnext/core/_agent_runtime.py | 3 ++ python/src/agnext/core/_base_agent.py | 36 ++++++------------- 4 files changed, 28 insertions(+), 32 deletions(-) diff --git a/python/src/agnext/application/_single_threaded_agent_runtime.py b/python/src/agnext/application/_single_threaded_agent_runtime.py index 6a59d21c2..0db1ae5e3 100644 --- a/python/src/agnext/application/_single_threaded_agent_runtime.py +++ b/python/src/agnext/application/_single_threaded_agent_runtime.py @@ -8,7 +8,16 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast -from ..core import Agent, AgentId, AgentMetadata, AgentProxy, AgentRuntime, AllNamespaces, BaseAgent, CancellationToken +from ..core import ( + Agent, + AgentId, + AgentMetadata, + AgentProxy, + AgentRuntime, + AllNamespaces, + CancellationToken, + agent_instantiation_context, +) from ..core.exceptions import MessageDroppedException from ..core.intervention import DropMessage, InterventionHandler @@ -395,6 +404,8 @@ class SingleThreadedAgentRuntime(AgentRuntime): def _invoke_agent_factory( self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId ) -> T: + token = agent_instantiation_context.set((self, agent_id)) + if len(inspect.signature(agent_factory).parameters) == 0: factory_one = cast(Callable[[], T], agent_factory) agent = factory_one() @@ -404,10 +415,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): else: raise ValueError("Agent factory must take 0 or 2 arguments.") - # TODO: should this be part of the base agent interface? - if isinstance(agent, BaseAgent): - agent.bind_id(agent_id) - agent.bind_runtime(self) + agent_instantiation_context.reset(token) return agent diff --git a/python/src/agnext/core/__init__.py b/python/src/agnext/core/__init__.py index 76f404387..210e87fd3 100644 --- a/python/src/agnext/core/__init__.py +++ b/python/src/agnext/core/__init__.py @@ -7,7 +7,7 @@ from ._agent_id import AgentId from ._agent_metadata import AgentMetadata from ._agent_props import AgentChildren from ._agent_proxy import AgentProxy -from ._agent_runtime import AgentRuntime, AllNamespaces +from ._agent_runtime import AgentRuntime, AllNamespaces, agent_instantiation_context from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken @@ -21,4 +21,5 @@ __all__ = [ "BaseAgent", "CancellationToken", "AgentChildren", + "agent_instantiation_context", ] diff --git a/python/src/agnext/core/_agent_runtime.py b/python/src/agnext/core/_agent_runtime.py index 12b2f97c5..a372eb1fc 100644 --- a/python/src/agnext/core/_agent_runtime.py +++ b/python/src/agnext/core/_agent_runtime.py @@ -1,6 +1,7 @@ from __future__ import annotations from asyncio import Future +from contextvars import ContextVar from typing import Any, Callable, Mapping, Protocol, Sequence, Type, TypeVar, overload, runtime_checkable from ._agent import Agent @@ -13,6 +14,8 @@ from ._cancellation_token import CancellationToken T = TypeVar("T", bound=Agent) +agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar("agent_instantiation_context") + class AllNamespaces: pass diff --git a/python/src/agnext/core/_base_agent.py b/python/src/agnext/core/_base_agent.py index 239bc7a1c..041328f62 100644 --- a/python/src/agnext/core/_base_agent.py +++ b/python/src/agnext/core/_base_agent.py @@ -6,7 +6,7 @@ from typing import Any, Mapping, Sequence from ._agent import Agent from ._agent_id import AgentId from ._agent_metadata import AgentMetadata -from ._agent_runtime import AgentRuntime +from ._agent_runtime import AgentRuntime, agent_instantiation_context from ._cancellation_token import CancellationToken @@ -22,38 +22,28 @@ class BaseAgent(ABC, Agent): ) def __init__(self, description: str, subscriptions: Sequence[type]) -> None: - self._runtime: AgentRuntime | None = None - self._id: AgentId | None = None + try: + runtime, id = agent_instantiation_context.get() + except LookupError as e: + raise RuntimeError( + "BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated." + ) from e + + self._runtime: AgentRuntime = runtime + self._id: AgentId = id self._description = description self._subscriptions = subscriptions - def bind_runtime(self, runtime: AgentRuntime) -> None: - if self._runtime is not None: - raise RuntimeError("Agent has already been bound to a runtime.") - - self._runtime = runtime - - def bind_id(self, agent_id: AgentId) -> None: - if self._id is not None: - raise RuntimeError("Agent has already been bound to an id.") - self._id = agent_id - @property def name(self) -> str: return self.id.name @property def id(self) -> AgentId: - if self._id is None: - raise RuntimeError("Agent has not been bound to an id.") - return self._id @property def runtime(self) -> AgentRuntime: - if self._runtime is None: - raise RuntimeError("Agent has not been bound to a runtime.") - return self._runtime @abstractmethod @@ -67,9 +57,6 @@ class BaseAgent(ABC, Agent): *, cancellation_token: CancellationToken | None = None, ) -> Future[Any]: - if self._runtime is None: - raise RuntimeError("Agent has not been bound to a runtime.") - if cancellation_token is None: cancellation_token = CancellationToken() @@ -88,9 +75,6 @@ class BaseAgent(ABC, Agent): *, cancellation_token: CancellationToken | None = None, ) -> Future[None]: - if self._runtime is None: - raise RuntimeError("Agent has not been bound to a runtime.") - if cancellation_token is None: cancellation_token = CancellationToken()