diff --git a/examples/futures.py b/examples/futures.py index 70cf9ee43..c60d9d74c 100644 --- a/examples/futures.py +++ b/examples/futures.py @@ -5,6 +5,7 @@ from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_h from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime from agnext.core.agent import Agent from agnext.core.agent_runtime import AgentRuntime +from agnext.core.cancellation_token import CancellationToken @dataclass @@ -13,30 +14,36 @@ class MessageType: sender: str -class Inner(TypeRoutedAgent[MessageType]): - def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None: +class Inner(TypeRoutedAgent): + def __init__(self, name: str, router: AgentRuntime) -> None: super().__init__(name, router) @message_handler(MessageType) - async def on_new_message(self, message: MessageType) -> MessageType: + async def on_new_message( + self, message: MessageType, require_response: bool, cancellation_token: CancellationToken + ) -> MessageType: + assert require_response return MessageType(body=f"Inner: {message.body}", sender=self.name) -class Outer(TypeRoutedAgent[MessageType]): - def __init__(self, name: str, router: AgentRuntime[MessageType], inner: Agent[MessageType]) -> None: +class Outer(TypeRoutedAgent): + def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None: super().__init__(name, router) self._inner = inner @message_handler(MessageType) - async def on_new_message(self, message: MessageType) -> MessageType: - inner_response = self._send_message(message, self._inner) + async def on_new_message( + self, message: MessageType, require_response: bool, cancellation_token: CancellationToken + ) -> MessageType: + assert require_response + inner_response = self._send_message(message, self._inner, require_response=True) inner_message = await inner_response + assert isinstance(inner_message, MessageType) return MessageType(body=f"Outer: {inner_message.body}", sender=self.name) async def main() -> None: - router = SingleThreadedAgentRuntime[MessageType]() - + router = SingleThreadedAgentRuntime() inner = Inner("inner", router) outer = Outer("outer", router, inner) response = router.send_message(MessageType(body="Hello", sender="external"), outer) diff --git a/examples/patterns.py b/examples/patterns.py index 7a9ef1cfd..1cb5707cc 100644 --- a/examples/patterns.py +++ b/examples/patterns.py @@ -3,15 +3,15 @@ import asyncio import openai from agnext.agent_components.models_clients.openai_client import OpenAI +from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent from agnext.chat.messages import ChatMessage from agnext.chat.patterns.group_chat import GroupChat from agnext.chat.patterns.orchestrator import Orchestrator -from agnext.chat.runtimes import SingleThreadedRuntime async def group_chat() -> None: - runtime = SingleThreadedRuntime() + runtime = SingleThreadedAgentRuntime() joe_oai_assistant = openai.beta.assistants.create( model="gpt-3.5-turbo", @@ -56,11 +56,11 @@ async def group_chat() -> None: while not response.done(): await runtime.process_next() - print((await response).body) + print((await response).body) # type: ignore async def orchestrator() -> None: - runtime = SingleThreadedRuntime() + runtime = SingleThreadedAgentRuntime() developer_oai_assistant = openai.beta.assistants.create( model="gpt-3.5-turbo", @@ -111,7 +111,7 @@ async def orchestrator() -> None: while not response.done(): await runtime.process_next() - print((await response).body) + print((await response).body) # type: ignore if __name__ == "__main__": diff --git a/src/agnext/agent_components/type_routed_agent.py b/src/agnext/agent_components/type_routed_agent.py index e9dc39e2a..490f3c03c 100644 --- a/src/agnext/agent_components/type_routed_agent.py +++ b/src/agnext/agent_components/type_routed_agent.py @@ -1,29 +1,39 @@ -from typing import Any, Awaitable, Callable, Dict, Sequence, Type, TypeVar +from typing import Any, Callable, Coroutine, Dict, NoReturn, Sequence, Type, TypeVar from agnext.core.agent_runtime import AgentRuntime from agnext.core.base_agent import BaseAgent from agnext.core.cancellation_token import CancellationToken from agnext.core.exceptions import CantHandleException -T = TypeVar("T") +ReceivesT = TypeVar("ReceivesT") +ProducesT = TypeVar("ProducesT", covariant=True) + +# TODO: Generic typevar bound binding U to agent type +# Can't do because python doesnt support it # NOTE: this works on concrete types and not inheritance def message_handler( - target_type: Type[T], -) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]: - def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]: + target_type: Type[ReceivesT], +) -> Callable[ + [Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]], + Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]], +]: + def decorator( + func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]], + ) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]: func._target_type = target_type # type: ignore return func return decorator -class TypeRoutedAgent(BaseAgent[T]): - def __init__(self, name: str, router: AgentRuntime[T]) -> None: +class TypeRoutedAgent(BaseAgent): + def __init__(self, name: str, router: AgentRuntime) -> None: super().__init__(name, router) - self._handlers: Dict[Type[Any], Callable[[T, CancellationToken], Awaitable[T]]] = {} + # Self is already bound to the handlers + self._handlers: Dict[Type[Any], Callable[[Any, bool, CancellationToken], Coroutine[Any, Any, Any | None]]] = {} router.add_agent(self) @@ -31,19 +41,23 @@ class TypeRoutedAgent(BaseAgent[T]): if callable(getattr(self, attr, None)): handler = getattr(self, attr) if hasattr(handler, "_target_type"): - # TODO do i need to partially apply self? self._handlers[handler._target_type] = handler @property - def subscriptions(self) -> Sequence[Type[T]]: + def subscriptions(self) -> Sequence[Type[Any]]: return list(self._handlers.keys()) - async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: - handler = self._handlers.get(type(message)) + async def on_message( + self, message: Any, require_response: bool, cancellation_token: CancellationToken + ) -> Any | None: + key_type: Type[Any] = type(message) # type: ignore + handler = self._handlers.get(key_type) # type: ignore if handler is not None: - return await handler(message, cancellation_token) + return await handler(message, require_response, cancellation_token) else: - return await self.on_unhandled_message(message, cancellation_token) + return await self.on_unhandled_message(message, require_response, cancellation_token) - async def on_unhandled_message(self, message: T, cancellation_token: CancellationToken) -> T: + async def on_unhandled_message( + self, message: Any, require_response: bool, cancellation_token: CancellationToken + ) -> NoReturn: raise CantHandleException() diff --git a/src/agnext/application_components/single_threaded_agent_runtime.py b/src/agnext/application_components/single_threaded_agent_runtime.py index 749f431f3..b21a382b8 100644 --- a/src/agnext/application_components/single_threaded_agent_runtime.py +++ b/src/agnext/application_components/single_threaded_agent_runtime.py @@ -1,7 +1,7 @@ import asyncio from asyncio import Future from dataclasses import dataclass -from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar, cast +from typing import Any, Awaitable, Dict, List, Sequence, Set, cast from agnext.core.cancellation_token import CancellationToken from agnext.core.exceptions import MessageDroppedException @@ -10,64 +10,61 @@ from agnext.core.intervention import DropMessage, InterventionHandler from ..core.agent import Agent from ..core.agent_runtime import AgentRuntime -T = TypeVar("T") - @dataclass(kw_only=True) -class BroadcastMessageEnvelope(Generic[T]): +class BroadcastMessageEnvelope: """A message envelope for broadcasting messages to all agents that can handle the message of the type T.""" - message: T - future: Future[List[T]] + message: Any + future: Future[Sequence[Any] | None] cancellation_token: CancellationToken - sender: Agent[T] | None + sender: Agent | None + require_response: bool @dataclass(kw_only=True) -class SendMessageEnvelope(Generic[T]): +class SendMessageEnvelope: """A message envelope for sending a message to a specific agent that can handle the message of the type T.""" - message: T - sender: Agent[T] | None - recipient: Agent[T] - future: Future[T] + message: Any + sender: Agent | None + recipient: Agent + future: Future[Any | None] cancellation_token: CancellationToken + require_response: bool @dataclass(kw_only=True) -class ResponseMessageEnvelope(Generic[T]): +class ResponseMessageEnvelope: """A message envelope for sending a response to a message.""" - message: T - future: Future[T] - sender: Agent[T] - recipient: Agent[T] | None + message: Any + future: Future[Any] + sender: Agent + recipient: Agent | None @dataclass(kw_only=True) -class BroadcastResponseMessageEnvelope(Generic[T]): +class BroadcastResponseMessageEnvelope: """A message envelope for sending a response to a message.""" - message: List[T] - future: Future[List[T]] - recipient: Agent[T] | None + message: Sequence[Any] + future: Future[Sequence[Any]] + recipient: Agent | None -class SingleThreadedAgentRuntime(AgentRuntime[T]): - def __init__(self, *, before_send: InterventionHandler[T] | None = None) -> None: +class SingleThreadedAgentRuntime(AgentRuntime): + def __init__(self, *, before_send: InterventionHandler | None = None) -> None: self._message_queue: List[ - BroadcastMessageEnvelope[T] - | SendMessageEnvelope[T] - | ResponseMessageEnvelope[T] - | BroadcastResponseMessageEnvelope[T] + BroadcastMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope | BroadcastResponseMessageEnvelope ] = [] - self._per_type_subscribers: Dict[Type[T], List[Agent[T]]] = {} - self._agents: Set[Agent[T]] = set() + self._per_type_subscribers: Dict[type, List[Agent]] = {} + self._agents: Set[Agent] = set() self._before_send = before_send - def add_agent(self, agent: Agent[T]) -> None: + def add_agent(self, agent: Agent) -> None: for message_type in agent.subscriptions: if message_type not in self._per_type_subscribers: self._per_type_subscribers[message_type] = [] @@ -77,17 +74,19 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]): # Returns the response of the message def send_message( self, - message: T, - recipient: Agent[T], + message: Any, + recipient: Agent, *, - sender: Agent[T] | None = None, + require_response: bool = True, + sender: Agent | None = None, cancellation_token: CancellationToken | None = None, - ) -> Future[T]: + ) -> Future[Any | None]: if cancellation_token is None: cancellation_token = CancellationToken() - loop = asyncio.get_event_loop() - future: Future[T] = loop.create_future() + future = asyncio.get_event_loop().create_future() + if recipient not in self._agents: + future.set_exception(Exception("Recipient not found")) self._message_queue.append( SendMessageEnvelope( @@ -96,53 +95,78 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]): future=future, cancellation_token=cancellation_token, sender=sender, + require_response=require_response, ) ) return future - # Returns the response of all handling agents + # send message, require_response=False -> returns after delivery, gives None + # send message, require_response=True -> returns after handling, gives Response def broadcast_message( - self, message: T, *, sender: Agent[T] | None = None, cancellation_token: CancellationToken | None = None - ) -> Future[List[T]]: + self, + message: Any, + *, + require_response: bool = True, + sender: Agent | None = None, + cancellation_token: CancellationToken | None = None, + ) -> Future[Sequence[Any] | None]: if cancellation_token is None: cancellation_token = CancellationToken() - future: Future[List[T]] = asyncio.get_event_loop().create_future() + future = asyncio.get_event_loop().create_future() self._message_queue.append( BroadcastMessageEnvelope( - message=message, future=future, cancellation_token=cancellation_token, sender=sender + message=message, + future=future, + cancellation_token=cancellation_token, + sender=sender, + require_response=require_response, ) ) + return future - async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None: + async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: recipient = message_envelope.recipient - if recipient not in self._agents: - message_envelope.future.set_exception(Exception("Recipient not found")) - return + assert recipient in self._agents try: response = await recipient.on_message( - message_envelope.message, cancellation_token=message_envelope.cancellation_token + message_envelope.message, + require_response=message_envelope.require_response, + cancellation_token=message_envelope.cancellation_token, ) except BaseException as e: message_envelope.future.set_exception(e) return - self._message_queue.append( - ResponseMessageEnvelope( - message=response, - future=message_envelope.future, - sender=message_envelope.recipient, - recipient=message_envelope.sender, - ) - ) + if not message_envelope.require_response and response is not None: + raise Exception("Recipient returned a response for a message that did not request a response") - async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T]) -> None: - responses: List[Awaitable[T]] = [] - for agent in self._per_type_subscribers.get(type(message_envelope.message), []): - future = agent.on_message(message_envelope.message, cancellation_token=message_envelope.cancellation_token) + if message_envelope.require_response and response is None: + raise Exception("Recipient did not return a response for a message that requested a response") + + if message_envelope.require_response: + self._message_queue.append( + ResponseMessageEnvelope( + message=response, + future=message_envelope.future, + sender=message_envelope.recipient, + recipient=message_envelope.sender, + ) + ) + else: + message_envelope.future.set_result(None) + + async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope) -> None: + responses: List[Awaitable[Any]] = [] + for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore + future = agent.on_message( + message_envelope.message, + require_response=message_envelope.require_response, + cancellation_token=message_envelope.cancellation_token, + ) responses.append(future) try: @@ -151,16 +175,21 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]): message_envelope.future.set_exception(e) return - self._message_queue.append( - BroadcastResponseMessageEnvelope( - message=all_responses, future=message_envelope.future, recipient=message_envelope.sender + if message_envelope.require_response: + self._message_queue.append( + BroadcastResponseMessageEnvelope( + message=all_responses, + future=cast(Future[Sequence[Any]], message_envelope.future), + recipient=message_envelope.sender, + ) ) - ) + else: + message_envelope.future.set_result(None) - async def _process_response(self, message_envelope: ResponseMessageEnvelope[T]) -> None: + async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: message_envelope.future.set_result(message_envelope.message) - async def _process_broadcast_response(self, message_envelope: BroadcastResponseMessageEnvelope[T]) -> None: + async def _process_broadcast_response(self, message_envelope: BroadcastResponseMessageEnvelope) -> None: message_envelope.future.set_result(message_envelope.message) async def process_next(self) -> None: @@ -179,7 +208,7 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]): future.set_exception(MessageDroppedException()) return - message_envelope.message = cast(T, temp_message) + message_envelope.message = temp_message asyncio.create_task(self._process_send(message_envelope)) case BroadcastMessageEnvelope( @@ -193,7 +222,7 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]): future.set_exception(MessageDroppedException()) return - message_envelope.message = cast(T, temp_message) + message_envelope.message = temp_message asyncio.create_task(self._process_broadcast(message_envelope)) case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): @@ -203,7 +232,7 @@ class SingleThreadedAgentRuntime(AgentRuntime[T]): future.set_exception(MessageDroppedException()) return - message_envelope.message = cast(T, temp_message) + message_envelope.message = temp_message asyncio.create_task(self._process_response(message_envelope)) diff --git a/src/agnext/chat/agents/base.py b/src/agnext/chat/agents/base.py index 34a6ccd36..3cece5005 100644 --- a/src/agnext/chat/agents/base.py +++ b/src/agnext/chat/agents/base.py @@ -1,13 +1,12 @@ -from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler -from ...core.cancellation_token import CancellationToken -from ..messages import ChatMessage -from ..runtimes import SingleThreadedRuntime +from agnext.core.agent_runtime import AgentRuntime + +from ...agent_components.type_routed_agent import TypeRoutedAgent -class BaseChatAgent(TypeRoutedAgent[ChatMessage]): +class BaseChatAgent(TypeRoutedAgent): """The BaseAgent class for the chat API.""" - def __init__(self, name: str, description: str, runtime: SingleThreadedRuntime) -> None: + def __init__(self, name: str, description: str, runtime: AgentRuntime) -> None: super().__init__(name, runtime) self._description = description @@ -15,15 +14,3 @@ class BaseChatAgent(TypeRoutedAgent[ChatMessage]): def description(self) -> str: """The description of the agent.""" return self._description - - async def on_chat_message(self, message: ChatMessage) -> ChatMessage: - """The method to handle chat messages.""" - raise NotImplementedError - - # TODO: how should we expose cancellation in chat layer? - @message_handler(ChatMessage) - async def on_chat_message_with_cancellation( - self, message: ChatMessage, cancellation_token: CancellationToken - ) -> ChatMessage: - """The method to handle chat messages with cancellation.""" - return await self.on_chat_message(message) diff --git a/src/agnext/chat/agents/oai_assistant.py b/src/agnext/chat/agents/oai_assistant.py index 891eab25a..11ef741f9 100644 --- a/src/agnext/chat/agents/oai_assistant.py +++ b/src/agnext/chat/agents/oai_assistant.py @@ -1,8 +1,11 @@ import openai -from ..agents.base import BaseChatAgent +from agnext.agent_components.type_routed_agent import message_handler +from agnext.chat.agents.base import BaseChatAgent +from agnext.core.agent_runtime import AgentRuntime +from agnext.core.cancellation_token import CancellationToken + from ..messages import ChatMessage -from ..runtimes import SingleThreadedRuntime class OpenAIAssistantAgent(BaseChatAgent): @@ -10,7 +13,7 @@ class OpenAIAssistantAgent(BaseChatAgent): self, name: str, description: str, - runtime: SingleThreadedRuntime, + runtime: AgentRuntime, client: openai.AsyncClient, assistant_id: str, thread_id: str, @@ -21,7 +24,11 @@ class OpenAIAssistantAgent(BaseChatAgent): self._thread_id = thread_id self._current_session_window_length = 0 - async def on_chat_message(self, message: ChatMessage) -> ChatMessage: + # TODO: use require_response + @message_handler(ChatMessage) + async def on_chat_message_with_cancellation( + self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken + ) -> ChatMessage | None: print("---------------") print(f"{self.name} received message from {message.sender}: {message.body}") print("---------------") diff --git a/src/agnext/chat/agents/random_agent.py b/src/agnext/chat/agents/random_agent.py index 67029318a..0581ac437 100644 --- a/src/agnext/chat/agents/random_agent.py +++ b/src/agnext/chat/agents/random_agent.py @@ -1,11 +1,18 @@ import random +from agnext.agent_components.type_routed_agent import message_handler +from agnext.core.cancellation_token import CancellationToken + from ..agents.base import BaseChatAgent from ..messages import ChatMessage class RandomResponseAgent(BaseChatAgent): - async def on_chat_message(self, message: ChatMessage) -> ChatMessage: + # TODO: use require_response + @message_handler(ChatMessage) + async def on_chat_message_with_cancellation( + self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken + ) -> ChatMessage | None: print(f"{self.name} received message from {message.sender}: {message.body}") if message.save_message_only: return ChatMessage(body="OK", sender=self.name) diff --git a/src/agnext/chat/patterns/group_chat.py b/src/agnext/chat/patterns/group_chat.py index 2a6c90d90..f837b6f27 100644 --- a/src/agnext/chat/patterns/group_chat.py +++ b/src/agnext/chat/patterns/group_chat.py @@ -1,8 +1,8 @@ from typing import List, Sequence +from ...core.agent_runtime import AgentRuntime from ..agents.base import BaseChatAgent from ..messages import ChatMessage -from ..runtimes import SingleThreadedRuntime class GroupChat(BaseChatAgent): @@ -10,7 +10,7 @@ class GroupChat(BaseChatAgent): self, name: str, description: str, - runtime: SingleThreadedRuntime, + runtime: AgentRuntime, agents: Sequence[BaseChatAgent], num_rounds: int, ) -> None: @@ -61,8 +61,9 @@ class GroupChat(BaseChatAgent): speaker, ) - # 4. Append the response to the history. - self._history.append(response) + if response is not None: + # 4. Append the response to the history. + self._history.append(response) # 5. Update the previous speaker. previous_speaker = speaker diff --git a/src/agnext/chat/patterns/orchestrator.py b/src/agnext/chat/patterns/orchestrator.py index 30f914a56..0c5e4915d 100644 --- a/src/agnext/chat/patterns/orchestrator.py +++ b/src/agnext/chat/patterns/orchestrator.py @@ -1,11 +1,14 @@ import json from typing import Any, List, Sequence, Tuple +from agnext.core.agent_runtime import AgentRuntime +from agnext.core.cancellation_token import CancellationToken + from ...agent_components.model_client import ModelClient +from ...agent_components.type_routed_agent import message_handler from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage from ..agents.base import BaseChatAgent from ..messages import ChatMessage -from ..runtimes import SingleThreadedRuntime class Orchestrator(BaseChatAgent): @@ -13,7 +16,7 @@ class Orchestrator(BaseChatAgent): self, name: str, description: str, - runtime: SingleThreadedRuntime, + runtime: AgentRuntime, agents: Sequence[BaseChatAgent], model_client: ModelClient, max_turns: int = 30, @@ -28,7 +31,10 @@ class Orchestrator(BaseChatAgent): self._max_retry_attempts_before_educated_guess = max_retry_attempts self._history: List[ChatMessage] = [] - async def on_chat_message(self, message: ChatMessage) -> ChatMessage: + @message_handler(ChatMessage) + async def on_chat_message( + self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken + ) -> ChatMessage: # A task is received. task = message.body @@ -169,6 +175,8 @@ Some additional points to consider: speaker, ) + assert speaker_response is not None + # Update the ledger. ledger.append( AssistantMessage( diff --git a/src/agnext/chat/runtimes.py b/src/agnext/chat/runtimes.py deleted file mode 100644 index 1284ee232..000000000 --- a/src/agnext/chat/runtimes.py +++ /dev/null @@ -1,12 +0,0 @@ -from ..application_components.single_threaded_agent_runtime import ( - SingleThreadedAgentRuntime, -) -from .messages import ChatMessage - - -# The built-in runtime for the chat API. -class SingleThreadedRuntime(SingleThreadedAgentRuntime[ChatMessage]): - pass - - -# Each new built-in runtime should be able to handle ChatMessage type. diff --git a/src/agnext/core/agent.py b/src/agnext/core/agent.py index d6078be68..038ffd2a3 100644 --- a/src/agnext/core/agent.py +++ b/src/agnext/core/agent.py @@ -1,15 +1,16 @@ -from typing import Protocol, Sequence, Type, TypeVar +from typing import Any, Protocol, Sequence, runtime_checkable from agnext.core.cancellation_token import CancellationToken -T = TypeVar("T") - -class Agent(Protocol[T]): +@runtime_checkable +class Agent(Protocol): @property def name(self) -> str: ... @property - def subscriptions(self) -> Sequence[Type[T]]: ... + def subscriptions(self) -> Sequence[type]: ... - async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ... + async def on_message( + self, message: Any, require_response: bool, cancellation_token: CancellationToken + ) -> Any | None: ... diff --git a/src/agnext/core/agent_runtime.py b/src/agnext/core/agent_runtime.py index aae215246..67f558c6f 100644 --- a/src/agnext/core/agent_runtime.py +++ b/src/agnext/core/agent_runtime.py @@ -1,28 +1,32 @@ from asyncio import Future -from typing import List, Protocol, TypeVar +from typing import Any, Protocol, Sequence from agnext.core.agent import Agent from agnext.core.cancellation_token import CancellationToken -T = TypeVar("T") - # Undeliverable - error -class AgentRuntime(Protocol[T]): - def add_agent(self, agent: Agent[T]) -> None: ... +class AgentRuntime(Protocol): + def add_agent(self, agent: Agent) -> None: ... # Returns the response of the message def send_message( self, - message: T, - recipient: Agent[T], + message: Any, + recipient: Agent, *, - sender: Agent[T] | None = None, - cancellation_token: CancellationToken | None, - ) -> Future[T]: ... + require_response: bool = True, + sender: Agent | None = None, + cancellation_token: CancellationToken | None = None, + ) -> Future[Any | None]: ... # Returns the response of all handling agents def broadcast_message( - self, message: T, *, sender: Agent[T] | None = None, cancellation_token: CancellationToken | None = None - ) -> Future[List[T]]: ... + self, + message: Any, + *, + require_response: bool = True, + sender: Agent | None = None, + cancellation_token: CancellationToken | None = None, + ) -> Future[Sequence[Any] | None]: ... diff --git a/src/agnext/core/base_agent.py b/src/agnext/core/base_agent.py index ee90258e5..9560328ac 100644 --- a/src/agnext/core/base_agent.py +++ b/src/agnext/core/base_agent.py @@ -1,17 +1,21 @@ from abc import ABC, abstractmethod from asyncio import Future -from typing import List, Sequence, Type, TypeVar +from typing import Any, Sequence, TypeVar from agnext.core.agent_runtime import AgentRuntime from agnext.core.cancellation_token import CancellationToken from .agent import Agent -T = TypeVar("T") +ConsumesT = TypeVar("ConsumesT") +ProducesT = TypeVar("ProducesT", covariant=True) + +OtherConsumesT = TypeVar("OtherConsumesT") +OtherProducesT = TypeVar("OtherProducesT") -class BaseAgent(ABC, Agent[T]): - def __init__(self, name: str, router: AgentRuntime[T]) -> None: +class BaseAgent(ABC, Agent): + def __init__(self, name: str, router: AgentRuntime) -> None: self._name = name self._router = router @@ -21,26 +25,47 @@ class BaseAgent(ABC, Agent[T]): @property @abstractmethod - def subscriptions(self) -> Sequence[Type[T]]: + def subscriptions(self) -> Sequence[type]: return [] @abstractmethod - async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ... + async def on_message( + self, message: Any, require_response: bool, cancellation_token: CancellationToken + ) -> Any | None: ... + # Returns the response of the message def _send_message( - self, message: T, recipient: Agent[T], cancellation_token: CancellationToken | None = None - ) -> Future[T]: + self, + message: Any, + recipient: Agent, + *, + require_response: bool = True, + cancellation_token: CancellationToken | None = None, + ) -> Future[Any | None]: if cancellation_token is None: cancellation_token = CancellationToken() + future = self._router.send_message( - message, sender=self, recipient=recipient, cancellation_token=cancellation_token + message, + sender=self, + recipient=recipient, + require_response=require_response, + cancellation_token=cancellation_token, ) cancellation_token.link_future(future) return future - def _broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: + # Returns the response of all handling agents + def _broadcast_message( + self, + message: Any, + *, + require_response: bool = True, + cancellation_token: CancellationToken | None = None, + ) -> Future[Sequence[Any] | None]: if cancellation_token is None: cancellation_token = CancellationToken() - future = self._router.broadcast_message(message, sender=self, cancellation_token=cancellation_token) - cancellation_token.link_future(future) + future = self._router.broadcast_message( + message, sender=self, require_response=require_response, cancellation_token=cancellation_token + ) return future diff --git a/src/agnext/core/intervention.py b/src/agnext/core/intervention.py index a8e5833eb..5a002c331 100644 --- a/src/agnext/core/intervention.py +++ b/src/agnext/core/intervention.py @@ -1,4 +1,4 @@ -from typing import Awaitable, Callable, Protocol, Sequence, TypeVar, final +from typing import Any, Awaitable, Callable, Protocol, Sequence, final from agnext.core.agent import Agent @@ -7,33 +7,29 @@ from agnext.core.agent import Agent class DropMessage: ... -T = TypeVar("T") - -InterventionFunction = Callable[[T], T | Awaitable[type[DropMessage]]] +InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]] -class InterventionHandler(Protocol[T]): - async def on_send(self, message: T, *, sender: Agent[T] | None, recipient: Agent[T]) -> T | type[DropMessage]: ... - async def on_broadcast(self, message: T, *, sender: Agent[T] | None) -> T | type[DropMessage]: ... - async def on_response( - self, message: T, *, sender: Agent[T], recipient: Agent[T] | None - ) -> T | type[DropMessage]: ... +class InterventionHandler(Protocol): + async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ... + async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ... + async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ... async def on_broadcast_response( - self, message: Sequence[T], *, recipient: Agent[T] | None - ) -> Sequence[T] | type[DropMessage]: ... + self, message: Sequence[Any], *, recipient: Agent | None + ) -> Sequence[Any] | type[DropMessage]: ... -class DefaultInterventionHandler(InterventionHandler[T]): - async def on_send(self, message: T, *, sender: Agent[T] | None, recipient: Agent[T]) -> T | type[DropMessage]: +class DefaultInterventionHandler(InterventionHandler): + async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: return message - async def on_broadcast(self, message: T, *, sender: Agent[T] | None) -> T | type[DropMessage]: + async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: return message - async def on_response(self, message: T, *, sender: Agent[T], recipient: Agent[T] | None) -> T | type[DropMessage]: + async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: return message async def on_broadcast_response( - self, message: Sequence[T], *, recipient: Agent[T] | None - ) -> Sequence[T] | type[DropMessage]: + self, message: Sequence[Any], *, recipient: Agent | None + ) -> Sequence[Any] | type[DropMessage]: return message diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 253806043..de256ce40 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -16,14 +16,14 @@ class MessageType: # To do cancellation, only the token should be interacted with as a user # If you cancel a future, it may not work as you expect. -class LongRunningAgent(TypeRoutedAgent[MessageType]): - def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None: +class LongRunningAgent(TypeRoutedAgent): + def __init__(self, name: str, router: AgentRuntime) -> None: super().__init__(name, router) self.called = False self.cancelled = False @message_handler(MessageType) - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType: self.called = True sleep = asyncio.ensure_future(asyncio.sleep(100)) cancellation_token.link_future(sleep) @@ -34,19 +34,22 @@ class LongRunningAgent(TypeRoutedAgent[MessageType]): self.cancelled = True raise -class NestingLongRunningAgent(TypeRoutedAgent[MessageType]): - def __init__(self, name: str, router: AgentRuntime[MessageType], nested_agent: Agent[MessageType]) -> None: +class NestingLongRunningAgent(TypeRoutedAgent): + def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None: super().__init__(name, router) self.called = False self.cancelled = False self._nested_agent = nested_agent @message_handler(MessageType) - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType: + assert require_response == True self.called = True - response = self._send_message(message, self._nested_agent, cancellation_token) + response = self._send_message(message, self._nested_agent, require_response=require_response, cancellation_token=cancellation_token) try: - return await response + val = await response + assert isinstance(val, MessageType) + return val except asyncio.CancelledError: self.cancelled = True raise @@ -54,7 +57,7 @@ class NestingLongRunningAgent(TypeRoutedAgent[MessageType]): @pytest.mark.asyncio async def test_cancellation_with_token() -> None: - router = SingleThreadedAgentRuntime[MessageType]() + router = SingleThreadedAgentRuntime() long_running = LongRunningAgent("name", router) token = CancellationToken() @@ -75,7 +78,7 @@ async def test_cancellation_with_token() -> None: @pytest.mark.asyncio async def test_nested_cancellation_only_outer_called() -> None: - router = SingleThreadedAgentRuntime[MessageType]() + router = SingleThreadedAgentRuntime() long_running = LongRunningAgent("name", router) nested = NestingLongRunningAgent("nested", router, long_running) @@ -98,7 +101,7 @@ async def test_nested_cancellation_only_outer_called() -> None: @pytest.mark.asyncio async def test_nested_cancellation_inner_called() -> None: - router = SingleThreadedAgentRuntime[MessageType]() + router = SingleThreadedAgentRuntime() long_running = LongRunningAgent("name", router) nested = NestingLongRunningAgent("nested", router, long_running) diff --git a/tests/test_intervention.py b/tests/test_intervention.py index 79a5c5b03..7750eedf7 100644 --- a/tests/test_intervention.py +++ b/tests/test_intervention.py @@ -13,30 +13,30 @@ from agnext.core.intervention import DefaultInterventionHandler, DropMessage class MessageType: ... -class LoopbackAgent(TypeRoutedAgent[MessageType]): - def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None: +class LoopbackAgent(TypeRoutedAgent): + def __init__(self, name: str, router: AgentRuntime) -> None: super().__init__(name, router) self.num_calls = 0 @message_handler(MessageType) - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType: self.num_calls += 1 return message @pytest.mark.asyncio async def test_intervention_count_messages() -> None: - class DebugInterventionHandler(DefaultInterventionHandler[MessageType]): + class DebugInterventionHandler(DefaultInterventionHandler): def __init__(self): self.num_messages = 0 - async def on_send(self, message: MessageType, *, sender: Agent[MessageType] | None, recipient: Agent[MessageType]) -> MessageType: + async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType: self.num_messages += 1 return message handler = DebugInterventionHandler() - router = SingleThreadedAgentRuntime[MessageType](before_send=handler) + router = SingleThreadedAgentRuntime(before_send=handler) long_running = LoopbackAgent("name", router) response = router.send_message(MessageType(), recipient=long_running) @@ -50,12 +50,12 @@ async def test_intervention_count_messages() -> None: @pytest.mark.asyncio async def test_intervention_drop_send() -> None: - class DropSendInterventionHandler(DefaultInterventionHandler[MessageType]): - async def on_send(self, message: MessageType, *, sender: Agent[MessageType] | None, recipient: Agent[MessageType]) -> MessageType | type[DropMessage]: + class DropSendInterventionHandler(DefaultInterventionHandler): + async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType | type[DropMessage]: return DropMessage handler = DropSendInterventionHandler() - router = SingleThreadedAgentRuntime[MessageType](before_send=handler) + router = SingleThreadedAgentRuntime(before_send=handler) long_running = LoopbackAgent("name", router) response = router.send_message(MessageType(), recipient=long_running) @@ -72,12 +72,12 @@ async def test_intervention_drop_send() -> None: @pytest.mark.asyncio async def test_intervention_drop_response() -> None: - class DropResponseInterventionHandler(DefaultInterventionHandler[MessageType]): - async def on_response(self, message: MessageType, *, sender: Agent[MessageType], recipient: Agent[MessageType] | None) -> MessageType | type[DropMessage]: + class DropResponseInterventionHandler(DefaultInterventionHandler): + async def on_response(self, message: MessageType, *, sender: Agent, recipient: Agent | None) -> MessageType | type[DropMessage]: return DropMessage handler = DropResponseInterventionHandler() - router = SingleThreadedAgentRuntime[MessageType](before_send=handler) + router = SingleThreadedAgentRuntime(before_send=handler) long_running = LoopbackAgent("name", router) response = router.send_message(MessageType(), recipient=long_running)