From ec654253d254a80e3216ee319f5bc6186d596ec7 Mon Sep 17 00:00:00 2001 From: afourney Date: Mon, 29 Jul 2024 13:09:31 -0700 Subject: [PATCH] Introduces a BaseWorker agent, allowing for a TeamOneBaseAgent (#289) --- .../src/team_one/agents/base_agent.py | 73 ++++--------- .../src/team_one/agents/base_orchestrator.py | 29 ++--- .../src/team_one/agents/base_worker.py | 55 ++++++++++ .../team-one/src/team_one/agents/coder.py | 8 +- .../agents/file_surfer/file_surfer.py | 4 +- .../multimodal_web_surfer.py | 9 +- .../src/team_one/agents/user_proxy.py | 4 +- .../teams/team-one/src/team_one/messages.py | 3 + python/teams/team-one/src/team_one/utils.py | 101 +++++++++--------- 9 files changed, 149 insertions(+), 137 deletions(-) create mode 100644 python/teams/team-one/src/team_one/agents/base_worker.py diff --git a/python/teams/team-one/src/team_one/agents/base_agent.py b/python/teams/team-one/src/team_one/agents/base_agent.py index d24348479..3d1f2457b 100644 --- a/python/teams/team-one/src/team_one/agents/base_agent.py +++ b/python/teams/team-one/src/team_one/agents/base_agent.py @@ -1,14 +1,9 @@ import asyncio import logging -from typing import Any, List, Tuple +from typing import Any from agnext.application.logging import EVENT_LOGGER_NAME from agnext.components import TypeRoutedAgent, message_handler -from agnext.components.models import ( - AssistantMessage, - LLMMessage, - UserMessage, -) from agnext.core import CancellationToken from team_one.messages import ( @@ -17,18 +12,14 @@ from team_one.messages import ( DeactivateMessage, RequestReplyMessage, ResetMessage, - UserContent, + TeamOneMessages, ) -from team_one.utils import message_content_to_str -logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator") +logger = logging.getLogger(EVENT_LOGGER_NAME + ".agent") -PossibleMessages = RequestReplyMessage | BroadcastMessage | ResetMessage | DeactivateMessage - - -class BaseAgent(TypeRoutedAgent): - """An agent that handles the RequestReply and Broadcast messages""" +class TeamOneBaseAgent(TypeRoutedAgent): + """An agent that optionally ensures messages are handled non-concurrently in the order they arrive.""" def __init__( self, @@ -36,14 +27,12 @@ class BaseAgent(TypeRoutedAgent): handle_messages_concurrently: bool = False, ) -> None: super().__init__(description) - self._chat_history: List[LLMMessage] = [] - self._enabled: bool = True - self._handle_messages_concurrently = handle_messages_concurrently + self._enabled = True if not self._handle_messages_concurrently: # TODO: make it possible to stop - self._message_queue = asyncio.Queue[tuple[PossibleMessages, CancellationToken, asyncio.Future[Any]]]() + self._message_queue = asyncio.Queue[tuple[TeamOneMessages, CancellationToken, asyncio.Future[Any]]]() self._processing_task = asyncio.create_task(self._process()) async def _process(self) -> None: @@ -54,13 +43,13 @@ class BaseAgent(TypeRoutedAgent): continue if isinstance(message, RequestReplyMessage): - await self.handle_request_reply(message, cancellation_token) + await self._handle_request_reply(message, cancellation_token) elif isinstance(message, BroadcastMessage): - await self.handle_broadcast(message, cancellation_token) + await self._handle_broadcast(message, cancellation_token) elif isinstance(message, ResetMessage): - await self.handle_reset(message, cancellation_token) + await self._handle_reset(message, cancellation_token) elif isinstance(message, DeactivateMessage): - await self.handle_deactivate(message, cancellation_token) + await self._handle_deactivate(message, cancellation_token) else: raise ValueError("Unknown message type.") @@ -77,27 +66,28 @@ class BaseAgent(TypeRoutedAgent): if self._handle_messages_concurrently: if isinstance(message, RequestReplyMessage): - await self.handle_request_reply(message, cancellation_token) + await self._handle_request_reply(message, cancellation_token) elif isinstance(message, BroadcastMessage): - await self.handle_broadcast(message, cancellation_token) + await self._handle_broadcast(message, cancellation_token) elif isinstance(message, ResetMessage): - await self.handle_reset(message, cancellation_token) + await self._handle_reset(message, cancellation_token) elif isinstance(message, DeactivateMessage): - await self.handle_deactivate(message, cancellation_token) + await self._handle_deactivate(message, cancellation_token) else: future = asyncio.Future[Any]() await self._message_queue.put((message, cancellation_token, future)) await future - async def handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None: - assert isinstance(message.content, UserMessage) - self._chat_history.append(message.content) + async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None: + raise NotImplementedError() - async def handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None: - """Handle a reset message.""" - await self._reset(cancellation_token) + async def _handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None: + raise NotImplementedError() - async def handle_deactivate(self, message: DeactivateMessage, cancellation_token: CancellationToken) -> None: + async def _handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None: + raise NotImplementedError() + + async def _handle_deactivate(self, message: DeactivateMessage, cancellation_token: CancellationToken) -> None: """Handle a deactivate message.""" self._enabled = False logger.info( @@ -106,20 +96,3 @@ class BaseAgent(TypeRoutedAgent): "", ) ) - - async def handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None: - """Respond to a reply request.""" - request_halt, response = await self._generate_reply(cancellation_token) - - assistant_message = AssistantMessage(content=message_content_to_str(response), source=self.metadata["name"]) - self._chat_history.append(assistant_message) - - user_message = UserMessage(content=response, source=self.metadata["name"]) - await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt)) - - async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]: - """Returns (request_halt, response_message)""" - raise NotImplementedError() - - async def _reset(self, cancellation_token: CancellationToken) -> None: - self._chat_history = [] diff --git a/python/teams/team-one/src/team_one/agents/base_orchestrator.py b/python/teams/team-one/src/team_one/agents/base_orchestrator.py index 7532d2616..8a86c1a33 100644 --- a/python/teams/team-one/src/team_one/agents/base_orchestrator.py +++ b/python/teams/team-one/src/team_one/agents/base_orchestrator.py @@ -2,36 +2,32 @@ import logging from typing import List, Optional from agnext.application.logging import EVENT_LOGGER_NAME -from agnext.components import TypeRoutedAgent, message_handler from agnext.components.models import AssistantMessage, LLMMessage, UserMessage from agnext.core import AgentProxy, CancellationToken -from ..messages import BroadcastMessage, DeactivateMessage, OrchestrationEvent, RequestReplyMessage +from ..messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage from ..utils import message_content_to_str +from .base_agent import TeamOneBaseAgent logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator") -class BaseOrchestrator(TypeRoutedAgent): +class BaseOrchestrator(TeamOneBaseAgent): def __init__( self, agents: List[AgentProxy], description: str = "Base orchestrator", max_rounds: int = 20, + handle_messages_concurrently: bool = False, ) -> None: - super().__init__(description) + super().__init__(description, handle_messages_concurrently=handle_messages_concurrently) self._agents = agents self._max_rounds = max_rounds self._num_rounds = 0 - self._enabled = True - @message_handler - async def handle_incoming_message(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None: + async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None: """Handle an incoming message.""" - if not self._enabled: - return - source = "Unknown" if isinstance(message.content, UserMessage) or isinstance(message.content, AssistantMessage): source = message.content.source @@ -81,19 +77,6 @@ class BaseOrchestrator(TypeRoutedAgent): self._num_rounds += 1 # Call before sending the message await self.send_message(request_reply_message, next_agent.id) - @message_handler - async def handle_deactivate_message( - self, message: DeactivateMessage, cancellation_token: CancellationToken - ) -> None: - """Handle a deactivate message.""" - self._enabled = False - logger.info( - OrchestrationEvent( - f"{self.metadata['name']} (deactivated)", - "", - ) - ) - async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: raise NotImplementedError() diff --git a/python/teams/team-one/src/team_one/agents/base_worker.py b/python/teams/team-one/src/team_one/agents/base_worker.py new file mode 100644 index 000000000..423280ca9 --- /dev/null +++ b/python/teams/team-one/src/team_one/agents/base_worker.py @@ -0,0 +1,55 @@ +from typing import List, Tuple + +from agnext.components.models import ( + AssistantMessage, + LLMMessage, + UserMessage, +) +from agnext.core import CancellationToken + +from team_one.messages import ( + BroadcastMessage, + RequestReplyMessage, + ResetMessage, + UserContent, +) + +from ..utils import message_content_to_str +from .base_agent import TeamOneBaseAgent + + +class BaseWorker(TeamOneBaseAgent): + """Base agent that handles the TeamOne worker behavior protocol.""" + + def __init__( + self, + description: str, + handle_messages_concurrently: bool = False, + ) -> None: + super().__init__(description, handle_messages_concurrently=handle_messages_concurrently) + self._chat_history: List[LLMMessage] = [] + + async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None: + assert isinstance(message.content, UserMessage) + self._chat_history.append(message.content) + + async def _handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None: + """Handle a reset message.""" + await self._reset(cancellation_token) + + async def _handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None: + """Respond to a reply request.""" + request_halt, response = await self._generate_reply(cancellation_token) + + assistant_message = AssistantMessage(content=message_content_to_str(response), source=self.metadata["name"]) + self._chat_history.append(assistant_message) + + user_message = UserMessage(content=response, source=self.metadata["name"]) + await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt)) + + async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]: + """Returns (request_halt, response_message)""" + raise NotImplementedError() + + async def _reset(self, cancellation_token: CancellationToken) -> None: + self._chat_history = [] diff --git a/python/teams/team-one/src/team_one/agents/coder.py b/python/teams/team-one/src/team_one/agents/coder.py index 03d64687f..6b8d62538 100644 --- a/python/teams/team-one/src/team_one/agents/coder.py +++ b/python/teams/team-one/src/team_one/agents/coder.py @@ -9,12 +9,12 @@ from agnext.components.models import ( ) from agnext.core import CancellationToken -from ..utils import message_content_to_str from ..messages import UserContent -from .base_agent import BaseAgent +from ..utils import message_content_to_str +from .base_worker import BaseWorker -class Coder(BaseAgent): +class Coder(BaseWorker): """An agent that uses tools to write, execute, and debug Python code.""" DEFAULT_DESCRIPTION = "A Python coder assistant." @@ -58,7 +58,7 @@ If the code was executed, and the output appears to indicate that the original p return "TERMINATE" in response.content, response.content -class Executor(BaseAgent): +class Executor(BaseWorker): def __init__( self, description: str, executor: Optional[CodeExecutor] = None, check_last_n_message: int = 5 ) -> None: diff --git a/python/teams/team-one/src/team_one/agents/file_surfer/file_surfer.py b/python/teams/team-one/src/team_one/agents/file_surfer/file_surfer.py index a801193cb..622893466 100644 --- a/python/teams/team-one/src/team_one/agents/file_surfer/file_surfer.py +++ b/python/teams/team-one/src/team_one/agents/file_surfer/file_surfer.py @@ -13,7 +13,7 @@ from agnext.components.models import ( from agnext.core import CancellationToken from ...markdown_browser import RequestsMarkdownBrowser -from ..base_agent import BaseAgent +from ..base_worker import BaseWorker # from typing_extensions import Annotated from ._tools import TOOL_FIND_NEXT, TOOL_FIND_ON_PAGE_CTRL_F, TOOL_OPEN_LOCAL_FILE, TOOL_PAGE_DOWN, TOOL_PAGE_UP @@ -49,7 +49,7 @@ from ._tools import TOOL_FIND_NEXT, TOOL_FIND_ON_PAGE_CTRL_F, TOOL_OPEN_LOCAL_FI # return "\n".join(items) -class FileSurfer(BaseAgent): +class FileSurfer(BaseWorker): """An agent that uses tools to read and navigate local files.""" DEFAULT_DESCRIPTION = "An agent that can handle local files." diff --git a/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py b/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py index c640ac90b..e3541686c 100644 --- a/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py +++ b/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py @@ -29,10 +29,6 @@ from playwright._impl._errors import TimeoutError # from playwright._impl._async_base.AsyncEventInfo from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright -from team_one.agents.base_agent import BaseAgent -from team_one.messages import UserContent, WebSurferEvent -from team_one.utils import SentinelMeta, message_content_to_str - # TODO: Fix mdconvert from ...markdown_browser import ( # type: ignore DocumentConverterResult, # type: ignore @@ -40,6 +36,9 @@ from ...markdown_browser import ( # type: ignore MarkdownConverter, # type: ignore UnsupportedFormatException, # type: ignore ) +from ...messages import UserContent, WebSurferEvent +from ...utils import SentinelMeta, message_content_to_str +from ..base_worker import BaseWorker from .set_of_mark import add_set_of_mark from .tool_definitions import ( TOOL_CLICK, @@ -81,7 +80,7 @@ class DEFAULT_CHANNEL(metaclass=SentinelMeta): pass -class MultimodalWebSurfer(BaseAgent): +class MultimodalWebSurfer(BaseWorker): """(In preview) A multimodal agent that acts as a web surfer that can search the web and visit web pages.""" DEFAULT_DESCRIPTION = "A helpful assistant with access to a web browser. Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, etc., filling in form fields, etc.) It can also summarize the entire page, or answer questions based on the content of the page. It can also be asked to sleep and wait for pages to load, in cases where the pages seem to be taking a while to load." diff --git a/python/teams/team-one/src/team_one/agents/user_proxy.py b/python/teams/team-one/src/team_one/agents/user_proxy.py index 89df2e2fc..6d12f90c8 100755 --- a/python/teams/team-one/src/team_one/agents/user_proxy.py +++ b/python/teams/team-one/src/team_one/agents/user_proxy.py @@ -4,10 +4,10 @@ from typing import Tuple from agnext.core import CancellationToken from ..messages import UserContent -from .base_agent import BaseAgent +from .base_worker import BaseWorker -class UserProxy(BaseAgent): +class UserProxy(BaseWorker): """An agent that allows the user to play the role of an agent in the conversation.""" DEFAULT_DESCRIPTION = "A human user." diff --git a/python/teams/team-one/src/team_one/messages.py b/python/teams/team-one/src/team_one/messages.py index e50d10aac..5452b881c 100644 --- a/python/teams/team-one/src/team_one/messages.py +++ b/python/teams/team-one/src/team_one/messages.py @@ -38,6 +38,9 @@ class OrchestrationEvent: message: str +TeamOneMessages = RequestReplyMessage | BroadcastMessage | ResetMessage | DeactivateMessage + + @dataclass class AgentEvent: source: str diff --git a/python/teams/team-one/src/team_one/utils.py b/python/teams/team-one/src/team_one/utils.py index 19d0221bc..bf4c4df7b 100644 --- a/python/teams/team-one/src/team_one/utils.py +++ b/python/teams/team-one/src/team_one/utils.py @@ -107,59 +107,58 @@ class LogHandler(logging.FileHandler): super().__init__(filename) def emit(self, record: logging.LogRecord) -> None: - # try: - ts = datetime.fromtimestamp(record.created).isoformat() - if isinstance(record.msg, OrchestrationEvent): - console_message = ( - f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}" - ) - print(console_message, flush=True) - record.msg = json.dumps( - { + try: + ts = datetime.fromtimestamp(record.created).isoformat() + if isinstance(record.msg, OrchestrationEvent): + console_message = ( + f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}" + ) + print(console_message, flush=True) + record.msg = json.dumps( + { + "timestamp": ts, + "source": record.msg.source, + "message": record.msg.message, + "type": "OrchestrationEvent", + } + ) + super().emit(record) + elif isinstance(record.msg, AgentEvent): + console_message = ( + f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}" + ) + print(console_message, flush=True) + record.msg = json.dumps( + { + "timestamp": ts, + "source": record.msg.source, + "message": record.msg.message, + "type": "AgentEvent", + } + ) + super().emit(record) + elif isinstance(record.msg, WebSurferEvent): + console_message = f"\033[96m[{ts}], {record.msg.source}: {record.msg.message}\033[0m" + print(console_message, flush=True) + payload: Dict[str, Any] = { "timestamp": ts, - "source": record.msg.source, - "message": record.msg.message, - "type": "OrchestrationEvent", + "type": "WebSurferEvent", } - ) - super().emit(record) - elif isinstance(record.msg, AgentEvent): - console_message = ( - f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}" - ) - print(console_message, flush=True) - record.msg = json.dumps( - { - "timestamp": ts, - "source": record.msg.source, - "message": record.msg.message, - "type": "AgentEvent", - } - ) - super().emit(record) - elif isinstance(record.msg, WebSurferEvent): - console_message = f"\033[96m[{ts}], {record.msg.source}: {record.msg.message}\033[0m" - print(console_message, flush=True) - payload: Dict[str, Any] = { - "timestamp": ts, - "type": "WebSurferEvent", - } - payload.update(asdict(record.msg)) - record.msg = json.dumps(payload) - super().emit(record) - elif isinstance(record.msg, LLMCallEvent): - record.msg = json.dumps( - { - "timestamp": ts, - "prompt_tokens": record.msg.prompt_tokens, - "completion_tokens": record.msg.completion_tokens, - "type": "LLMCallEvent", - } - ) - super().emit(record) - - # except Exception: - # self.handleError(record) + payload.update(asdict(record.msg)) + record.msg = json.dumps(payload) + super().emit(record) + elif isinstance(record.msg, LLMCallEvent): + record.msg = json.dumps( + { + "timestamp": ts, + "prompt_tokens": record.msg.prompt_tokens, + "completion_tokens": record.msg.completion_tokens, + "type": "LLMCallEvent", + } + ) + super().emit(record) + except Exception: + self.handleError(record) class SentinelMeta(type):