use model context for assistant agent, refactor model context (#4681)

* Decouple model_context from AssistantAgent

* add UnboundedBufferedChatCompletionContext to mimic pervious model_context behaviour on AssistantAgent

* moving unbounded buffered chat to a different file

* fix model_context assertions in test_group_chat

* Refactor model context, introduce states

* fixes

* update

---------

Co-authored-by: aditya.kurniawan <aditya.kurniawan@core42.ai>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
Aditya Kurniawan
2024-12-20 10:27:41 +04:00
committed by GitHub
parent a271708a97
commit c989181da2
13 changed files with 183 additions and 97 deletions

View File

@@ -2,15 +2,27 @@ import asyncio
import json
import logging
import warnings
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
List,
Mapping,
Sequence,
)
from autogen_core import CancellationToken, FunctionCall
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
@@ -87,7 +99,6 @@ class AssistantAgent(BaseChatAgent):
If multiple handoffs are detected, only the first handoff is executed.
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference.
@@ -96,8 +107,9 @@ class AssistantAgent(BaseChatAgent):
allowing it to transfer to other agents by responding with a :class:`HandoffMessage`.
The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`.
If a handoff is a string, it should represent the target agent's name.
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
description (str, optional): The description of the agent.
system_message (str, optional): The system message for the model.
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
@@ -219,9 +231,11 @@ class AssistantAgent(BaseChatAgent):
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[HandoffBase | str] | None = None,
model_context: ChatCompletionContext | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
system_message: (
str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
):
@@ -273,7 +287,8 @@ class AssistantAgent(BaseChatAgent):
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []
if not model_context:
self._model_context = UnboundedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
@@ -301,19 +316,19 @@ class AssistantAgent(BaseChatAgent):
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
raise ValueError("The model does not support vision.")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))
# Inner messages.
inner_messages: List[AgentEvent | ChatMessage] = []
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
# Check if the response is a string and return it.
if isinstance(result.content, str):
@@ -335,7 +350,7 @@ class AssistantAgent(BaseChatAgent):
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg
@@ -360,11 +375,11 @@ class AssistantAgent(BaseChatAgent):
if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + self._model_context
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
@@ -406,14 +421,15 @@ class AssistantAgent(BaseChatAgent):
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
self._model_context.clear()
await self._model_context.clear()
async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the assistant agent."""
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()
model_context_state = await self._model_context.save_state()
return AssistantAgentState(llm_context=model_context_state).model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the assistant agent"""
assistant_agent_state = AssistantAgentState.model_validate(state)
self._model_context.clear()
self._model_context.extend(assistant_agent_state.llm_messages)
# Load the model context state.
await self._model_context.load_state(assistant_agent_state.llm_context)

View File

@@ -1,8 +1,5 @@
from typing import Any, List, Mapping, Optional
from autogen_core.models import (
LLMMessage,
)
from pydantic import BaseModel, Field
from ..messages import (
@@ -21,7 +18,7 @@ class BaseState(BaseModel):
class AssistantAgentState(BaseState):
"""State for an assistant agent."""
llm_messages: List[LLMMessage] = Field(default_factory=list)
llm_context: Mapping[str, Any] = Field(default_factory=lambda: dict([("messages", [])]))
type: str = Field(default="AssistantAgentState")

View File

@@ -239,8 +239,13 @@ async def test_round_robin_group_chat_state() -> None:
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
assert agent3._model_context == agent1._model_context # pyright: ignore
assert agent4._model_context == agent2._model_context # pyright: ignore
agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
assert agent3_model_ctx_messages == agent1_model_ctx_messages
assert agent4_model_ctx_messages == agent2_model_ctx_messages
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
RoundRobinGroupChatManager, # pyright: ignore
@@ -337,7 +342,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
# Test streaming.
tool_use_agent._model_context.clear() # pyright: ignore
await tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
@@ -351,7 +356,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
index += 1
# Test Console.
tool_use_agent._model_context.clear() # pyright: ignore
await tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
@@ -579,8 +584,13 @@ async def test_selector_group_chat_state() -> None:
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
assert agent3._model_context == agent1._model_context # pyright: ignore
assert agent4._model_context == agent2._model_context # pyright: ignore
agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
assert agent3_model_ctx_messages == agent1_model_ctx_messages
assert agent4_model_ctx_messages == agent2_model_ctx_messages
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
SelectorGroupChatManager, # pyright: ignore
@@ -931,7 +941,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
# Test streaming.
agent1._model_context.clear() # pyright: ignore
await agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
@@ -944,7 +954,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
index += 1
# Test Console
agent1._model_context.clear() # pyright: ignore
await agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()

View File

@@ -454,17 +454,17 @@
"\n",
"The above `SimpleAgent` always responds with a fresh context that contains only\n",
"the system message and the latest user's message.\n",
"We can use model context classes from {py:mod}`autogen_core.components.model_context`\n",
"We can use model context classes from {py:mod}`autogen_core.model_context`\n",
"to make the agent \"remember\" previous conversations.\n",
"A model context supports storage and retrieval of Chat Completion messages.\n",
"It is always used together with a model client to generate LLM-based responses.\n",
"\n",
"For example, {py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`\n",
"For example, {py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`\n",
"is a most-recent-used (MRU) context that stores the most recent `buffer_size`\n",
"number of messages. This is useful to avoid context overflow in many LLMs.\n",
"\n",
"Let's update the previous example to use\n",
"{py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`."
"{py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`."
]
},
{
@@ -473,7 +473,7 @@
"metadata": {},
"outputs": [],
"source": [
"from autogen_core.components.model_context import BufferedChatCompletionContext\n",
"from autogen_core.model_context import BufferedChatCompletionContext\n",
"from autogen_core.models import AssistantMessage\n",
"\n",
"\n",
@@ -615,7 +615,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.7"
}
},
"nbformat": 4,

View File

@@ -254,10 +254,10 @@ class ChatCompletionAgent(RoutedAgent):
async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"chat_history": await self._model_context.save_state(),
"system_messages": self._system_messages,
}
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["memory"])
await self._model_context.load_state(state["chat_history"])
self._system_messages = state["system_messages"]

View File

@@ -143,10 +143,12 @@ class GroupChatManager(RoutedAgent):
async def save_state(self) -> Mapping[str, Any]:
return {
"chat_history": self._model_context.save_state(),
"chat_history": await self._model_context.save_state(),
"termination_word": self._termination_word,
}
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["chat_history"])
# Load the chat history.
await self._model_context.load_state(state["chat_history"])
# Load the termination word.
self._termination_word = state["termination_word"]

View File

@@ -114,7 +114,7 @@ class SlowUserProxyAgent(RoutedAgent):
return state_to_save
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
await self._model_context.load_state(state["memory"])
class ScheduleMeetingInput(BaseModel):
@@ -200,11 +200,11 @@ Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"memory": await self._model_context.save_state(),
}
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
await self._model_context.load_state(state["memory"])
class NeedsUserInputHandler(DefaultInterventionHandler):

View File

@@ -1,9 +1,14 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._unbounded_chat_completion_context import (
UnboundedChatCompletionContext,
)
__all__ = [
"ChatCompletionContext",
"ChatCompletionContextState",
"UnboundedChatCompletionContext",
"BufferedChatCompletionContext",
"HeadAndTailChatCompletionContext",
]

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Mapping
from typing import List
from ..models import FunctionExecutionResultMessage, LLMMessage
from ._chat_completion_context import ChatCompletionContext
@@ -10,17 +10,15 @@ class BufferedChatCompletionContext(ChatCompletionContext):
Args:
buffer_size (int): The size of the buffer.
initial_messages (List[LLMMessage] | None): The initial messages.
"""
def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []
super().__init__(initial_messages)
if buffer_size <= 0:
raise ValueError("buffer_size must be greater than 0.")
self._buffer_size = buffer_size
async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)
async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
messages = self._messages[-self._buffer_size :]
@@ -29,17 +27,3 @@ class BufferedChatCompletionContext(ChatCompletionContext):
# Remove the first message from the list.
messages = messages[1:]
return messages
async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []
def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"buffer_size": self._buffer_size,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._buffer_size = state["buffer_size"]

View File

@@ -1,19 +1,40 @@
from typing import List, Mapping, Protocol
from abc import ABC, abstractmethod
from typing import Any, List, Mapping
from pydantic import BaseModel, Field
from ..models import LLMMessage
class ChatCompletionContext(Protocol):
"""A protocol for defining the interface of a chat completion context.
class ChatCompletionContext(ABC):
"""An abstract base class for defining the interface of a chat completion context.
A chat completion context lets agents store and retrieve LLM messages.
It can be implemented with different recall strategies."""
It can be implemented with different recall strategies.
async def add_message(self, message: LLMMessage) -> None: ...
Args:
initial_messages (List[LLMMessage] | None): The initial messages.
"""
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []
async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the context."""
self._messages.append(message)
@abstractmethod
async def get_messages(self) -> List[LLMMessage]: ...
async def clear(self) -> None: ...
async def clear(self) -> None:
"""Clear the context."""
self._messages = []
def save_state(self) -> Mapping[str, LLMMessage]: ...
async def save_state(self) -> Mapping[str, Any]:
return ChatCompletionContextState(messages=self._messages).model_dump()
def load_state(self, state: Mapping[str, LLMMessage]) -> None: ...
async def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = ChatCompletionContextState.model_validate(state).messages
class ChatCompletionContextState(BaseModel):
messages: List[LLMMessage] = Field(default_factory=list)

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Mapping
from typing import List
from .._types import FunctionCall
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
@@ -13,17 +13,18 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
Args:
head_size (int): The size of the head.
tail_size (int): The size of the tail.
initial_messages (List[LLMMessage] | None): The initial messages.
"""
def __init__(self, head_size: int, tail_size: int) -> None:
self._messages: List[LLMMessage] = []
def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
super().__init__(initial_messages)
if head_size <= 0:
raise ValueError("head_size must be greater than 0.")
if tail_size <= 0:
raise ValueError("tail_size must be greater than 0.")
self._head_size = head_size
self._tail_size = tail_size
async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)
async def get_messages(self) -> List[LLMMessage]:
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
head_messages = self._messages[: self._head_size]
@@ -51,21 +52,3 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")]
return head_messages + placeholder_messages + tail_messages
async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []
def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"head_size": self._head_size,
"tail_size": self._tail_size,
"placeholder_message": self._placeholder_message,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._head_size = state["head_size"]
self._tail_size = state["tail_size"]
self._placeholder_message = state["placeholder_message"]

View File

@@ -0,0 +1,12 @@
from typing import List
from ..models import LLMMessage
from ._chat_completion_context import ChatCompletionContext
class UnboundedChatCompletionContext(ChatCompletionContext):
"""An unbounded chat completion context that keeps a view of the all the messages."""
async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
return self._messages

View File

@@ -1,7 +1,11 @@
from typing import List
import pytest
from autogen_core.model_context import BufferedChatCompletionContext, HeadAndTailChatCompletionContext
from autogen_core.model_context import (
BufferedChatCompletionContext,
HeadAndTailChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage
@@ -26,6 +30,17 @@ async def test_buffered_model_context() -> None:
retrieved = await model_context.get_messages()
assert len(retrieved) == 0
# Test saving and loading state.
await model_context.add_message(messages[0])
await model_context.add_message(messages[1])
state = await model_context.save_state()
await model_context.clear()
await model_context.load_state(state)
retrieved = await model_context.get_messages()
assert len(retrieved) == 2
assert retrieved[0] == messages[0]
assert retrieved[1] == messages[1]
@pytest.mark.asyncio
async def test_head_and_tail_model_context() -> None:
@@ -48,3 +63,44 @@ async def test_head_and_tail_model_context() -> None:
await model_context.clear()
retrieved = await model_context.get_messages()
assert len(retrieved) == 0
# Test saving and loading state.
for msg in messages:
await model_context.add_message(msg)
state = await model_context.save_state()
await model_context.clear()
await model_context.load_state(state)
retrived = await model_context.get_messages()
assert len(retrived) == 3
assert retrived[0] == messages[0]
assert retrived[2] == messages[-1]
@pytest.mark.asyncio
async def test_unbounded_model_context() -> None:
model_context = UnboundedChatCompletionContext()
messages: List[LLMMessage] = [
UserMessage(content="Hello!", source="user"),
AssistantMessage(content="What can I do for you?", source="assistant"),
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
]
for msg in messages:
await model_context.add_message(msg)
retrieved = await model_context.get_messages()
assert len(retrieved) == 3
assert retrieved == messages
await model_context.clear()
retrieved = await model_context.get_messages()
assert len(retrieved) == 0
# Test saving and loading state.
for msg in messages:
await model_context.add_message(msg)
state = await model_context.save_state()
await model_context.clear()
await model_context.load_state(state)
retrieved = await model_context.get_messages()
assert len(retrieved) == 3
assert retrieved == messages