simplify namespace usage (#116)

* simplify namespace usage

* format

* pyright
This commit is contained in:
Jack Gerrits
2024-06-24 16:52:09 -04:00
committed by GitHub
parent 606e43b325
commit 6189fdb05c
9 changed files with 32 additions and 96 deletions

View File

@@ -6,7 +6,7 @@ from asyncio import Future
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
from ..core import (
Agent,
@@ -14,7 +14,6 @@ from ..core import (
AgentMetadata,
AgentProxy,
AgentRuntime,
AllNamespaces,
CancellationToken,
agent_instantiation_context,
)
@@ -82,15 +81,13 @@ class Counter:
class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None:
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
# (namespace, type) -> List[AgentId]
self._per_type_subscribers: DefaultDict[tuple[str, type], Set[AgentId]] = defaultdict(set)
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
# If empty, then all namespaces are valid for that agent type
self._valid_namespaces: Dict[str, Sequence[str]] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
self._before_send = before_send
self._intervention_handler = intervention_handler
self._known_namespaces: set[str] = set()
self._outstanding_tasks = Counter()
@@ -322,9 +319,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
match message_envelope:
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
if self._intervention_handler is not None:
try:
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
temp_message = await self._intervention_handler.on_send(
message, sender=sender, recipient=recipient
)
except BaseException as e:
future.set_exception(e)
return
@@ -339,9 +338,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message=message,
sender=sender,
):
if self._before_send is not None:
if self._intervention_handler is not None:
try:
temp_message = await self._before_send.on_publish(message, sender=sender)
temp_message = await self._intervention_handler.on_publish(message, sender=sender)
except BaseException as e:
# TODO: we should raise the intervention exception to the publisher.
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
@@ -354,9 +353,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._outstanding_tasks.increment()
asyncio.create_task(self._process_publish(message_envelope))
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
if self._intervention_handler is not None:
try:
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
temp_message = await self._intervention_handler.on_response(
message, sender=sender, recipient=recipient
)
except BaseException as e:
# TODO: should we raise the exception to sender of the response instead?
future.set_exception(e)
@@ -385,21 +386,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> None:
if name in self._agent_factories:
raise ValueError(f"Agent with name {name} already exists.")
self._agent_factories[name] = agent_factory
if valid_namespaces is not AllNamespaces:
self._valid_namespaces[name] = cast(Sequence[str], valid_namespaces)
else:
self._valid_namespaces[name] = []
# For all already prepared namespaces we need to prepare this agent
for namespace in self._known_namespaces:
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
self._get_agent(AgentId(name=name, namespace=namespace))
self._get_agent(AgentId(name=name, namespace=namespace))
def _invoke_agent_factory(
self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId
@@ -419,24 +413,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return agent
def _type_valid_for_namespace(self, agent_id: AgentId) -> bool:
if agent_id.name not in self._agent_factories:
raise KeyError(f"Agent with name {agent_id.name} not found.")
valid_namespaces = self._valid_namespaces[agent_id.name]
if len(valid_namespaces) == 0:
return True
return agent_id.namespace in valid_namespaces
def _get_agent(self, agent_id: AgentId) -> Agent:
self._process_seen_namespace(agent_id.namespace)
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
if not self._type_valid_for_namespace(agent_id):
raise ValueError(f"Agent with name {agent_id.name} not valid for namespace {agent_id.namespace}.")
if agent_id.name not in self._agent_factories:
raise ValueError(f"Agent with name {agent_id.name} not found.")
@@ -463,5 +444,4 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._known_namespaces.add(namespace)
for name in self._known_agent_names:
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
self._get_agent(AgentId(name=name, namespace=namespace))
self._get_agent(AgentId(name=name, namespace=namespace))

View File

@@ -7,7 +7,7 @@ from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._agent_props import AgentChildren
from ._agent_proxy import AgentProxy
from ._agent_runtime import AgentRuntime, AllNamespaces, agent_instantiation_context
from ._agent_runtime import AgentRuntime, agent_instantiation_context
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
@@ -17,7 +17,6 @@ __all__ = [
"AgentProxy",
"AgentMetadata",
"AgentRuntime",
"AllNamespaces",
"BaseAgent",
"CancellationToken",
"AgentChildren",

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from asyncio import Future
from contextvars import ContextVar
from typing import Any, Callable, Mapping, Protocol, Sequence, Type, TypeVar, overload, runtime_checkable
from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable
from ._agent import Agent
from ._agent_id import AgentId
@@ -17,10 +17,6 @@ T = TypeVar("T", bound=Agent)
agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar("agent_instantiation_context")
class AllNamespaces:
pass
@runtime_checkable
class AgentRuntime(Protocol):
# Returns the response of the message
@@ -45,7 +41,9 @@ class AgentRuntime(Protocol):
@overload
def register(
self, name: str, agent_factory: Callable[[], T], *, valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...
self,
name: str,
agent_factory: Callable[[], T],
) -> None: ...
@overload
@@ -53,23 +51,18 @@ class AgentRuntime(Protocol):
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T],
*,
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> None: ...
def register(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> None:
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent.
valid_namespaces (Sequence[str] | Type[AllNamespaces], optional): Valid namespaces for this type. Defaults to AllNamespaces.
Example:
@@ -99,7 +92,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentId: ...
@overload
@@ -109,7 +101,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentId: ...
def register_and_get(
@@ -118,7 +109,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> AgentId:
self.register(name, agent_factory)
return self.get(name, namespace=namespace)
@@ -130,7 +120,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentProxy: ...
@overload
@@ -140,7 +129,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentProxy: ...
def register_and_get_proxy(
@@ -149,7 +137,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> AgentProxy:
self.register(name, agent_factory)
return self.get_proxy(name, namespace=namespace)