mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
use model context for assistant agent, refactor model context (#4681)
* Decouple model_context from AssistantAgent * add UnboundedBufferedChatCompletionContext to mimic pervious model_context behaviour on AssistantAgent * moving unbounded buffered chat to a different file * fix model_context assertions in test_group_chat * Refactor model context, introduce states * fixes * update --------- Co-authored-by: aditya.kurniawan <aditya.kurniawan@core42.ai> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
@@ -2,15 +2,27 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
from autogen_core.model_context import (
|
||||
ChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
@@ -87,7 +99,6 @@ class AssistantAgent(BaseChatAgent):
|
||||
If multiple handoffs are detected, only the first handoff is executed.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client to use for inference.
|
||||
@@ -96,8 +107,9 @@ class AssistantAgent(BaseChatAgent):
|
||||
allowing it to transfer to other agents by responding with a :class:`HandoffMessage`.
|
||||
The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`.
|
||||
If a handoff is a string, it should represent the target agent's name.
|
||||
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
|
||||
description (str, optional): The description of the agent.
|
||||
system_message (str, optional): The system message for the model.
|
||||
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
|
||||
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
|
||||
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
|
||||
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
|
||||
@@ -219,9 +231,11 @@ class AssistantAgent(BaseChatAgent):
|
||||
*,
|
||||
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
|
||||
handoffs: List[HandoffBase | str] | None = None,
|
||||
model_context: ChatCompletionContext | None = None,
|
||||
description: str = "An agent that provides assistance with ability to use tools.",
|
||||
system_message: str
|
||||
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
|
||||
system_message: (
|
||||
str | None
|
||||
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
|
||||
reflect_on_tool_use: bool = False,
|
||||
tool_call_summary_format: str = "{result}",
|
||||
):
|
||||
@@ -273,7 +287,8 @@ class AssistantAgent(BaseChatAgent):
|
||||
raise ValueError(
|
||||
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
|
||||
)
|
||||
self._model_context: List[LLMMessage] = []
|
||||
if not model_context:
|
||||
self._model_context = UnboundedChatCompletionContext()
|
||||
self._reflect_on_tool_use = reflect_on_tool_use
|
||||
self._tool_call_summary_format = tool_call_summary_format
|
||||
self._is_running = False
|
||||
@@ -301,19 +316,19 @@ class AssistantAgent(BaseChatAgent):
|
||||
for msg in messages:
|
||||
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
|
||||
raise ValueError("The model does not support vision.")
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Inner messages.
|
||||
inner_messages: List[AgentEvent | ChatMessage] = []
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
llm_messages = self._system_messages + await self._model_context.get_messages()
|
||||
result = await self._model_client.create(
|
||||
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
# Add the response to the model context.
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
# Check if the response is a string and return it.
|
||||
if isinstance(result.content, str):
|
||||
@@ -335,7 +350,7 @@ class AssistantAgent(BaseChatAgent):
|
||||
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
|
||||
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
|
||||
event_logger.debug(tool_call_result_msg)
|
||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
|
||||
inner_messages.append(tool_call_result_msg)
|
||||
yield tool_call_result_msg
|
||||
|
||||
@@ -360,11 +375,11 @@ class AssistantAgent(BaseChatAgent):
|
||||
|
||||
if self._reflect_on_tool_use:
|
||||
# Generate another inference result based on the tool call and result.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
llm_messages = self._system_messages + await self._model_context.get_messages()
|
||||
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
|
||||
assert isinstance(result.content, str)
|
||||
# Add the response to the model context.
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
|
||||
# Yield the response.
|
||||
yield Response(
|
||||
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
|
||||
@@ -406,14 +421,15 @@ class AssistantAgent(BaseChatAgent):
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Reset the assistant agent to its initialization state."""
|
||||
self._model_context.clear()
|
||||
await self._model_context.clear()
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the current state of the assistant agent."""
|
||||
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()
|
||||
model_context_state = await self._model_context.save_state()
|
||||
return AssistantAgentState(llm_context=model_context_state).model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the assistant agent"""
|
||||
assistant_agent_state = AssistantAgentState.model_validate(state)
|
||||
self._model_context.clear()
|
||||
self._model_context.extend(assistant_agent_state.llm_messages)
|
||||
# Load the model context state.
|
||||
await self._model_context.load_state(assistant_agent_state.llm_context)
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from autogen_core.models import (
|
||||
LLMMessage,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..messages import (
|
||||
@@ -21,7 +18,7 @@ class BaseState(BaseModel):
|
||||
class AssistantAgentState(BaseState):
|
||||
"""State for an assistant agent."""
|
||||
|
||||
llm_messages: List[LLMMessage] = Field(default_factory=list)
|
||||
llm_context: Mapping[str, Any] = Field(default_factory=lambda: dict([("messages", [])]))
|
||||
type: str = Field(default="AssistantAgentState")
|
||||
|
||||
|
||||
|
||||
@@ -239,8 +239,13 @@ async def test_round_robin_group_chat_state() -> None:
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
assert agent3._model_context == agent1._model_context # pyright: ignore
|
||||
assert agent4._model_context == agent2._model_context # pyright: ignore
|
||||
|
||||
agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
|
||||
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
|
||||
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
|
||||
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
|
||||
assert agent3_model_ctx_messages == agent1_model_ctx_messages
|
||||
assert agent4_model_ctx_messages == agent2_model_ctx_messages
|
||||
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
|
||||
RoundRobinGroupChatManager, # pyright: ignore
|
||||
@@ -337,7 +342,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
||||
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
||||
|
||||
# Test streaming.
|
||||
tool_use_agent._model_context.clear() # pyright: ignore
|
||||
await tool_use_agent._model_context.clear() # pyright: ignore
|
||||
mock.reset()
|
||||
index = 0
|
||||
await team.reset()
|
||||
@@ -351,7 +356,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
||||
index += 1
|
||||
|
||||
# Test Console.
|
||||
tool_use_agent._model_context.clear() # pyright: ignore
|
||||
await tool_use_agent._model_context.clear() # pyright: ignore
|
||||
mock.reset()
|
||||
index = 0
|
||||
await team.reset()
|
||||
@@ -579,8 +584,13 @@ async def test_selector_group_chat_state() -> None:
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
assert agent3._model_context == agent1._model_context # pyright: ignore
|
||||
assert agent4._model_context == agent2._model_context # pyright: ignore
|
||||
|
||||
agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
|
||||
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
|
||||
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
|
||||
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
|
||||
assert agent3_model_ctx_messages == agent1_model_ctx_messages
|
||||
assert agent4_model_ctx_messages == agent2_model_ctx_messages
|
||||
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
|
||||
SelectorGroupChatManager, # pyright: ignore
|
||||
@@ -931,7 +941,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
|
||||
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
||||
|
||||
# Test streaming.
|
||||
agent1._model_context.clear() # pyright: ignore
|
||||
await agent1._model_context.clear() # pyright: ignore
|
||||
mock.reset()
|
||||
index = 0
|
||||
await team.reset()
|
||||
@@ -944,7 +954,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
|
||||
index += 1
|
||||
|
||||
# Test Console
|
||||
agent1._model_context.clear() # pyright: ignore
|
||||
await agent1._model_context.clear() # pyright: ignore
|
||||
mock.reset()
|
||||
index = 0
|
||||
await team.reset()
|
||||
|
||||
@@ -454,17 +454,17 @@
|
||||
"\n",
|
||||
"The above `SimpleAgent` always responds with a fresh context that contains only\n",
|
||||
"the system message and the latest user's message.\n",
|
||||
"We can use model context classes from {py:mod}`autogen_core.components.model_context`\n",
|
||||
"We can use model context classes from {py:mod}`autogen_core.model_context`\n",
|
||||
"to make the agent \"remember\" previous conversations.\n",
|
||||
"A model context supports storage and retrieval of Chat Completion messages.\n",
|
||||
"It is always used together with a model client to generate LLM-based responses.\n",
|
||||
"\n",
|
||||
"For example, {py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`\n",
|
||||
"For example, {py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`\n",
|
||||
"is a most-recent-used (MRU) context that stores the most recent `buffer_size`\n",
|
||||
"number of messages. This is useful to avoid context overflow in many LLMs.\n",
|
||||
"\n",
|
||||
"Let's update the previous example to use\n",
|
||||
"{py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`."
|
||||
"{py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -473,7 +473,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_core.components.model_context import BufferedChatCompletionContext\n",
|
||||
"from autogen_core.model_context import BufferedChatCompletionContext\n",
|
||||
"from autogen_core.models import AssistantMessage\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -615,7 +615,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -254,10 +254,10 @@ class ChatCompletionAgent(RoutedAgent):
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._model_context.save_state(),
|
||||
"chat_history": await self._model_context.save_state(),
|
||||
"system_messages": self._system_messages,
|
||||
}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state(state["memory"])
|
||||
await self._model_context.load_state(state["chat_history"])
|
||||
self._system_messages = state["system_messages"]
|
||||
|
||||
@@ -143,10 +143,12 @@ class GroupChatManager(RoutedAgent):
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"chat_history": self._model_context.save_state(),
|
||||
"chat_history": await self._model_context.save_state(),
|
||||
"termination_word": self._termination_word,
|
||||
}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state(state["chat_history"])
|
||||
# Load the chat history.
|
||||
await self._model_context.load_state(state["chat_history"])
|
||||
# Load the termination word.
|
||||
self._termination_word = state["termination_word"]
|
||||
|
||||
@@ -114,7 +114,7 @@ class SlowUserProxyAgent(RoutedAgent):
|
||||
return state_to_save
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
|
||||
await self._model_context.load_state(state["memory"])
|
||||
|
||||
|
||||
class ScheduleMeetingInput(BaseModel):
|
||||
@@ -200,11 +200,11 @@ Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._model_context.save_state(),
|
||||
"memory": await self._model_context.save_state(),
|
||||
}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
|
||||
await self._model_context.load_state(state["memory"])
|
||||
|
||||
|
||||
class NeedsUserInputHandler(DefaultInterventionHandler):
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from ._buffered_chat_completion_context import BufferedChatCompletionContext
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
|
||||
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
|
||||
from ._unbounded_chat_completion_context import (
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatCompletionContext",
|
||||
"ChatCompletionContextState",
|
||||
"UnboundedChatCompletionContext",
|
||||
"BufferedChatCompletionContext",
|
||||
"HeadAndTailChatCompletionContext",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Mapping
|
||||
from typing import List
|
||||
|
||||
from ..models import FunctionExecutionResultMessage, LLMMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
@@ -10,17 +10,15 @@ class BufferedChatCompletionContext(ChatCompletionContext):
|
||||
|
||||
Args:
|
||||
buffer_size (int): The size of the buffer.
|
||||
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
self._messages: List[LLMMessage] = initial_messages or []
|
||||
super().__init__(initial_messages)
|
||||
if buffer_size <= 0:
|
||||
raise ValueError("buffer_size must be greater than 0.")
|
||||
self._buffer_size = buffer_size
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
||||
"""Add a message to the memory."""
|
||||
self._messages.append(message)
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `buffer_size` recent messages."""
|
||||
messages = self._messages[-self._buffer_size :]
|
||||
@@ -29,17 +27,3 @@ class BufferedChatCompletionContext(ChatCompletionContext):
|
||||
# Remove the first message from the list.
|
||||
messages = messages[1:]
|
||||
return messages
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the message memory."""
|
||||
self._messages = []
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"messages": [message for message in self._messages],
|
||||
"buffer_size": self._buffer_size,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = state["messages"]
|
||||
self._buffer_size = state["buffer_size"]
|
||||
|
||||
@@ -1,19 +1,40 @@
|
||||
from typing import List, Mapping, Protocol
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models import LLMMessage
|
||||
|
||||
|
||||
class ChatCompletionContext(Protocol):
|
||||
"""A protocol for defining the interface of a chat completion context.
|
||||
class ChatCompletionContext(ABC):
|
||||
"""An abstract base class for defining the interface of a chat completion context.
|
||||
A chat completion context lets agents store and retrieve LLM messages.
|
||||
It can be implemented with different recall strategies."""
|
||||
It can be implemented with different recall strategies.
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None: ...
|
||||
Args:
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
self._messages: List[LLMMessage] = initial_messages or []
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
||||
"""Add a message to the context."""
|
||||
self._messages.append(message)
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages(self) -> List[LLMMessage]: ...
|
||||
|
||||
async def clear(self) -> None: ...
|
||||
async def clear(self) -> None:
|
||||
"""Clear the context."""
|
||||
self._messages = []
|
||||
|
||||
def save_state(self) -> Mapping[str, LLMMessage]: ...
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return ChatCompletionContextState(messages=self._messages).model_dump()
|
||||
|
||||
def load_state(self, state: Mapping[str, LLMMessage]) -> None: ...
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = ChatCompletionContextState.model_validate(state).messages
|
||||
|
||||
|
||||
class ChatCompletionContextState(BaseModel):
|
||||
messages: List[LLMMessage] = Field(default_factory=list)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Mapping
|
||||
from typing import List
|
||||
|
||||
from .._types import FunctionCall
|
||||
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
|
||||
@@ -13,17 +13,18 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
|
||||
Args:
|
||||
head_size (int): The size of the head.
|
||||
tail_size (int): The size of the tail.
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
def __init__(self, head_size: int, tail_size: int) -> None:
|
||||
self._messages: List[LLMMessage] = []
|
||||
def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
super().__init__(initial_messages)
|
||||
if head_size <= 0:
|
||||
raise ValueError("head_size must be greater than 0.")
|
||||
if tail_size <= 0:
|
||||
raise ValueError("tail_size must be greater than 0.")
|
||||
self._head_size = head_size
|
||||
self._tail_size = tail_size
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
||||
"""Add a message to the memory."""
|
||||
self._messages.append(message)
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
|
||||
head_messages = self._messages[: self._head_size]
|
||||
@@ -51,21 +52,3 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
|
||||
|
||||
placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")]
|
||||
return head_messages + placeholder_messages + tail_messages
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the message memory."""
|
||||
self._messages = []
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"messages": [message for message in self._messages],
|
||||
"head_size": self._head_size,
|
||||
"tail_size": self._tail_size,
|
||||
"placeholder_message": self._placeholder_message,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = state["messages"]
|
||||
self._head_size = state["head_size"]
|
||||
self._tail_size = state["tail_size"]
|
||||
self._placeholder_message = state["placeholder_message"]
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
from typing import List
|
||||
|
||||
from ..models import LLMMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class UnboundedChatCompletionContext(ChatCompletionContext):
|
||||
"""An unbounded chat completion context that keeps a view of the all the messages."""
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `buffer_size` recent messages."""
|
||||
return self._messages
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from autogen_core.model_context import BufferedChatCompletionContext, HeadAndTailChatCompletionContext
|
||||
from autogen_core.model_context import (
|
||||
BufferedChatCompletionContext,
|
||||
HeadAndTailChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage
|
||||
|
||||
|
||||
@@ -26,6 +30,17 @@ async def test_buffered_model_context() -> None:
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
||||
|
||||
# Test saving and loading state.
|
||||
await model_context.add_message(messages[0])
|
||||
await model_context.add_message(messages[1])
|
||||
state = await model_context.save_state()
|
||||
await model_context.clear()
|
||||
await model_context.load_state(state)
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 2
|
||||
assert retrieved[0] == messages[0]
|
||||
assert retrieved[1] == messages[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_head_and_tail_model_context() -> None:
|
||||
@@ -48,3 +63,44 @@ async def test_head_and_tail_model_context() -> None:
|
||||
await model_context.clear()
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
||||
|
||||
# Test saving and loading state.
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
state = await model_context.save_state()
|
||||
await model_context.clear()
|
||||
await model_context.load_state(state)
|
||||
retrived = await model_context.get_messages()
|
||||
assert len(retrived) == 3
|
||||
assert retrived[0] == messages[0]
|
||||
assert retrived[2] == messages[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unbounded_model_context() -> None:
|
||||
model_context = UnboundedChatCompletionContext()
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||
]
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 3
|
||||
assert retrieved == messages
|
||||
|
||||
await model_context.clear()
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
||||
|
||||
# Test saving and loading state.
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
state = await model_context.save_state()
|
||||
await model_context.clear()
|
||||
await model_context.load_state(state)
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 3
|
||||
assert retrieved == messages
|
||||
|
||||
Reference in New Issue
Block a user