Move ChatMemory protocol to components.memory; rename files to use "_" convention for implementations. (#101)

This commit is contained in:
Eric Zhu
2024-06-21 04:06:01 -07:00
committed by GitHub
parent d365a588cb
commit 51dde2916f
22 changed files with 110 additions and 111 deletions

View File

@@ -1,6 +1,11 @@
from .chat_completion_agent import ChatCompletionAgent
from .image_generation_agent import ImageGenerationAgent
from .oai_assistant import OpenAIAssistantAgent
from .user_proxy import UserProxyAgent
from ._chat_completion_agent import ChatCompletionAgent
from ._image_generation_agent import ImageGenerationAgent
from ._oai_assistant import OpenAIAssistantAgent
from ._user_proxy import UserProxyAgent
__all__ = ["ChatCompletionAgent", "OpenAIAssistantAgent", "UserProxyAgent", "ImageGenerationAgent"]
__all__ = [
"ChatCompletionAgent",
"OpenAIAssistantAgent",
"UserProxyAgent",
"ImageGenerationAgent",
]

View File

@@ -7,6 +7,7 @@ from ...components import (
TypeRoutedAgent,
message_handler,
)
from ...components.memory import ChatMemory
from ...components.models import (
ChatCompletionClient,
FunctionExecutionResult,
@@ -15,7 +16,6 @@ from ...components.models import (
)
from ...components.tools import Tool
from ...core import AgentId, CancellationToken
from ..memory import ChatMemory
from ..types import (
FunctionCallMessage,
Message,
@@ -36,12 +36,11 @@ class ChatCompletionAgent(TypeRoutedAgent):
responses and execute tools.
Args:
name (str): The name of the agent.
description (str): The description of the agent.
runtime (AgentRuntime): The runtime to register the agent.
system_messages (List[SystemMessage]): The system messages to use for
the ChatCompletion API.
memory (ChatMemory): The memory to store and retrieve messages.
memory (ChatMemory[Message]): The memory to store and retrieve messages.
model_client (ChatCompletionClient): The client to use for the
ChatCompletion API.
tools (Sequence[Tool], optional): The tools used by the agent. Defaults
@@ -61,7 +60,7 @@ class ChatCompletionAgent(TypeRoutedAgent):
self,
description: str,
system_messages: List[SystemMessage],
memory: ChatMemory,
memory: ChatMemory[Message],
model_client: ChatCompletionClient,
tools: Sequence[Tool] = [],
tool_approver: AgentId | None = None,

View File

@@ -7,9 +7,10 @@ from ...components import (
TypeRoutedAgent,
message_handler,
)
from ...components.memory import ChatMemory
from ...core import CancellationToken
from ..memory import ChatMemory
from ..types import (
Message,
MultiModalMessage,
PublishNow,
Reset,
@@ -18,10 +19,20 @@ from ..types import (
class ImageGenerationAgent(TypeRoutedAgent):
"""An agent that generates images using DALL-E models. It publishes the
generated images as MultiModalMessage.
Args:
description (str): The description of the agent.
memory (ChatMemory[Message]): The memory to store and retrieve messages.
client (openai.AsyncClient): The client to use for the OpenAI API.
model (Literal["dall-e-2", "dall-e-3"], optional): The DALL-E model to use. Defaults to "dall-e-2".
"""
def __init__(
self,
description: str,
memory: ChatMemory,
memory: ChatMemory[Message],
client: openai.AsyncClient,
model: Literal["dall-e-2", "dall-e-3"] = "dall-e-2",
):
@@ -32,6 +43,7 @@ class ImageGenerationAgent(TypeRoutedAgent):
@message_handler
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
"""Handle a text message. This method adds the message to the memory."""
await self._memory.add_message(message)
@message_handler
@@ -40,6 +52,10 @@ class ImageGenerationAgent(TypeRoutedAgent):
@message_handler
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
"""Handle a publish now message. This method generates an image using a DALL-E model with
a prompt. The prompt is a concatenation of all TextMessages in the memory. The generated
image is published as a MultiModalMessage."""
response = await self._generate_response(cancellation_token)
self.publish_message(response)

View File

@@ -14,7 +14,6 @@ class OpenAIAssistantAgent(TypeRoutedAgent):
responses.
Args:
name (str): The name of the agent.
description (str): The description of the agent.
runtime (AgentRuntime): The runtime to register the agent.
client (openai.AsyncClient): The client to use for the OpenAI API.

View File

@@ -10,7 +10,6 @@ class UserProxyAgent(TypeRoutedAgent):
method to customize how user input is retrieved.
Args:
name (str): The name of the agent.
description (str): The description of the agent.
runtime (AgentRuntime): The runtime to register the agent.
user_input_prompt (str): The console prompt to show to the user when asking for input.

View File

@@ -1,5 +1,4 @@
from ._base import ChatMemory
from ._buffered import BufferedChatMemory
from ._head_and_tail import HeadAndTailChatMemory
__all__ = ["ChatMemory", "BufferedChatMemory", "HeadAndTailChatMemory"]
__all__ = ["BufferedChatMemory", "HeadAndTailChatMemory"]

View File

@@ -1,19 +0,0 @@
from typing import Any, List, Mapping, Protocol
from ..types import Message
class ChatMemory(Protocol):
"""A protocol for defining the interface of a chat memory. A chat memory
lets agents to store and retrieve messages. It can be implemented with
different memory recall strategies."""
async def add_message(self, message: Message) -> None: ...
async def get_messages(self) -> List[Message]: ...
async def clear(self) -> None: ...
def save_state(self) -> Mapping[str, Any]: ...
def load_state(self, state: Mapping[str, Any]) -> None: ...

View File

@@ -1,11 +1,11 @@
from typing import Any, List, Mapping
from ...components.memory import ChatMemory
from ...components.models import FunctionExecutionResultMessage
from ..types import Message
from ._base import ChatMemory
class BufferedChatMemory(ChatMemory):
class BufferedChatMemory(ChatMemory[Message]):
"""A buffered chat memory that keeps a view of the last n messages,
where n is the buffer size. The buffer size is set at initialization.

View File

@@ -1,11 +1,11 @@
from typing import Any, List, Mapping
from ...components.memory import ChatMemory
from ...components.models import FunctionExecutionResultMessage
from ..types import FunctionCallMessage, Message, TextMessage
from ._base import ChatMemory
class HeadAndTailChatMemory(ChatMemory):
class HeadAndTailChatMemory(ChatMemory[Message]):
"""A chat memory that keeps a view of the first n and last m messages,
where n is the head size and m is the tail size. The head and tail sizes
are set at initialization.

View File

@@ -1,3 +1,4 @@
from .group_chat_manager import GroupChatManager
from ._group_chat_manager import GroupChatManager
from ._orchestrator_chat import OrchestratorChat
__all__ = ["GroupChatManager"]
__all__ = ["GroupChatManager", "OrchestratorChat"]

View File

@@ -2,16 +2,17 @@ import logging
from typing import Any, Callable, List, Mapping
from ...components import TypeRoutedAgent, message_handler
from ...components.memory import ChatMemory
from ...components.models import ChatCompletionClient
from ...core import AgentId, AgentProxy, AgentRuntime, CancellationToken
from ..memory import ChatMemory
from ..types import (
Message,
MultiModalMessage,
PublishNow,
Reset,
TextMessage,
)
from .group_chat_utils import select_speaker
from ._group_chat_utils import select_speaker
logger = logging.getLogger("agnext.events")
@@ -24,7 +25,7 @@ class GroupChatManager(TypeRoutedAgent):
description (str): The description of the agent.
runtime (AgentRuntime): The runtime to register the agent.
participants (List[AgentId]): The list of participants in the group chat.
memory (ChatMemory): The memory to store and retrieve messages.
memory (ChatMemory[Message]): The memory to store and retrieve messages.
model_client (ChatCompletionClient, optional): The client to use for the model.
If provided, the agent will use the model to select the next speaker.
If not provided, the agent will select the next speaker from the list of participants
@@ -44,7 +45,7 @@ class GroupChatManager(TypeRoutedAgent):
description: str,
runtime: AgentRuntime,
participants: List[AgentId],
memory: ChatMemory,
memory: ChatMemory[Message],
model_client: ChatCompletionClient | None = None,
termination_word: str = "TERMINATE",
transitions: Mapping[AgentId, List[AgentId]] = {},

View File

@@ -3,13 +3,13 @@
import re
from typing import Dict, List
from ...components.memory import ChatMemory
from ...components.models import ChatCompletionClient, SystemMessage
from ...core import AgentProxy
from ..memory import ChatMemory
from ..types import TextMessage
from ..types import Message, TextMessage
async def select_speaker(memory: ChatMemory, client: ChatCompletionClient, agents: List[AgentProxy]) -> int:
async def select_speaker(memory: ChatMemory[Message], client: ChatCompletionClient, agents: List[AgentProxy]) -> int:
"""Selects the next speaker in a group chat using a ChatCompletion client."""
# TODO: Handle multi-modal messages.

View File

@@ -0,0 +1,3 @@
from ._base import ChatMemory
__all__ = ["ChatMemory"]

View File

@@ -0,0 +1,19 @@
from typing import List, Mapping, Protocol, TypeVar
T = TypeVar("T")
class ChatMemory(Protocol[T]):
"""A protocol for defining the interface of a chat memory. A chat memory
lets agents store and retrieve messages. It can be implemented with
different memory recall strategies."""
async def add_message(self, message: T) -> None: ...
async def get_messages(self) -> List[T]: ...
async def clear(self) -> None: ...
def save_state(self) -> Mapping[str, T]: ...
def load_state(self, state: Mapping[str, T]) -> None: ...