TeamOne handle multimodal console (#200)

* Added ability to print multimodal messages to console.

* Fixed hatch error
This commit is contained in:
afourney
2024-07-10 00:01:13 -07:00
committed by GitHub
parent 14628f2ae0
commit 5996b452eb
7 changed files with 42 additions and 28 deletions

View File

@@ -1,6 +1,6 @@
from typing import List, Tuple, Union
from typing import List, Tuple
from agnext.components import Image, TypeRoutedAgent, message_handler
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.models import (
AssistantMessage,
LLMMessage,
@@ -8,10 +8,8 @@ from agnext.components.models import (
)
from agnext.core import CancellationToken
from team_one.messages import BroadcastMessage, RequestReplyMessage
# Convenience type
UserContent = Union[str, List[Union[str, Image]]]
from team_one.messages import BroadcastMessage, RequestReplyMessage, UserContent
from team_one.utils import message_content_to_str
class BaseAgent(TypeRoutedAgent):
@@ -35,21 +33,7 @@ class BaseAgent(TypeRoutedAgent):
"""Respond to a reply request."""
request_halt, response = await self._generate_reply(cancellation_token)
# Convert the response to an acceptable format for the assistant
if isinstance(response, str):
assistant_message = AssistantMessage(content=response, source=self.metadata["name"])
elif isinstance(response, List):
converted: List[str] = list()
for item in response:
if isinstance(item, str):
converted.append(item.rstrip())
elif isinstance(item, Image):
converted.append("<image>")
else:
raise AssertionError("Unexpected response type.")
assistant_message = AssistantMessage(content="\n".join(converted), source=self.metadata["name"])
else:
raise AssertionError("Unexpected response type.")
assistant_message = AssistantMessage(content=message_content_to_str(response), source=self.metadata["name"])
self._chat_history.append(assistant_message)
user_message = UserMessage(content=response, source=self.metadata["name"])

View File

@@ -9,7 +9,8 @@ from agnext.components.models import (
)
from agnext.core import CancellationToken
from .base_agent import BaseAgent, UserContent
from ..messages import UserContent
from .base_agent import BaseAgent
class Coder(BaseAgent):

View File

@@ -14,7 +14,8 @@ from agnext.components.tools import FunctionTool
from agnext.core import CancellationToken
from typing_extensions import Annotated
from .base_agent import BaseAgent, UserContent
from ..messages import UserContent
from .base_agent import BaseAgent
async def read_local_file(file_path: Annotated[str, "relative or absolute path of file to read"]) -> str:

View File

@@ -7,6 +7,7 @@ from agnext.components.models import AssistantMessage, UserMessage
from agnext.core import AgentProxy, CancellationToken
from ..messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage
from ..utils import message_content_to_str
logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator")
@@ -30,7 +31,7 @@ class RoundRobinOrchestrator(TypeRoutedAgent):
if isinstance(message.content, UserMessage) or isinstance(message.content, AssistantMessage):
source = message.content.source
content = str(message.content.content)
content = message_content_to_str(message.content.content)
logger.info(OrchestrationEvent(source, content))

View File

@@ -3,7 +3,8 @@ from typing import Tuple
from agnext.core import CancellationToken
from .base_agent import BaseAgent, UserContent
from ..messages import UserContent
from .base_agent import BaseAgent
class UserProxy(BaseAgent):

View File

@@ -1,6 +1,14 @@
from dataclasses import dataclass
from typing import List, Union
from agnext.components.models import LLMMessage
from agnext.components import FunctionCall, Image
from agnext.components.models import FunctionExecutionResult, LLMMessage
# Convenience type
UserContent = Union[str, List[Union[str, Image]]]
AssistantContent = Union[str, List[FunctionCall]]
FunctionExecutionContent = List[FunctionExecutionResult]
SystemContent = str
@dataclass

View File

@@ -2,7 +2,7 @@ import json
import logging
import os
from datetime import datetime
from typing import Any, Dict
from typing import Any, Dict, List
from agnext.components.models import (
AzureOpenAIChatCompletionClient,
@@ -11,7 +11,7 @@ from agnext.components.models import (
OpenAIChatCompletionClient,
)
from .messages import OrchestrationEvent
from .messages import AssistantContent, FunctionExecutionContent, OrchestrationEvent, SystemContent, UserContent
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER = "CHAT_COMPLETION_PROVIDER"
ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON = "CHAT_COMPLETION_KWARGS_JSON"
@@ -70,6 +70,24 @@ def create_completion_client_from_env(env: Dict[str, str] | None = None, **kwarg
raise ValueError(f"Unknown OAI provider '{_provider}'")
# Convert UserContent to a string
def message_content_to_str(
message_content: UserContent | AssistantContent | SystemContent | FunctionExecutionContent,
) -> str:
if isinstance(message_content, str):
return message_content
elif isinstance(message_content, List):
converted: List[str] = list()
for item in message_content:
if isinstance(item, str):
converted.append(item.rstrip())
else:
converted.append(str(item).rstrip())
return "\n".join(converted)
else:
raise AssertionError("Unexpected response type.")
# TeamOne log event handler
class LogHandler(logging.Handler):
def __init__(self) -> None: