Introduces a BaseWorker agent, allowing for a TeamOneBaseAgent (#289)

This commit is contained in:
afourney
2024-07-29 13:09:31 -07:00
committed by GitHub
parent 2bc0a33f78
commit ec654253d2
9 changed files with 149 additions and 137 deletions

View File

@@ -1,14 +1,9 @@
import asyncio
import logging
from typing import Any, List, Tuple
from typing import Any
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.models import (
AssistantMessage,
LLMMessage,
UserMessage,
)
from agnext.core import CancellationToken
from team_one.messages import (
@@ -17,18 +12,14 @@ from team_one.messages import (
DeactivateMessage,
RequestReplyMessage,
ResetMessage,
UserContent,
TeamOneMessages,
)
from team_one.utils import message_content_to_str
logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator")
logger = logging.getLogger(EVENT_LOGGER_NAME + ".agent")
PossibleMessages = RequestReplyMessage | BroadcastMessage | ResetMessage | DeactivateMessage
class BaseAgent(TypeRoutedAgent):
"""An agent that handles the RequestReply and Broadcast messages"""
class TeamOneBaseAgent(TypeRoutedAgent):
"""An agent that optionally ensures messages are handled non-concurrently in the order they arrive."""
def __init__(
self,
@@ -36,14 +27,12 @@ class BaseAgent(TypeRoutedAgent):
handle_messages_concurrently: bool = False,
) -> None:
super().__init__(description)
self._chat_history: List[LLMMessage] = []
self._enabled: bool = True
self._handle_messages_concurrently = handle_messages_concurrently
self._enabled = True
if not self._handle_messages_concurrently:
# TODO: make it possible to stop
self._message_queue = asyncio.Queue[tuple[PossibleMessages, CancellationToken, asyncio.Future[Any]]]()
self._message_queue = asyncio.Queue[tuple[TeamOneMessages, CancellationToken, asyncio.Future[Any]]]()
self._processing_task = asyncio.create_task(self._process())
async def _process(self) -> None:
@@ -54,13 +43,13 @@ class BaseAgent(TypeRoutedAgent):
continue
if isinstance(message, RequestReplyMessage):
await self.handle_request_reply(message, cancellation_token)
await self._handle_request_reply(message, cancellation_token)
elif isinstance(message, BroadcastMessage):
await self.handle_broadcast(message, cancellation_token)
await self._handle_broadcast(message, cancellation_token)
elif isinstance(message, ResetMessage):
await self.handle_reset(message, cancellation_token)
await self._handle_reset(message, cancellation_token)
elif isinstance(message, DeactivateMessage):
await self.handle_deactivate(message, cancellation_token)
await self._handle_deactivate(message, cancellation_token)
else:
raise ValueError("Unknown message type.")
@@ -77,27 +66,28 @@ class BaseAgent(TypeRoutedAgent):
if self._handle_messages_concurrently:
if isinstance(message, RequestReplyMessage):
await self.handle_request_reply(message, cancellation_token)
await self._handle_request_reply(message, cancellation_token)
elif isinstance(message, BroadcastMessage):
await self.handle_broadcast(message, cancellation_token)
await self._handle_broadcast(message, cancellation_token)
elif isinstance(message, ResetMessage):
await self.handle_reset(message, cancellation_token)
await self._handle_reset(message, cancellation_token)
elif isinstance(message, DeactivateMessage):
await self.handle_deactivate(message, cancellation_token)
await self._handle_deactivate(message, cancellation_token)
else:
future = asyncio.Future[Any]()
await self._message_queue.put((message, cancellation_token, future))
await future
async def handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
assert isinstance(message.content, UserMessage)
self._chat_history.append(message.content)
async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
raise NotImplementedError()
async def handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None:
"""Handle a reset message."""
await self._reset(cancellation_token)
async def _handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None:
raise NotImplementedError()
async def handle_deactivate(self, message: DeactivateMessage, cancellation_token: CancellationToken) -> None:
async def _handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None:
raise NotImplementedError()
async def _handle_deactivate(self, message: DeactivateMessage, cancellation_token: CancellationToken) -> None:
"""Handle a deactivate message."""
self._enabled = False
logger.info(
@@ -106,20 +96,3 @@ class BaseAgent(TypeRoutedAgent):
"",
)
)
async def handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None:
"""Respond to a reply request."""
request_halt, response = await self._generate_reply(cancellation_token)
assistant_message = AssistantMessage(content=message_content_to_str(response), source=self.metadata["name"])
self._chat_history.append(assistant_message)
user_message = UserMessage(content=response, source=self.metadata["name"])
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt))
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
"""Returns (request_halt, response_message)"""
raise NotImplementedError()
async def _reset(self, cancellation_token: CancellationToken) -> None:
self._chat_history = []

View File

@@ -2,36 +2,32 @@ import logging
from typing import List, Optional
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.models import AssistantMessage, LLMMessage, UserMessage
from agnext.core import AgentProxy, CancellationToken
from ..messages import BroadcastMessage, DeactivateMessage, OrchestrationEvent, RequestReplyMessage
from ..messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage
from ..utils import message_content_to_str
from .base_agent import TeamOneBaseAgent
logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator")
class BaseOrchestrator(TypeRoutedAgent):
class BaseOrchestrator(TeamOneBaseAgent):
def __init__(
self,
agents: List[AgentProxy],
description: str = "Base orchestrator",
max_rounds: int = 20,
handle_messages_concurrently: bool = False,
) -> None:
super().__init__(description)
super().__init__(description, handle_messages_concurrently=handle_messages_concurrently)
self._agents = agents
self._max_rounds = max_rounds
self._num_rounds = 0
self._enabled = True
@message_handler
async def handle_incoming_message(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
"""Handle an incoming message."""
if not self._enabled:
return
source = "Unknown"
if isinstance(message.content, UserMessage) or isinstance(message.content, AssistantMessage):
source = message.content.source
@@ -81,19 +77,6 @@ class BaseOrchestrator(TypeRoutedAgent):
self._num_rounds += 1 # Call before sending the message
await self.send_message(request_reply_message, next_agent.id)
@message_handler
async def handle_deactivate_message(
self, message: DeactivateMessage, cancellation_token: CancellationToken
) -> None:
"""Handle a deactivate message."""
self._enabled = False
logger.info(
OrchestrationEvent(
f"{self.metadata['name']} (deactivated)",
"",
)
)
async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]:
raise NotImplementedError()

View File

@@ -0,0 +1,55 @@
from typing import List, Tuple
from agnext.components.models import (
AssistantMessage,
LLMMessage,
UserMessage,
)
from agnext.core import CancellationToken
from team_one.messages import (
BroadcastMessage,
RequestReplyMessage,
ResetMessage,
UserContent,
)
from ..utils import message_content_to_str
from .base_agent import TeamOneBaseAgent
class BaseWorker(TeamOneBaseAgent):
"""Base agent that handles the TeamOne worker behavior protocol."""
def __init__(
self,
description: str,
handle_messages_concurrently: bool = False,
) -> None:
super().__init__(description, handle_messages_concurrently=handle_messages_concurrently)
self._chat_history: List[LLMMessage] = []
async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
assert isinstance(message.content, UserMessage)
self._chat_history.append(message.content)
async def _handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None:
"""Handle a reset message."""
await self._reset(cancellation_token)
async def _handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None:
"""Respond to a reply request."""
request_halt, response = await self._generate_reply(cancellation_token)
assistant_message = AssistantMessage(content=message_content_to_str(response), source=self.metadata["name"])
self._chat_history.append(assistant_message)
user_message = UserMessage(content=response, source=self.metadata["name"])
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt))
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
"""Returns (request_halt, response_message)"""
raise NotImplementedError()
async def _reset(self, cancellation_token: CancellationToken) -> None:
self._chat_history = []

View File

@@ -9,12 +9,12 @@ from agnext.components.models import (
)
from agnext.core import CancellationToken
from ..utils import message_content_to_str
from ..messages import UserContent
from .base_agent import BaseAgent
from ..utils import message_content_to_str
from .base_worker import BaseWorker
class Coder(BaseAgent):
class Coder(BaseWorker):
"""An agent that uses tools to write, execute, and debug Python code."""
DEFAULT_DESCRIPTION = "A Python coder assistant."
@@ -58,7 +58,7 @@ If the code was executed, and the output appears to indicate that the original p
return "TERMINATE" in response.content, response.content
class Executor(BaseAgent):
class Executor(BaseWorker):
def __init__(
self, description: str, executor: Optional[CodeExecutor] = None, check_last_n_message: int = 5
) -> None:

View File

@@ -13,7 +13,7 @@ from agnext.components.models import (
from agnext.core import CancellationToken
from ...markdown_browser import RequestsMarkdownBrowser
from ..base_agent import BaseAgent
from ..base_worker import BaseWorker
# from typing_extensions import Annotated
from ._tools import TOOL_FIND_NEXT, TOOL_FIND_ON_PAGE_CTRL_F, TOOL_OPEN_LOCAL_FILE, TOOL_PAGE_DOWN, TOOL_PAGE_UP
@@ -49,7 +49,7 @@ from ._tools import TOOL_FIND_NEXT, TOOL_FIND_ON_PAGE_CTRL_F, TOOL_OPEN_LOCAL_FI
# return "\n".join(items)
class FileSurfer(BaseAgent):
class FileSurfer(BaseWorker):
"""An agent that uses tools to read and navigate local files."""
DEFAULT_DESCRIPTION = "An agent that can handle local files."

View File

@@ -29,10 +29,6 @@ from playwright._impl._errors import TimeoutError
# from playwright._impl._async_base.AsyncEventInfo
from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright
from team_one.agents.base_agent import BaseAgent
from team_one.messages import UserContent, WebSurferEvent
from team_one.utils import SentinelMeta, message_content_to_str
# TODO: Fix mdconvert
from ...markdown_browser import ( # type: ignore
DocumentConverterResult, # type: ignore
@@ -40,6 +36,9 @@ from ...markdown_browser import ( # type: ignore
MarkdownConverter, # type: ignore
UnsupportedFormatException, # type: ignore
)
from ...messages import UserContent, WebSurferEvent
from ...utils import SentinelMeta, message_content_to_str
from ..base_worker import BaseWorker
from .set_of_mark import add_set_of_mark
from .tool_definitions import (
TOOL_CLICK,
@@ -81,7 +80,7 @@ class DEFAULT_CHANNEL(metaclass=SentinelMeta):
pass
class MultimodalWebSurfer(BaseAgent):
class MultimodalWebSurfer(BaseWorker):
"""(In preview) A multimodal agent that acts as a web surfer that can search the web and visit web pages."""
DEFAULT_DESCRIPTION = "A helpful assistant with access to a web browser. Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, etc., filling in form fields, etc.) It can also summarize the entire page, or answer questions based on the content of the page. It can also be asked to sleep and wait for pages to load, in cases where the pages seem to be taking a while to load."

View File

@@ -4,10 +4,10 @@ from typing import Tuple
from agnext.core import CancellationToken
from ..messages import UserContent
from .base_agent import BaseAgent
from .base_worker import BaseWorker
class UserProxy(BaseAgent):
class UserProxy(BaseWorker):
"""An agent that allows the user to play the role of an agent in the conversation."""
DEFAULT_DESCRIPTION = "A human user."

View File

@@ -38,6 +38,9 @@ class OrchestrationEvent:
message: str
TeamOneMessages = RequestReplyMessage | BroadcastMessage | ResetMessage | DeactivateMessage
@dataclass
class AgentEvent:
source: str

View File

@@ -107,59 +107,58 @@ class LogHandler(logging.FileHandler):
super().__init__(filename)
def emit(self, record: logging.LogRecord) -> None:
# try:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, OrchestrationEvent):
console_message = (
f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}"
)
print(console_message, flush=True)
record.msg = json.dumps(
{
try:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, OrchestrationEvent):
console_message = (
f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}"
)
print(console_message, flush=True)
record.msg = json.dumps(
{
"timestamp": ts,
"source": record.msg.source,
"message": record.msg.message,
"type": "OrchestrationEvent",
}
)
super().emit(record)
elif isinstance(record.msg, AgentEvent):
console_message = (
f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}"
)
print(console_message, flush=True)
record.msg = json.dumps(
{
"timestamp": ts,
"source": record.msg.source,
"message": record.msg.message,
"type": "AgentEvent",
}
)
super().emit(record)
elif isinstance(record.msg, WebSurferEvent):
console_message = f"\033[96m[{ts}], {record.msg.source}: {record.msg.message}\033[0m"
print(console_message, flush=True)
payload: Dict[str, Any] = {
"timestamp": ts,
"source": record.msg.source,
"message": record.msg.message,
"type": "OrchestrationEvent",
"type": "WebSurferEvent",
}
)
super().emit(record)
elif isinstance(record.msg, AgentEvent):
console_message = (
f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}"
)
print(console_message, flush=True)
record.msg = json.dumps(
{
"timestamp": ts,
"source": record.msg.source,
"message": record.msg.message,
"type": "AgentEvent",
}
)
super().emit(record)
elif isinstance(record.msg, WebSurferEvent):
console_message = f"\033[96m[{ts}], {record.msg.source}: {record.msg.message}\033[0m"
print(console_message, flush=True)
payload: Dict[str, Any] = {
"timestamp": ts,
"type": "WebSurferEvent",
}
payload.update(asdict(record.msg))
record.msg = json.dumps(payload)
super().emit(record)
elif isinstance(record.msg, LLMCallEvent):
record.msg = json.dumps(
{
"timestamp": ts,
"prompt_tokens": record.msg.prompt_tokens,
"completion_tokens": record.msg.completion_tokens,
"type": "LLMCallEvent",
}
)
super().emit(record)
# except Exception:
# self.handleError(record)
payload.update(asdict(record.msg))
record.msg = json.dumps(payload)
super().emit(record)
elif isinstance(record.msg, LLMCallEvent):
record.msg = json.dumps(
{
"timestamp": ts,
"prompt_tokens": record.msg.prompt_tokens,
"completion_tokens": record.msg.completion_tokens,
"type": "LLMCallEvent",
}
)
super().emit(record)
except Exception:
self.handleError(record)
class SentinelMeta(type):