Change send/publish api to better support async and represent reality (#137)

* Make send and publish better represent reality

* fix team-one
This commit is contained in:
Jack Gerrits
2024-06-27 13:40:12 -04:00
committed by GitHub
parent 905e2e3f95
commit a13c971b16
32 changed files with 257 additions and 116 deletions

View File

@@ -27,8 +27,9 @@ class Agent(Protocol):
Returns:
Any: Response to the message. Can be None.
Notes:
If there was a cancellation, this function should raise a `CancelledError`.
Raises:
asyncio.CancelledError: If the message was cancelled.
CantHandleException: If the agent cannot handle the message.
"""
...

View File

@@ -26,14 +26,14 @@ class AgentProxy:
"""Metadata of the agent."""
return self._runtime.agent_metadata(self._agent)
def send_message(
async def send_message(
self,
message: Any,
*,
sender: AgentId,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]:
return self._runtime.send_message(
return await self._runtime.send_message(
message,
recipient=self._agent,
sender=sender,

View File

@@ -20,24 +20,64 @@ agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextV
@runtime_checkable
class AgentRuntime(Protocol):
# Returns the response of the message
def send_message(
# Can raise CantHandleException
async def send_message(
self,
message: Any,
recipient: AgentId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]: ...
) -> Future[Any]:
"""Send a message to an agent and return a future that will resolve to the response.
The act of sending a message may be asynchronous, and the response to the message itself is also asynchronous. For example:
.. code-block:: python
response_future = await runtime.send_message(MyMessage("Hello"), recipient=agent_id)
response = await response_future
The returned future only needs to be awaited if the response is needed. If the response is not needed, the future can be ignored.
Args:
message (Any): The message to send.
recipient (AgentId): The agent to send the message to.
sender (AgentId | None, optional): Agent which sent the message. Should **only** be None if this was sent from no agent, such as directly to the runtime externally. Defaults to None.
cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress . Defaults to None.
Raises:
CantHandleException: If the recipient cannot handle the message.
UndeliverableException: If the message cannot be delivered.
Returns:
Future[Any]: A future that will resolve to the response of the message.
"""
...
# No responses from publishing
def publish_message(
async def publish_message(
self,
message: Any,
*,
namespace: str | None = None,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[None]: ...
) -> None:
"""Publish a message to all agents in the given namespace, or if no namespace is provided, the namespace of the sender.
No responses are expected from publishing.
Args:
message (Any): The message to publish.
namespace (str | None, optional): The namespace to publish to. Defaults to None.
sender (AgentId | None, optional): The agent which sent the message. Defaults to None.
cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress . Defaults to None.
Raises:
UndeliverableException: If the message cannot be delivered.
"""
@overload
def register(
@@ -62,7 +102,7 @@ class AgentRuntime(Protocol):
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent.
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent, where T is a concrete Agent type.
Example:
@@ -82,8 +122,29 @@ class AgentRuntime(Protocol):
...
def get(self, name: str, *, namespace: str = "default") -> AgentId: ...
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: ...
def get(self, name: str, *, namespace: str = "default") -> AgentId:
"""Get an agent by name and namespace.
Args:
name (str): The name of the agent.
namespace (str, optional): The namespace of the agent. Defaults to "default".
Returns:
AgentId: The agent id.
"""
...
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
"""Get a proxy for an agent by name and namespace.
Args:
name (str): The name of the agent.
namespace (str, optional): The namespace of the agent. Defaults to "default".
Returns:
AgentProxy: The agent proxy.
"""
...
@overload
def register_and_get(
@@ -110,6 +171,16 @@ class AgentRuntime(Protocol):
*,
namespace: str = "default",
) -> AgentId:
"""Register an agent factory with the runtime associated with a specific name and get the agent id. The name must be unique.
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent, where T is a concrete Agent type.
namespace (str, optional): The namespace of the agent. Defaults to "default".
Returns:
AgentId: The agent id.
"""
self.register(name, agent_factory)
return self.get(name, namespace=namespace)
@@ -138,15 +209,66 @@ class AgentRuntime(Protocol):
*,
namespace: str = "default",
) -> AgentProxy:
"""Register an agent factory with the runtime associated with a specific name and get the agent proxy. The name must be unique.
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent, where T is a concrete Agent type.
namespace (str, optional): The namespace of the agent. Defaults to "default".
Returns:
AgentProxy: The agent proxy.
"""
self.register(name, agent_factory)
return self.get_proxy(name, namespace=namespace)
def save_state(self) -> Mapping[str, Any]: ...
def save_state(self) -> Mapping[str, Any]:
"""Save the state of the entire runtime, including all hosted agents. The only way to restore the state is to pass it to :meth:`load_state`.
def load_state(self, state: Mapping[str, Any]) -> None: ...
The structure of the state is implementation defined and can be any JSON serializable object.
def agent_metadata(self, agent: AgentId) -> AgentMetadata: ...
Returns:
Mapping[str, Any]: The saved state.
"""
...
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: ...
def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the entire runtime, including all hosted agents. The state should be the same as the one returned by :meth:`save_state`.
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: ...
Args:
state (Mapping[str, Any]): The saved state.
"""
...
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
"""Get the metadata for an agent.
Args:
agent (AgentId): The agent id.
Returns:
AgentMetadata: The agent metadata.
"""
...
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
"""Save the state of a single agent.
The structure of the state is implementation defined and can be any JSON serializable object.
Args:
agent (AgentId): The agent id.
Returns:
Mapping[str, Any]: The saved state.
"""
...
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
"""Load the state of a single agent.
Args:
agent (AgentId): The agent id.
state (Mapping[str, Any]): The saved state.
"""
...

View File

@@ -49,18 +49,18 @@ class BaseAgent(ABC, Agent):
@abstractmethod
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: ...
# Returns the response of the message
def send_message(
async def send_message(
self,
message: Any,
recipient: AgentId,
*,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]:
"""See :py:meth:`agnext.core.AgentRuntime.send_message` for more information."""
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._runtime.send_message(
future = await self._runtime.send_message(
message,
sender=self.id,
recipient=recipient,
@@ -69,17 +69,13 @@ class BaseAgent(ABC, Agent):
cancellation_token.link_future(future)
return future
def publish_message(
async def publish_message(
self,
message: Any,
*,
cancellation_token: CancellationToken | None = None,
) -> Future[None]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._runtime.publish_message(message, sender=self.id, cancellation_token=cancellation_token)
return future
) -> None:
await self._runtime.publish_message(message, sender=self.id, cancellation_token=cancellation_token)
def save_state(self) -> Mapping[str, Any]:
warnings.warn("save_state not implemented", stacklevel=2)