mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-11 12:54:59 -05:00
Introduces a BaseWorker agent, allowing for a TeamOneBaseAgent (#289)
This commit is contained in:
@@ -1,14 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
LLMMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
from team_one.messages import (
|
||||
@@ -17,18 +12,14 @@ from team_one.messages import (
|
||||
DeactivateMessage,
|
||||
RequestReplyMessage,
|
||||
ResetMessage,
|
||||
UserContent,
|
||||
TeamOneMessages,
|
||||
)
|
||||
from team_one.utils import message_content_to_str
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator")
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME + ".agent")
|
||||
|
||||
|
||||
PossibleMessages = RequestReplyMessage | BroadcastMessage | ResetMessage | DeactivateMessage
|
||||
|
||||
|
||||
class BaseAgent(TypeRoutedAgent):
|
||||
"""An agent that handles the RequestReply and Broadcast messages"""
|
||||
class TeamOneBaseAgent(TypeRoutedAgent):
|
||||
"""An agent that optionally ensures messages are handled non-concurrently in the order they arrive."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -36,14 +27,12 @@ class BaseAgent(TypeRoutedAgent):
|
||||
handle_messages_concurrently: bool = False,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
self._enabled: bool = True
|
||||
|
||||
self._handle_messages_concurrently = handle_messages_concurrently
|
||||
self._enabled = True
|
||||
|
||||
if not self._handle_messages_concurrently:
|
||||
# TODO: make it possible to stop
|
||||
self._message_queue = asyncio.Queue[tuple[PossibleMessages, CancellationToken, asyncio.Future[Any]]]()
|
||||
self._message_queue = asyncio.Queue[tuple[TeamOneMessages, CancellationToken, asyncio.Future[Any]]]()
|
||||
self._processing_task = asyncio.create_task(self._process())
|
||||
|
||||
async def _process(self) -> None:
|
||||
@@ -54,13 +43,13 @@ class BaseAgent(TypeRoutedAgent):
|
||||
continue
|
||||
|
||||
if isinstance(message, RequestReplyMessage):
|
||||
await self.handle_request_reply(message, cancellation_token)
|
||||
await self._handle_request_reply(message, cancellation_token)
|
||||
elif isinstance(message, BroadcastMessage):
|
||||
await self.handle_broadcast(message, cancellation_token)
|
||||
await self._handle_broadcast(message, cancellation_token)
|
||||
elif isinstance(message, ResetMessage):
|
||||
await self.handle_reset(message, cancellation_token)
|
||||
await self._handle_reset(message, cancellation_token)
|
||||
elif isinstance(message, DeactivateMessage):
|
||||
await self.handle_deactivate(message, cancellation_token)
|
||||
await self._handle_deactivate(message, cancellation_token)
|
||||
else:
|
||||
raise ValueError("Unknown message type.")
|
||||
|
||||
@@ -77,27 +66,28 @@ class BaseAgent(TypeRoutedAgent):
|
||||
|
||||
if self._handle_messages_concurrently:
|
||||
if isinstance(message, RequestReplyMessage):
|
||||
await self.handle_request_reply(message, cancellation_token)
|
||||
await self._handle_request_reply(message, cancellation_token)
|
||||
elif isinstance(message, BroadcastMessage):
|
||||
await self.handle_broadcast(message, cancellation_token)
|
||||
await self._handle_broadcast(message, cancellation_token)
|
||||
elif isinstance(message, ResetMessage):
|
||||
await self.handle_reset(message, cancellation_token)
|
||||
await self._handle_reset(message, cancellation_token)
|
||||
elif isinstance(message, DeactivateMessage):
|
||||
await self.handle_deactivate(message, cancellation_token)
|
||||
await self._handle_deactivate(message, cancellation_token)
|
||||
else:
|
||||
future = asyncio.Future[Any]()
|
||||
await self._message_queue.put((message, cancellation_token, future))
|
||||
await future
|
||||
|
||||
async def handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
|
||||
assert isinstance(message.content, UserMessage)
|
||||
self._chat_history.append(message.content)
|
||||
async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a reset message."""
|
||||
await self._reset(cancellation_token)
|
||||
async def _handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def handle_deactivate(self, message: DeactivateMessage, cancellation_token: CancellationToken) -> None:
|
||||
async def _handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _handle_deactivate(self, message: DeactivateMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a deactivate message."""
|
||||
self._enabled = False
|
||||
logger.info(
|
||||
@@ -106,20 +96,3 @@ class BaseAgent(TypeRoutedAgent):
|
||||
"",
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Respond to a reply request."""
|
||||
request_halt, response = await self._generate_reply(cancellation_token)
|
||||
|
||||
assistant_message = AssistantMessage(content=message_content_to_str(response), source=self.metadata["name"])
|
||||
self._chat_history.append(assistant_message)
|
||||
|
||||
user_message = UserMessage(content=response, source=self.metadata["name"])
|
||||
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt))
|
||||
|
||||
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
|
||||
"""Returns (request_halt, response_message)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _reset(self, cancellation_token: CancellationToken) -> None:
|
||||
self._chat_history = []
|
||||
|
||||
@@ -2,36 +2,32 @@ import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components.models import AssistantMessage, LLMMessage, UserMessage
|
||||
from agnext.core import AgentProxy, CancellationToken
|
||||
|
||||
from ..messages import BroadcastMessage, DeactivateMessage, OrchestrationEvent, RequestReplyMessage
|
||||
from ..messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage
|
||||
from ..utils import message_content_to_str
|
||||
from .base_agent import TeamOneBaseAgent
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator")
|
||||
|
||||
|
||||
class BaseOrchestrator(TypeRoutedAgent):
|
||||
class BaseOrchestrator(TeamOneBaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
agents: List[AgentProxy],
|
||||
description: str = "Base orchestrator",
|
||||
max_rounds: int = 20,
|
||||
handle_messages_concurrently: bool = False,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
super().__init__(description, handle_messages_concurrently=handle_messages_concurrently)
|
||||
self._agents = agents
|
||||
self._max_rounds = max_rounds
|
||||
self._num_rounds = 0
|
||||
self._enabled = True
|
||||
|
||||
@message_handler
|
||||
async def handle_incoming_message(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
|
||||
async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle an incoming message."""
|
||||
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
source = "Unknown"
|
||||
if isinstance(message.content, UserMessage) or isinstance(message.content, AssistantMessage):
|
||||
source = message.content.source
|
||||
@@ -81,19 +77,6 @@ class BaseOrchestrator(TypeRoutedAgent):
|
||||
self._num_rounds += 1 # Call before sending the message
|
||||
await self.send_message(request_reply_message, next_agent.id)
|
||||
|
||||
@message_handler
|
||||
async def handle_deactivate_message(
|
||||
self, message: DeactivateMessage, cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
"""Handle a deactivate message."""
|
||||
self._enabled = False
|
||||
logger.info(
|
||||
OrchestrationEvent(
|
||||
f"{self.metadata['name']} (deactivated)",
|
||||
"",
|
||||
)
|
||||
)
|
||||
|
||||
async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
55
python/teams/team-one/src/team_one/agents/base_worker.py
Normal file
55
python/teams/team-one/src/team_one/agents/base_worker.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
LLMMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
from team_one.messages import (
|
||||
BroadcastMessage,
|
||||
RequestReplyMessage,
|
||||
ResetMessage,
|
||||
UserContent,
|
||||
)
|
||||
|
||||
from ..utils import message_content_to_str
|
||||
from .base_agent import TeamOneBaseAgent
|
||||
|
||||
|
||||
class BaseWorker(TeamOneBaseAgent):
|
||||
"""Base agent that handles the TeamOne worker behavior protocol."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
handle_messages_concurrently: bool = False,
|
||||
) -> None:
|
||||
super().__init__(description, handle_messages_concurrently=handle_messages_concurrently)
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
|
||||
async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
|
||||
assert isinstance(message.content, UserMessage)
|
||||
self._chat_history.append(message.content)
|
||||
|
||||
async def _handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a reset message."""
|
||||
await self._reset(cancellation_token)
|
||||
|
||||
async def _handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Respond to a reply request."""
|
||||
request_halt, response = await self._generate_reply(cancellation_token)
|
||||
|
||||
assistant_message = AssistantMessage(content=message_content_to_str(response), source=self.metadata["name"])
|
||||
self._chat_history.append(assistant_message)
|
||||
|
||||
user_message = UserMessage(content=response, source=self.metadata["name"])
|
||||
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt))
|
||||
|
||||
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
|
||||
"""Returns (request_halt, response_message)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _reset(self, cancellation_token: CancellationToken) -> None:
|
||||
self._chat_history = []
|
||||
@@ -9,12 +9,12 @@ from agnext.components.models import (
|
||||
)
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
from ..utils import message_content_to_str
|
||||
from ..messages import UserContent
|
||||
from .base_agent import BaseAgent
|
||||
from ..utils import message_content_to_str
|
||||
from .base_worker import BaseWorker
|
||||
|
||||
|
||||
class Coder(BaseAgent):
|
||||
class Coder(BaseWorker):
|
||||
"""An agent that uses tools to write, execute, and debug Python code."""
|
||||
|
||||
DEFAULT_DESCRIPTION = "A Python coder assistant."
|
||||
@@ -58,7 +58,7 @@ If the code was executed, and the output appears to indicate that the original p
|
||||
return "TERMINATE" in response.content, response.content
|
||||
|
||||
|
||||
class Executor(BaseAgent):
|
||||
class Executor(BaseWorker):
|
||||
def __init__(
|
||||
self, description: str, executor: Optional[CodeExecutor] = None, check_last_n_message: int = 5
|
||||
) -> None:
|
||||
|
||||
@@ -13,7 +13,7 @@ from agnext.components.models import (
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
from ...markdown_browser import RequestsMarkdownBrowser
|
||||
from ..base_agent import BaseAgent
|
||||
from ..base_worker import BaseWorker
|
||||
|
||||
# from typing_extensions import Annotated
|
||||
from ._tools import TOOL_FIND_NEXT, TOOL_FIND_ON_PAGE_CTRL_F, TOOL_OPEN_LOCAL_FILE, TOOL_PAGE_DOWN, TOOL_PAGE_UP
|
||||
@@ -49,7 +49,7 @@ from ._tools import TOOL_FIND_NEXT, TOOL_FIND_ON_PAGE_CTRL_F, TOOL_OPEN_LOCAL_FI
|
||||
# return "\n".join(items)
|
||||
|
||||
|
||||
class FileSurfer(BaseAgent):
|
||||
class FileSurfer(BaseWorker):
|
||||
"""An agent that uses tools to read and navigate local files."""
|
||||
|
||||
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
||||
|
||||
@@ -29,10 +29,6 @@ from playwright._impl._errors import TimeoutError
|
||||
# from playwright._impl._async_base.AsyncEventInfo
|
||||
from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright
|
||||
|
||||
from team_one.agents.base_agent import BaseAgent
|
||||
from team_one.messages import UserContent, WebSurferEvent
|
||||
from team_one.utils import SentinelMeta, message_content_to_str
|
||||
|
||||
# TODO: Fix mdconvert
|
||||
from ...markdown_browser import ( # type: ignore
|
||||
DocumentConverterResult, # type: ignore
|
||||
@@ -40,6 +36,9 @@ from ...markdown_browser import ( # type: ignore
|
||||
MarkdownConverter, # type: ignore
|
||||
UnsupportedFormatException, # type: ignore
|
||||
)
|
||||
from ...messages import UserContent, WebSurferEvent
|
||||
from ...utils import SentinelMeta, message_content_to_str
|
||||
from ..base_worker import BaseWorker
|
||||
from .set_of_mark import add_set_of_mark
|
||||
from .tool_definitions import (
|
||||
TOOL_CLICK,
|
||||
@@ -81,7 +80,7 @@ class DEFAULT_CHANNEL(metaclass=SentinelMeta):
|
||||
pass
|
||||
|
||||
|
||||
class MultimodalWebSurfer(BaseAgent):
|
||||
class MultimodalWebSurfer(BaseWorker):
|
||||
"""(In preview) A multimodal agent that acts as a web surfer that can search the web and visit web pages."""
|
||||
|
||||
DEFAULT_DESCRIPTION = "A helpful assistant with access to a web browser. Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, etc., filling in form fields, etc.) It can also summarize the entire page, or answer questions based on the content of the page. It can also be asked to sleep and wait for pages to load, in cases where the pages seem to be taking a while to load."
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import Tuple
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
from ..messages import UserContent
|
||||
from .base_agent import BaseAgent
|
||||
from .base_worker import BaseWorker
|
||||
|
||||
|
||||
class UserProxy(BaseAgent):
|
||||
class UserProxy(BaseWorker):
|
||||
"""An agent that allows the user to play the role of an agent in the conversation."""
|
||||
|
||||
DEFAULT_DESCRIPTION = "A human user."
|
||||
|
||||
@@ -38,6 +38,9 @@ class OrchestrationEvent:
|
||||
message: str
|
||||
|
||||
|
||||
TeamOneMessages = RequestReplyMessage | BroadcastMessage | ResetMessage | DeactivateMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentEvent:
|
||||
source: str
|
||||
|
||||
@@ -107,59 +107,58 @@ class LogHandler(logging.FileHandler):
|
||||
super().__init__(filename)
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# try:
|
||||
ts = datetime.fromtimestamp(record.created).isoformat()
|
||||
if isinstance(record.msg, OrchestrationEvent):
|
||||
console_message = (
|
||||
f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}"
|
||||
)
|
||||
print(console_message, flush=True)
|
||||
record.msg = json.dumps(
|
||||
{
|
||||
try:
|
||||
ts = datetime.fromtimestamp(record.created).isoformat()
|
||||
if isinstance(record.msg, OrchestrationEvent):
|
||||
console_message = (
|
||||
f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}"
|
||||
)
|
||||
print(console_message, flush=True)
|
||||
record.msg = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"source": record.msg.source,
|
||||
"message": record.msg.message,
|
||||
"type": "OrchestrationEvent",
|
||||
}
|
||||
)
|
||||
super().emit(record)
|
||||
elif isinstance(record.msg, AgentEvent):
|
||||
console_message = (
|
||||
f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}"
|
||||
)
|
||||
print(console_message, flush=True)
|
||||
record.msg = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"source": record.msg.source,
|
||||
"message": record.msg.message,
|
||||
"type": "AgentEvent",
|
||||
}
|
||||
)
|
||||
super().emit(record)
|
||||
elif isinstance(record.msg, WebSurferEvent):
|
||||
console_message = f"\033[96m[{ts}], {record.msg.source}: {record.msg.message}\033[0m"
|
||||
print(console_message, flush=True)
|
||||
payload: Dict[str, Any] = {
|
||||
"timestamp": ts,
|
||||
"source": record.msg.source,
|
||||
"message": record.msg.message,
|
||||
"type": "OrchestrationEvent",
|
||||
"type": "WebSurferEvent",
|
||||
}
|
||||
)
|
||||
super().emit(record)
|
||||
elif isinstance(record.msg, AgentEvent):
|
||||
console_message = (
|
||||
f"\n{'-'*75} \n" f"\033[91m[{ts}], {record.msg.source}:\033[0m\n" f"\n{record.msg.message}"
|
||||
)
|
||||
print(console_message, flush=True)
|
||||
record.msg = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"source": record.msg.source,
|
||||
"message": record.msg.message,
|
||||
"type": "AgentEvent",
|
||||
}
|
||||
)
|
||||
super().emit(record)
|
||||
elif isinstance(record.msg, WebSurferEvent):
|
||||
console_message = f"\033[96m[{ts}], {record.msg.source}: {record.msg.message}\033[0m"
|
||||
print(console_message, flush=True)
|
||||
payload: Dict[str, Any] = {
|
||||
"timestamp": ts,
|
||||
"type": "WebSurferEvent",
|
||||
}
|
||||
payload.update(asdict(record.msg))
|
||||
record.msg = json.dumps(payload)
|
||||
super().emit(record)
|
||||
elif isinstance(record.msg, LLMCallEvent):
|
||||
record.msg = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"prompt_tokens": record.msg.prompt_tokens,
|
||||
"completion_tokens": record.msg.completion_tokens,
|
||||
"type": "LLMCallEvent",
|
||||
}
|
||||
)
|
||||
super().emit(record)
|
||||
|
||||
# except Exception:
|
||||
# self.handleError(record)
|
||||
payload.update(asdict(record.msg))
|
||||
record.msg = json.dumps(payload)
|
||||
super().emit(record)
|
||||
elif isinstance(record.msg, LLMCallEvent):
|
||||
record.msg = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"prompt_tokens": record.msg.prompt_tokens,
|
||||
"completion_tokens": record.msg.completion_tokens,
|
||||
"type": "LLMCallEvent",
|
||||
}
|
||||
)
|
||||
super().emit(record)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class SentinelMeta(type):
|
||||
|
||||
Reference in New Issue
Block a user