mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Move ChatMemory protocol to components.memory; rename files to use "_" convention for implementations. (#101)
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
from .chat_completion_agent import ChatCompletionAgent
|
||||
from .image_generation_agent import ImageGenerationAgent
|
||||
from .oai_assistant import OpenAIAssistantAgent
|
||||
from .user_proxy import UserProxyAgent
|
||||
from ._chat_completion_agent import ChatCompletionAgent
|
||||
from ._image_generation_agent import ImageGenerationAgent
|
||||
from ._oai_assistant import OpenAIAssistantAgent
|
||||
from ._user_proxy import UserProxyAgent
|
||||
|
||||
__all__ = ["ChatCompletionAgent", "OpenAIAssistantAgent", "UserProxyAgent", "ImageGenerationAgent"]
|
||||
__all__ = [
|
||||
"ChatCompletionAgent",
|
||||
"OpenAIAssistantAgent",
|
||||
"UserProxyAgent",
|
||||
"ImageGenerationAgent",
|
||||
]
|
||||
|
||||
@@ -7,6 +7,7 @@ from ...components import (
|
||||
TypeRoutedAgent,
|
||||
message_handler,
|
||||
)
|
||||
from ...components.memory import ChatMemory
|
||||
from ...components.models import (
|
||||
ChatCompletionClient,
|
||||
FunctionExecutionResult,
|
||||
@@ -15,7 +16,6 @@ from ...components.models import (
|
||||
)
|
||||
from ...components.tools import Tool
|
||||
from ...core import AgentId, CancellationToken
|
||||
from ..memory import ChatMemory
|
||||
from ..types import (
|
||||
FunctionCallMessage,
|
||||
Message,
|
||||
@@ -36,12 +36,11 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
responses and execute tools.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
system_messages (List[SystemMessage]): The system messages to use for
|
||||
the ChatCompletion API.
|
||||
memory (ChatMemory): The memory to store and retrieve messages.
|
||||
memory (ChatMemory[Message]): The memory to store and retrieve messages.
|
||||
model_client (ChatCompletionClient): The client to use for the
|
||||
ChatCompletion API.
|
||||
tools (Sequence[Tool], optional): The tools used by the agent. Defaults
|
||||
@@ -61,7 +60,7 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
self,
|
||||
description: str,
|
||||
system_messages: List[SystemMessage],
|
||||
memory: ChatMemory,
|
||||
memory: ChatMemory[Message],
|
||||
model_client: ChatCompletionClient,
|
||||
tools: Sequence[Tool] = [],
|
||||
tool_approver: AgentId | None = None,
|
||||
@@ -7,9 +7,10 @@ from ...components import (
|
||||
TypeRoutedAgent,
|
||||
message_handler,
|
||||
)
|
||||
from ...components.memory import ChatMemory
|
||||
from ...core import CancellationToken
|
||||
from ..memory import ChatMemory
|
||||
from ..types import (
|
||||
Message,
|
||||
MultiModalMessage,
|
||||
PublishNow,
|
||||
Reset,
|
||||
@@ -18,10 +19,20 @@ from ..types import (
|
||||
|
||||
|
||||
class ImageGenerationAgent(TypeRoutedAgent):
|
||||
"""An agent that generates images using DALL-E models. It publishes the
|
||||
generated images as MultiModalMessage.
|
||||
|
||||
Args:
|
||||
description (str): The description of the agent.
|
||||
memory (ChatMemory[Message]): The memory to store and retrieve messages.
|
||||
client (openai.AsyncClient): The client to use for the OpenAI API.
|
||||
model (Literal["dall-e-2", "dall-e-3"], optional): The DALL-E model to use. Defaults to "dall-e-2".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
memory: ChatMemory,
|
||||
memory: ChatMemory[Message],
|
||||
client: openai.AsyncClient,
|
||||
model: Literal["dall-e-2", "dall-e-3"] = "dall-e-2",
|
||||
):
|
||||
@@ -32,6 +43,7 @@ class ImageGenerationAgent(TypeRoutedAgent):
|
||||
|
||||
@message_handler
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a text message. This method adds the message to the memory."""
|
||||
await self._memory.add_message(message)
|
||||
|
||||
@message_handler
|
||||
@@ -40,6 +52,10 @@ class ImageGenerationAgent(TypeRoutedAgent):
|
||||
|
||||
@message_handler
|
||||
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a publish now message. This method generates an image using a DALL-E model with
|
||||
a prompt. The prompt is a concatenation of all TextMessages in the memory. The generated
|
||||
image is published as a MultiModalMessage."""
|
||||
|
||||
response = await self._generate_response(cancellation_token)
|
||||
self.publish_message(response)
|
||||
|
||||
@@ -14,7 +14,6 @@ class OpenAIAssistantAgent(TypeRoutedAgent):
|
||||
responses.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
client (openai.AsyncClient): The client to use for the OpenAI API.
|
||||
@@ -10,7 +10,6 @@ class UserProxyAgent(TypeRoutedAgent):
|
||||
method to customize how user input is retrieved.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
user_input_prompt (str): The console prompt to show to the user when asking for input.
|
||||
@@ -1,5 +1,4 @@
|
||||
from ._base import ChatMemory
|
||||
from ._buffered import BufferedChatMemory
|
||||
from ._head_and_tail import HeadAndTailChatMemory
|
||||
|
||||
__all__ = ["ChatMemory", "BufferedChatMemory", "HeadAndTailChatMemory"]
|
||||
__all__ = ["BufferedChatMemory", "HeadAndTailChatMemory"]
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from typing import Any, List, Mapping, Protocol
|
||||
|
||||
from ..types import Message
|
||||
|
||||
|
||||
class ChatMemory(Protocol):
|
||||
"""A protocol for defining the interface of a chat memory. A chat memory
|
||||
lets agents to store and retrieve messages. It can be implemented with
|
||||
different memory recall strategies."""
|
||||
|
||||
async def add_message(self, message: Message) -> None: ...
|
||||
|
||||
async def get_messages(self) -> List[Message]: ...
|
||||
|
||||
async def clear(self) -> None: ...
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]: ...
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None: ...
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from ...components.memory import ChatMemory
|
||||
from ...components.models import FunctionExecutionResultMessage
|
||||
from ..types import Message
|
||||
from ._base import ChatMemory
|
||||
|
||||
|
||||
class BufferedChatMemory(ChatMemory):
|
||||
class BufferedChatMemory(ChatMemory[Message]):
|
||||
"""A buffered chat memory that keeps a view of the last n messages,
|
||||
where n is the buffer size. The buffer size is set at initialization.
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from ...components.memory import ChatMemory
|
||||
from ...components.models import FunctionExecutionResultMessage
|
||||
from ..types import FunctionCallMessage, Message, TextMessage
|
||||
from ._base import ChatMemory
|
||||
|
||||
|
||||
class HeadAndTailChatMemory(ChatMemory):
|
||||
class HeadAndTailChatMemory(ChatMemory[Message]):
|
||||
"""A chat memory that keeps a view of the first n and last m messages,
|
||||
where n is the head size and m is the tail size. The head and tail sizes
|
||||
are set at initialization.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .group_chat_manager import GroupChatManager
|
||||
from ._group_chat_manager import GroupChatManager
|
||||
from ._orchestrator_chat import OrchestratorChat
|
||||
|
||||
__all__ = ["GroupChatManager"]
|
||||
__all__ = ["GroupChatManager", "OrchestratorChat"]
|
||||
|
||||
@@ -2,16 +2,17 @@ import logging
|
||||
from typing import Any, Callable, List, Mapping
|
||||
|
||||
from ...components import TypeRoutedAgent, message_handler
|
||||
from ...components.memory import ChatMemory
|
||||
from ...components.models import ChatCompletionClient
|
||||
from ...core import AgentId, AgentProxy, AgentRuntime, CancellationToken
|
||||
from ..memory import ChatMemory
|
||||
from ..types import (
|
||||
Message,
|
||||
MultiModalMessage,
|
||||
PublishNow,
|
||||
Reset,
|
||||
TextMessage,
|
||||
)
|
||||
from .group_chat_utils import select_speaker
|
||||
from ._group_chat_utils import select_speaker
|
||||
|
||||
logger = logging.getLogger("agnext.events")
|
||||
|
||||
@@ -24,7 +25,7 @@ class GroupChatManager(TypeRoutedAgent):
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
participants (List[AgentId]): The list of participants in the group chat.
|
||||
memory (ChatMemory): The memory to store and retrieve messages.
|
||||
memory (ChatMemory[Message]): The memory to store and retrieve messages.
|
||||
model_client (ChatCompletionClient, optional): The client to use for the model.
|
||||
If provided, the agent will use the model to select the next speaker.
|
||||
If not provided, the agent will select the next speaker from the list of participants
|
||||
@@ -44,7 +45,7 @@ class GroupChatManager(TypeRoutedAgent):
|
||||
description: str,
|
||||
runtime: AgentRuntime,
|
||||
participants: List[AgentId],
|
||||
memory: ChatMemory,
|
||||
memory: ChatMemory[Message],
|
||||
model_client: ChatCompletionClient | None = None,
|
||||
termination_word: str = "TERMINATE",
|
||||
transitions: Mapping[AgentId, List[AgentId]] = {},
|
||||
@@ -3,13 +3,13 @@
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from ...components.memory import ChatMemory
|
||||
from ...components.models import ChatCompletionClient, SystemMessage
|
||||
from ...core import AgentProxy
|
||||
from ..memory import ChatMemory
|
||||
from ..types import TextMessage
|
||||
from ..types import Message, TextMessage
|
||||
|
||||
|
||||
async def select_speaker(memory: ChatMemory, client: ChatCompletionClient, agents: List[AgentProxy]) -> int:
|
||||
async def select_speaker(memory: ChatMemory[Message], client: ChatCompletionClient, agents: List[AgentProxy]) -> int:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client."""
|
||||
# TODO: Handle multi-modal messages.
|
||||
|
||||
3
python/src/agnext/components/memory/__init__.py
Normal file
3
python/src/agnext/components/memory/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from ._base import ChatMemory
|
||||
|
||||
__all__ = ["ChatMemory"]
|
||||
19
python/src/agnext/components/memory/_base.py
Normal file
19
python/src/agnext/components/memory/_base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import List, Mapping, Protocol, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ChatMemory(Protocol[T]):
|
||||
"""A protocol for defining the interface of a chat memory. A chat memory
|
||||
lets agents store and retrieve messages. It can be implemented with
|
||||
different memory recall strategies."""
|
||||
|
||||
async def add_message(self, message: T) -> None: ...
|
||||
|
||||
async def get_messages(self) -> List[T]: ...
|
||||
|
||||
async def clear(self) -> None: ...
|
||||
|
||||
def save_state(self) -> Mapping[str, T]: ...
|
||||
|
||||
def load_state(self, state: Mapping[str, T]) -> None: ...
|
||||
Reference in New Issue
Block a user