mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
simplify namespace usage (#116)
* simplify namespace usage * format * pyright
This commit is contained in:
@@ -6,7 +6,7 @@ from asyncio import Future
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
|
||||
|
||||
from ..core import (
|
||||
Agent,
|
||||
@@ -14,7 +14,6 @@ from ..core import (
|
||||
AgentMetadata,
|
||||
AgentProxy,
|
||||
AgentRuntime,
|
||||
AllNamespaces,
|
||||
CancellationToken,
|
||||
agent_instantiation_context,
|
||||
)
|
||||
@@ -82,15 +81,13 @@ class Counter:
|
||||
|
||||
|
||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
|
||||
def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None:
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._per_type_subscribers: DefaultDict[tuple[str, type], Set[AgentId]] = defaultdict(set)
|
||||
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
|
||||
# If empty, then all namespaces are valid for that agent type
|
||||
self._valid_namespaces: Dict[str, Sequence[str]] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._before_send = before_send
|
||||
self._intervention_handler = intervention_handler
|
||||
self._known_namespaces: set[str] = set()
|
||||
self._outstanding_tasks = Counter()
|
||||
|
||||
@@ -322,9 +319,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
match message_envelope:
|
||||
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
if self._intervention_handler is not None:
|
||||
try:
|
||||
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
|
||||
temp_message = await self._intervention_handler.on_send(
|
||||
message, sender=sender, recipient=recipient
|
||||
)
|
||||
except BaseException as e:
|
||||
future.set_exception(e)
|
||||
return
|
||||
@@ -339,9 +338,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
message=message,
|
||||
sender=sender,
|
||||
):
|
||||
if self._before_send is not None:
|
||||
if self._intervention_handler is not None:
|
||||
try:
|
||||
temp_message = await self._before_send.on_publish(message, sender=sender)
|
||||
temp_message = await self._intervention_handler.on_publish(message, sender=sender)
|
||||
except BaseException as e:
|
||||
# TODO: we should raise the intervention exception to the publisher.
|
||||
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
|
||||
@@ -354,9 +353,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_publish(message_envelope))
|
||||
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
if self._intervention_handler is not None:
|
||||
try:
|
||||
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
|
||||
temp_message = await self._intervention_handler.on_response(
|
||||
message, sender=sender, recipient=recipient
|
||||
)
|
||||
except BaseException as e:
|
||||
# TODO: should we raise the exception to sender of the response instead?
|
||||
future.set_exception(e)
|
||||
@@ -385,21 +386,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> None:
|
||||
if name in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {name} already exists.")
|
||||
self._agent_factories[name] = agent_factory
|
||||
if valid_namespaces is not AllNamespaces:
|
||||
self._valid_namespaces[name] = cast(Sequence[str], valid_namespaces)
|
||||
else:
|
||||
self._valid_namespaces[name] = []
|
||||
|
||||
# For all already prepared namespaces we need to prepare this agent
|
||||
for namespace in self._known_namespaces:
|
||||
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
||||
def _invoke_agent_factory(
|
||||
self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId
|
||||
@@ -419,24 +413,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
return agent
|
||||
|
||||
def _type_valid_for_namespace(self, agent_id: AgentId) -> bool:
|
||||
if agent_id.name not in self._agent_factories:
|
||||
raise KeyError(f"Agent with name {agent_id.name} not found.")
|
||||
|
||||
valid_namespaces = self._valid_namespaces[agent_id.name]
|
||||
if len(valid_namespaces) == 0:
|
||||
return True
|
||||
|
||||
return agent_id.namespace in valid_namespaces
|
||||
|
||||
def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
self._process_seen_namespace(agent_id.namespace)
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
if not self._type_valid_for_namespace(agent_id):
|
||||
raise ValueError(f"Agent with name {agent_id.name} not valid for namespace {agent_id.namespace}.")
|
||||
|
||||
if agent_id.name not in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {agent_id.name} not found.")
|
||||
|
||||
@@ -463,5 +444,4 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
self._known_namespaces.add(namespace)
|
||||
for name in self._known_agent_names:
|
||||
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
||||
@@ -7,7 +7,7 @@ from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_props import AgentChildren
|
||||
from ._agent_proxy import AgentProxy
|
||||
from ._agent_runtime import AgentRuntime, AllNamespaces, agent_instantiation_context
|
||||
from ._agent_runtime import AgentRuntime, agent_instantiation_context
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
@@ -17,7 +17,6 @@ __all__ = [
|
||||
"AgentProxy",
|
||||
"AgentMetadata",
|
||||
"AgentRuntime",
|
||||
"AllNamespaces",
|
||||
"BaseAgent",
|
||||
"CancellationToken",
|
||||
"AgentChildren",
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from asyncio import Future
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Mapping, Protocol, Sequence, Type, TypeVar, overload, runtime_checkable
|
||||
from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
@@ -17,10 +17,6 @@ T = TypeVar("T", bound=Agent)
|
||||
agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar("agent_instantiation_context")
|
||||
|
||||
|
||||
class AllNamespaces:
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentRuntime(Protocol):
|
||||
# Returns the response of the message
|
||||
@@ -45,7 +41,9 @@ class AgentRuntime(Protocol):
|
||||
|
||||
@overload
|
||||
def register(
|
||||
self, name: str, agent_factory: Callable[[], T], *, valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T],
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
@@ -53,23 +51,18 @@ class AgentRuntime(Protocol):
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> None: ...
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> None:
|
||||
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
|
||||
|
||||
Args:
|
||||
name (str): The name of the type agent this factory creates.
|
||||
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent.
|
||||
valid_namespaces (Sequence[str] | Type[AllNamespaces], optional): Valid namespaces for this type. Defaults to AllNamespaces.
|
||||
|
||||
|
||||
Example:
|
||||
@@ -99,7 +92,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentId: ...
|
||||
|
||||
@overload
|
||||
@@ -109,7 +101,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentId: ...
|
||||
|
||||
def register_and_get(
|
||||
@@ -118,7 +109,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> AgentId:
|
||||
self.register(name, agent_factory)
|
||||
return self.get(name, namespace=namespace)
|
||||
@@ -130,7 +120,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentProxy: ...
|
||||
|
||||
@overload
|
||||
@@ -140,7 +129,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentProxy: ...
|
||||
|
||||
def register_and_get_proxy(
|
||||
@@ -149,7 +137,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> AgentProxy:
|
||||
self.register(name, agent_factory)
|
||||
return self.get_proxy(name, namespace=namespace)
|
||||
|
||||
Reference in New Issue
Block a user