Add types agnostic to role (#11)

This commit is contained in:
Jack Gerrits
2024-05-23 16:49:01 -04:00
committed by GitHub
parent 8d1f4aedc0
commit 52f6f79591
2 changed files with 110 additions and 0 deletions

42
src/agnext/chat/types.py Normal file
View File

@@ -0,0 +1,42 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Union
from agnext.agent_components.image import Image
from agnext.agent_components.types import FunctionCall
@dataclass(kw_only=True)
class BaseMessage:
# Name of the agent that sent this message
source: str
@dataclass
class TextMessage(BaseMessage):
content: str
@dataclass
class MultiModalMessage(BaseMessage):
content: List[Union[str, Image]]
@dataclass
class FunctionCallMessage(BaseMessage):
content: List[FunctionCall]
@dataclass
class FunctionExecutionResult:
content: str
call_id: str
@dataclass
class FunctionExecutionResultMessage(BaseMessage):
content: List[FunctionExecutionResult]
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]

68
src/agnext/chat/utils.py Normal file
View File

@@ -0,0 +1,68 @@
from typing import List, Optional, Union
from agnext.agent_components.types import AssistantMessage, LLMMessage, UserMessage
from agnext.chat.types import FunctionCallMessage, Message, MultiModalMessage, TextMessage
from typing_extensions import Literal
def convert_content_message_to_assistant_message(
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[AssistantMessage]:
match message:
case TextMessage() | FunctionCallMessage():
return AssistantMessage(content=message.content, source=message.source)
case MultiModalMessage():
if handle_unrepresentable == "error":
raise ValueError("Cannot represent multimodal message as AssistantMessage")
elif handle_unrepresentable == "ignore":
return None
elif handle_unrepresentable == "try_slice":
return AssistantMessage(
content="".join([x for x in message.content if isinstance(x, str)]), source=message.source
)
def convert_content_message_to_user_message(
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[UserMessage]:
match message:
case TextMessage() | MultiModalMessage():
return UserMessage(content=message.content, source=message.source)
case FunctionCallMessage():
if handle_unrepresentable == "error":
raise ValueError("Cannot represent multimodal message as UserMessage")
elif handle_unrepresentable == "ignore":
return None
elif handle_unrepresentable == "try_slice":
# TODO: what is a sliced function call?
raise NotImplementedError("Sliced function calls not yet implemented")
def convert_messages_to_llm_messages(
messages: List[Message], self_name: str, handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error"
) -> List[LLMMessage]:
result: List[LLMMessage] = []
for message in messages:
match message:
case (
TextMessage(_, source=source)
| MultiModalMessage(_, source=source)
| FunctionCallMessage(_, source=source)
) if source == self_name:
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
if converted_message_1 is not None:
result.append(converted_message_1)
case (
TextMessage(_, source=source)
| MultiModalMessage(_, source=source)
| FunctionCallMessage(_, source=source)
) if source != self_name:
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
if converted_message_2 is not None:
result.append(converted_message_2)
case _:
raise AssertionError("unreachable")
return result