use model context for assistant agent, refactor model context (#4681)

* Decouple model_context from AssistantAgent

* add UnboundedBufferedChatCompletionContext to mimic pervious model_context behaviour on AssistantAgent

* moving unbounded buffered chat to a different file

* fix model_context assertions in test_group_chat

* Refactor model context, introduce states

* fixes

* update

---------

Co-authored-by: aditya.kurniawan <aditya.kurniawan@core42.ai>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
Aditya Kurniawan
2024-12-20 10:27:41 +04:00
committed by GitHub
parent a271708a97
commit c989181da2
13 changed files with 183 additions and 97 deletions

View File

@@ -454,17 +454,17 @@
"\n",
"The above `SimpleAgent` always responds with a fresh context that contains only\n",
"the system message and the latest user's message.\n",
"We can use model context classes from {py:mod}`autogen_core.components.model_context`\n",
"We can use model context classes from {py:mod}`autogen_core.model_context`\n",
"to make the agent \"remember\" previous conversations.\n",
"A model context supports storage and retrieval of Chat Completion messages.\n",
"It is always used together with a model client to generate LLM-based responses.\n",
"\n",
"For example, {py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`\n",
"For example, {py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`\n",
"is a most-recent-used (MRU) context that stores the most recent `buffer_size`\n",
"number of messages. This is useful to avoid context overflow in many LLMs.\n",
"\n",
"Let's update the previous example to use\n",
"{py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`."
"{py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`."
]
},
{
@@ -473,7 +473,7 @@
"metadata": {},
"outputs": [],
"source": [
"from autogen_core.components.model_context import BufferedChatCompletionContext\n",
"from autogen_core.model_context import BufferedChatCompletionContext\n",
"from autogen_core.models import AssistantMessage\n",
"\n",
"\n",
@@ -615,7 +615,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.7"
}
},
"nbformat": 4,

View File

@@ -254,10 +254,10 @@ class ChatCompletionAgent(RoutedAgent):
async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"chat_history": await self._model_context.save_state(),
"system_messages": self._system_messages,
}
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["memory"])
await self._model_context.load_state(state["chat_history"])
self._system_messages = state["system_messages"]

View File

@@ -143,10 +143,12 @@ class GroupChatManager(RoutedAgent):
async def save_state(self) -> Mapping[str, Any]:
return {
"chat_history": self._model_context.save_state(),
"chat_history": await self._model_context.save_state(),
"termination_word": self._termination_word,
}
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["chat_history"])
# Load the chat history.
await self._model_context.load_state(state["chat_history"])
# Load the termination word.
self._termination_word = state["termination_word"]

View File

@@ -114,7 +114,7 @@ class SlowUserProxyAgent(RoutedAgent):
return state_to_save
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
await self._model_context.load_state(state["memory"])
class ScheduleMeetingInput(BaseModel):
@@ -200,11 +200,11 @@ Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"memory": await self._model_context.save_state(),
}
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
await self._model_context.load_state(state["memory"])
class NeedsUserInputHandler(DefaultInterventionHandler):

View File

@@ -1,9 +1,14 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._unbounded_chat_completion_context import (
UnboundedChatCompletionContext,
)
__all__ = [
"ChatCompletionContext",
"ChatCompletionContextState",
"UnboundedChatCompletionContext",
"BufferedChatCompletionContext",
"HeadAndTailChatCompletionContext",
]

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Mapping
from typing import List
from ..models import FunctionExecutionResultMessage, LLMMessage
from ._chat_completion_context import ChatCompletionContext
@@ -10,17 +10,15 @@ class BufferedChatCompletionContext(ChatCompletionContext):
Args:
buffer_size (int): The size of the buffer.
initial_messages (List[LLMMessage] | None): The initial messages.
"""
def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []
super().__init__(initial_messages)
if buffer_size <= 0:
raise ValueError("buffer_size must be greater than 0.")
self._buffer_size = buffer_size
async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)
async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
messages = self._messages[-self._buffer_size :]
@@ -29,17 +27,3 @@ class BufferedChatCompletionContext(ChatCompletionContext):
# Remove the first message from the list.
messages = messages[1:]
return messages
async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []
def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"buffer_size": self._buffer_size,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._buffer_size = state["buffer_size"]

View File

@@ -1,19 +1,40 @@
from typing import List, Mapping, Protocol
from abc import ABC, abstractmethod
from typing import Any, List, Mapping
from pydantic import BaseModel, Field
from ..models import LLMMessage
class ChatCompletionContext(Protocol):
"""A protocol for defining the interface of a chat completion context.
class ChatCompletionContext(ABC):
"""An abstract base class for defining the interface of a chat completion context.
A chat completion context lets agents store and retrieve LLM messages.
It can be implemented with different recall strategies."""
It can be implemented with different recall strategies.
async def add_message(self, message: LLMMessage) -> None: ...
Args:
initial_messages (List[LLMMessage] | None): The initial messages.
"""
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []
async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the context."""
self._messages.append(message)
@abstractmethod
async def get_messages(self) -> List[LLMMessage]: ...
async def clear(self) -> None: ...
async def clear(self) -> None:
"""Clear the context."""
self._messages = []
def save_state(self) -> Mapping[str, LLMMessage]: ...
async def save_state(self) -> Mapping[str, Any]:
return ChatCompletionContextState(messages=self._messages).model_dump()
def load_state(self, state: Mapping[str, LLMMessage]) -> None: ...
async def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = ChatCompletionContextState.model_validate(state).messages
class ChatCompletionContextState(BaseModel):
messages: List[LLMMessage] = Field(default_factory=list)

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Mapping
from typing import List
from .._types import FunctionCall
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
@@ -13,17 +13,18 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
Args:
head_size (int): The size of the head.
tail_size (int): The size of the tail.
initial_messages (List[LLMMessage] | None): The initial messages.
"""
def __init__(self, head_size: int, tail_size: int) -> None:
self._messages: List[LLMMessage] = []
def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
super().__init__(initial_messages)
if head_size <= 0:
raise ValueError("head_size must be greater than 0.")
if tail_size <= 0:
raise ValueError("tail_size must be greater than 0.")
self._head_size = head_size
self._tail_size = tail_size
async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)
async def get_messages(self) -> List[LLMMessage]:
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
head_messages = self._messages[: self._head_size]
@@ -51,21 +52,3 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")]
return head_messages + placeholder_messages + tail_messages
async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []
def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"head_size": self._head_size,
"tail_size": self._tail_size,
"placeholder_message": self._placeholder_message,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._head_size = state["head_size"]
self._tail_size = state["tail_size"]
self._placeholder_message = state["placeholder_message"]

View File

@@ -0,0 +1,12 @@
from typing import List
from ..models import LLMMessage
from ._chat_completion_context import ChatCompletionContext
class UnboundedChatCompletionContext(ChatCompletionContext):
"""An unbounded chat completion context that keeps a view of the all the messages."""
async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
return self._messages

View File

@@ -1,7 +1,11 @@
from typing import List
import pytest
from autogen_core.model_context import BufferedChatCompletionContext, HeadAndTailChatCompletionContext
from autogen_core.model_context import (
BufferedChatCompletionContext,
HeadAndTailChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage
@@ -26,6 +30,17 @@ async def test_buffered_model_context() -> None:
retrieved = await model_context.get_messages()
assert len(retrieved) == 0
# Test saving and loading state.
await model_context.add_message(messages[0])
await model_context.add_message(messages[1])
state = await model_context.save_state()
await model_context.clear()
await model_context.load_state(state)
retrieved = await model_context.get_messages()
assert len(retrieved) == 2
assert retrieved[0] == messages[0]
assert retrieved[1] == messages[1]
@pytest.mark.asyncio
async def test_head_and_tail_model_context() -> None:
@@ -48,3 +63,44 @@ async def test_head_and_tail_model_context() -> None:
await model_context.clear()
retrieved = await model_context.get_messages()
assert len(retrieved) == 0
# Test saving and loading state.
for msg in messages:
await model_context.add_message(msg)
state = await model_context.save_state()
await model_context.clear()
await model_context.load_state(state)
retrived = await model_context.get_messages()
assert len(retrived) == 3
assert retrived[0] == messages[0]
assert retrived[2] == messages[-1]
@pytest.mark.asyncio
async def test_unbounded_model_context() -> None:
model_context = UnboundedChatCompletionContext()
messages: List[LLMMessage] = [
UserMessage(content="Hello!", source="user"),
AssistantMessage(content="What can I do for you?", source="assistant"),
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
]
for msg in messages:
await model_context.add_message(msg)
retrieved = await model_context.get_messages()
assert len(retrieved) == 3
assert retrieved == messages
await model_context.clear()
retrieved = await model_context.get_messages()
assert len(retrieved) == 0
# Test saving and loading state.
for msg in messages:
await model_context.add_message(msg)
state = await model_context.save_state()
await model_context.clear()
await model_context.load_state(state)
retrieved = await model_context.get_messages()
assert len(retrieved) == 3
assert retrieved == messages