use model context for assistant agent, refactor model context (#4681)

* Decouple model_context from AssistantAgent

* add UnboundedBufferedChatCompletionContext to mimic pervious model_context behaviour on AssistantAgent

* moving unbounded buffered chat to a different file

* fix model_context assertions in test_group_chat

* Refactor model context, introduce states

* fixes

* update

---------

Co-authored-by: aditya.kurniawan <aditya.kurniawan@core42.ai>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
Aditya Kurniawan
2024-12-20 10:27:41 +04:00
committed by GitHub
parent a271708a97
commit c989181da2
13 changed files with 183 additions and 97 deletions

View File

@@ -254,10 +254,10 @@ class ChatCompletionAgent(RoutedAgent):
async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"chat_history": await self._model_context.save_state(),
"system_messages": self._system_messages,
}
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["memory"])
await self._model_context.load_state(state["chat_history"])
self._system_messages = state["system_messages"]

View File

@@ -143,10 +143,12 @@ class GroupChatManager(RoutedAgent):
async def save_state(self) -> Mapping[str, Any]:
return {
"chat_history": self._model_context.save_state(),
"chat_history": await self._model_context.save_state(),
"termination_word": self._termination_word,
}
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["chat_history"])
# Load the chat history.
await self._model_context.load_state(state["chat_history"])
# Load the termination word.
self._termination_word = state["termination_word"]

View File

@@ -114,7 +114,7 @@ class SlowUserProxyAgent(RoutedAgent):
return state_to_save
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
await self._model_context.load_state(state["memory"])
class ScheduleMeetingInput(BaseModel):
@@ -200,11 +200,11 @@ Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"memory": await self._model_context.save_state(),
}
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
await self._model_context.load_state(state["memory"])
class NeedsUserInputHandler(DefaultInterventionHandler):