mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
use model context for assistant agent, refactor model context (#4681)
* Decouple model_context from AssistantAgent * add UnboundedBufferedChatCompletionContext to mimic pervious model_context behaviour on AssistantAgent * moving unbounded buffered chat to a different file * fix model_context assertions in test_group_chat * Refactor model context, introduce states * fixes * update --------- Co-authored-by: aditya.kurniawan <aditya.kurniawan@core42.ai> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
@@ -454,17 +454,17 @@
|
||||
"\n",
|
||||
"The above `SimpleAgent` always responds with a fresh context that contains only\n",
|
||||
"the system message and the latest user's message.\n",
|
||||
"We can use model context classes from {py:mod}`autogen_core.components.model_context`\n",
|
||||
"We can use model context classes from {py:mod}`autogen_core.model_context`\n",
|
||||
"to make the agent \"remember\" previous conversations.\n",
|
||||
"A model context supports storage and retrieval of Chat Completion messages.\n",
|
||||
"It is always used together with a model client to generate LLM-based responses.\n",
|
||||
"\n",
|
||||
"For example, {py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`\n",
|
||||
"For example, {py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`\n",
|
||||
"is a most-recent-used (MRU) context that stores the most recent `buffer_size`\n",
|
||||
"number of messages. This is useful to avoid context overflow in many LLMs.\n",
|
||||
"\n",
|
||||
"Let's update the previous example to use\n",
|
||||
"{py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`."
|
||||
"{py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -473,7 +473,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_core.components.model_context import BufferedChatCompletionContext\n",
|
||||
"from autogen_core.model_context import BufferedChatCompletionContext\n",
|
||||
"from autogen_core.models import AssistantMessage\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -615,7 +615,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -254,10 +254,10 @@ class ChatCompletionAgent(RoutedAgent):
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._model_context.save_state(),
|
||||
"chat_history": await self._model_context.save_state(),
|
||||
"system_messages": self._system_messages,
|
||||
}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state(state["memory"])
|
||||
await self._model_context.load_state(state["chat_history"])
|
||||
self._system_messages = state["system_messages"]
|
||||
|
||||
@@ -143,10 +143,12 @@ class GroupChatManager(RoutedAgent):
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"chat_history": self._model_context.save_state(),
|
||||
"chat_history": await self._model_context.save_state(),
|
||||
"termination_word": self._termination_word,
|
||||
}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state(state["chat_history"])
|
||||
# Load the chat history.
|
||||
await self._model_context.load_state(state["chat_history"])
|
||||
# Load the termination word.
|
||||
self._termination_word = state["termination_word"]
|
||||
|
||||
@@ -114,7 +114,7 @@ class SlowUserProxyAgent(RoutedAgent):
|
||||
return state_to_save
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
|
||||
await self._model_context.load_state(state["memory"])
|
||||
|
||||
|
||||
class ScheduleMeetingInput(BaseModel):
|
||||
@@ -200,11 +200,11 @@ Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._model_context.save_state(),
|
||||
"memory": await self._model_context.save_state(),
|
||||
}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
|
||||
await self._model_context.load_state(state["memory"])
|
||||
|
||||
|
||||
class NeedsUserInputHandler(DefaultInterventionHandler):
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from ._buffered_chat_completion_context import BufferedChatCompletionContext
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
|
||||
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
|
||||
from ._unbounded_chat_completion_context import (
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatCompletionContext",
|
||||
"ChatCompletionContextState",
|
||||
"UnboundedChatCompletionContext",
|
||||
"BufferedChatCompletionContext",
|
||||
"HeadAndTailChatCompletionContext",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Mapping
|
||||
from typing import List
|
||||
|
||||
from ..models import FunctionExecutionResultMessage, LLMMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
@@ -10,17 +10,15 @@ class BufferedChatCompletionContext(ChatCompletionContext):
|
||||
|
||||
Args:
|
||||
buffer_size (int): The size of the buffer.
|
||||
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
self._messages: List[LLMMessage] = initial_messages or []
|
||||
super().__init__(initial_messages)
|
||||
if buffer_size <= 0:
|
||||
raise ValueError("buffer_size must be greater than 0.")
|
||||
self._buffer_size = buffer_size
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
||||
"""Add a message to the memory."""
|
||||
self._messages.append(message)
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `buffer_size` recent messages."""
|
||||
messages = self._messages[-self._buffer_size :]
|
||||
@@ -29,17 +27,3 @@ class BufferedChatCompletionContext(ChatCompletionContext):
|
||||
# Remove the first message from the list.
|
||||
messages = messages[1:]
|
||||
return messages
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the message memory."""
|
||||
self._messages = []
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"messages": [message for message in self._messages],
|
||||
"buffer_size": self._buffer_size,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = state["messages"]
|
||||
self._buffer_size = state["buffer_size"]
|
||||
|
||||
@@ -1,19 +1,40 @@
|
||||
from typing import List, Mapping, Protocol
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models import LLMMessage
|
||||
|
||||
|
||||
class ChatCompletionContext(Protocol):
|
||||
"""A protocol for defining the interface of a chat completion context.
|
||||
class ChatCompletionContext(ABC):
|
||||
"""An abstract base class for defining the interface of a chat completion context.
|
||||
A chat completion context lets agents store and retrieve LLM messages.
|
||||
It can be implemented with different recall strategies."""
|
||||
It can be implemented with different recall strategies.
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None: ...
|
||||
Args:
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
self._messages: List[LLMMessage] = initial_messages or []
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
||||
"""Add a message to the context."""
|
||||
self._messages.append(message)
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages(self) -> List[LLMMessage]: ...
|
||||
|
||||
async def clear(self) -> None: ...
|
||||
async def clear(self) -> None:
|
||||
"""Clear the context."""
|
||||
self._messages = []
|
||||
|
||||
def save_state(self) -> Mapping[str, LLMMessage]: ...
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return ChatCompletionContextState(messages=self._messages).model_dump()
|
||||
|
||||
def load_state(self, state: Mapping[str, LLMMessage]) -> None: ...
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = ChatCompletionContextState.model_validate(state).messages
|
||||
|
||||
|
||||
class ChatCompletionContextState(BaseModel):
|
||||
messages: List[LLMMessage] = Field(default_factory=list)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Mapping
|
||||
from typing import List
|
||||
|
||||
from .._types import FunctionCall
|
||||
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
|
||||
@@ -13,17 +13,18 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
|
||||
Args:
|
||||
head_size (int): The size of the head.
|
||||
tail_size (int): The size of the tail.
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
def __init__(self, head_size: int, tail_size: int) -> None:
|
||||
self._messages: List[LLMMessage] = []
|
||||
def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
super().__init__(initial_messages)
|
||||
if head_size <= 0:
|
||||
raise ValueError("head_size must be greater than 0.")
|
||||
if tail_size <= 0:
|
||||
raise ValueError("tail_size must be greater than 0.")
|
||||
self._head_size = head_size
|
||||
self._tail_size = tail_size
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
||||
"""Add a message to the memory."""
|
||||
self._messages.append(message)
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
|
||||
head_messages = self._messages[: self._head_size]
|
||||
@@ -51,21 +52,3 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
|
||||
|
||||
placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")]
|
||||
return head_messages + placeholder_messages + tail_messages
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the message memory."""
|
||||
self._messages = []
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"messages": [message for message in self._messages],
|
||||
"head_size": self._head_size,
|
||||
"tail_size": self._tail_size,
|
||||
"placeholder_message": self._placeholder_message,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = state["messages"]
|
||||
self._head_size = state["head_size"]
|
||||
self._tail_size = state["tail_size"]
|
||||
self._placeholder_message = state["placeholder_message"]
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
from typing import List
|
||||
|
||||
from ..models import LLMMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class UnboundedChatCompletionContext(ChatCompletionContext):
|
||||
"""An unbounded chat completion context that keeps a view of the all the messages."""
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `buffer_size` recent messages."""
|
||||
return self._messages
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from autogen_core.model_context import BufferedChatCompletionContext, HeadAndTailChatCompletionContext
|
||||
from autogen_core.model_context import (
|
||||
BufferedChatCompletionContext,
|
||||
HeadAndTailChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage
|
||||
|
||||
|
||||
@@ -26,6 +30,17 @@ async def test_buffered_model_context() -> None:
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
||||
|
||||
# Test saving and loading state.
|
||||
await model_context.add_message(messages[0])
|
||||
await model_context.add_message(messages[1])
|
||||
state = await model_context.save_state()
|
||||
await model_context.clear()
|
||||
await model_context.load_state(state)
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 2
|
||||
assert retrieved[0] == messages[0]
|
||||
assert retrieved[1] == messages[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_head_and_tail_model_context() -> None:
|
||||
@@ -48,3 +63,44 @@ async def test_head_and_tail_model_context() -> None:
|
||||
await model_context.clear()
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
||||
|
||||
# Test saving and loading state.
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
state = await model_context.save_state()
|
||||
await model_context.clear()
|
||||
await model_context.load_state(state)
|
||||
retrived = await model_context.get_messages()
|
||||
assert len(retrived) == 3
|
||||
assert retrived[0] == messages[0]
|
||||
assert retrived[2] == messages[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unbounded_model_context() -> None:
|
||||
model_context = UnboundedChatCompletionContext()
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||
]
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 3
|
||||
assert retrieved == messages
|
||||
|
||||
await model_context.clear()
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
||||
|
||||
# Test saving and loading state.
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
state = await model_context.save_state()
|
||||
await model_context.clear()
|
||||
await model_context.load_state(state)
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 3
|
||||
assert retrieved == messages
|
||||
|
||||
Reference in New Issue
Block a user