From 2b565713d0e265864cd9ee0425ba9a2917b8fe91 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Fri, 15 Nov 2024 10:38:01 -0500 Subject: [PATCH] Make save/load state for agent async (#4195) --- .../samples/common/agents/_chat_completion_agent.py | 4 ++-- .../samples/common/patterns/_group_chat_manager.py | 4 ++-- .../packages/autogen-core/samples/slow_human_in_loop.py | 8 ++++---- .../application/_single_threaded_agent_runtime.py | 8 ++++---- .../packages/autogen-core/src/autogen_core/base/_agent.py | 4 ++-- .../autogen-core/src/autogen_core/base/_base_agent.py | 4 ++-- .../src/autogen_core/components/_closure_agent.py | 4 ++-- python/packages/autogen-core/tests/test_state.py | 8 ++++---- 8 files changed, 22 insertions(+), 22 deletions(-) diff --git a/python/packages/autogen-core/samples/common/agents/_chat_completion_agent.py b/python/packages/autogen-core/samples/common/agents/_chat_completion_agent.py index cf7d1e016..51f53c9a2 100644 --- a/python/packages/autogen-core/samples/common/agents/_chat_completion_agent.py +++ b/python/packages/autogen-core/samples/common/agents/_chat_completion_agent.py @@ -250,12 +250,12 @@ class ChatCompletionAgent(RoutedAgent): result_as_str = f"Error: {str(e)}" return (result_as_str, call_id) - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: return { "memory": self._model_context.save_state(), "system_messages": self._system_messages, } - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: self._model_context.load_state(state["memory"]) self._system_messages = state["system_messages"] diff --git a/python/packages/autogen-core/samples/common/patterns/_group_chat_manager.py b/python/packages/autogen-core/samples/common/patterns/_group_chat_manager.py index 47c8ecf3b..22d429118 100644 --- a/python/packages/autogen-core/samples/common/patterns/_group_chat_manager.py +++ b/python/packages/autogen-core/samples/common/patterns/_group_chat_manager.py @@ -142,12 +142,12 @@ class GroupChatManager(RoutedAgent): # Send the message to the selected speaker to ask it to publish a response. await self.send_message(PublishNow(), speaker) - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: return { "chat_history": self._model_context.save_state(), "termination_word": self._termination_word, } - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: self._model_context.load_state(state["chat_history"]) self._termination_word = state["termination_word"] diff --git a/python/packages/autogen-core/samples/slow_human_in_loop.py b/python/packages/autogen-core/samples/slow_human_in_loop.py index 845bbba26..cc8012f60 100644 --- a/python/packages/autogen-core/samples/slow_human_in_loop.py +++ b/python/packages/autogen-core/samples/slow_human_in_loop.py @@ -98,13 +98,13 @@ class SlowUserProxyAgent(RoutedAgent): GetSlowUserMessage(content=message.content), topic_id=DefaultTopicId("scheduling_assistant_conversation") ) - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: state_to_save = { "memory": self._model_context.save_state(), } return state_to_save - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]}) @@ -186,12 +186,12 @@ Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")} await self.publish_message(speech, topic_id=DefaultTopicId("scheduling_assistant_conversation")) - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: return { "memory": self._model_context.save_state(), } - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]}) diff --git a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py index 7feff6bab..f511bd782 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py @@ -274,14 +274,14 @@ class SingleThreadedAgentRuntime(AgentRuntime): async def save_state(self) -> Mapping[str, Any]: state: Dict[str, Dict[str, Any]] = {} for agent_id in self._instantiated_agents: - state[str(agent_id)] = dict((await self._get_agent(agent_id)).save_state()) + state[str(agent_id)] = dict(await (await self._get_agent(agent_id)).save_state()) return state async def load_state(self, state: Mapping[str, Any]) -> None: for agent_id_str in state: agent_id = AgentId.from_str(agent_id_str) if agent_id.type in self._known_agent_names: - (await self._get_agent(agent_id)).load_state(state[str(agent_id)]) + await (await self._get_agent(agent_id)).load_state(state[str(agent_id)]) async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata): @@ -526,10 +526,10 @@ class SingleThreadedAgentRuntime(AgentRuntime): return (await self._get_agent(agent)).metadata async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: - return (await self._get_agent(agent)).save_state() + return await (await self._get_agent(agent)).save_state() async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: - (await self._get_agent(agent)).load_state(state) + await (await self._get_agent(agent)).load_state(state) @deprecated( "Use your agent's `register` method directly instead of this method. See documentation for latest usage." diff --git a/python/packages/autogen-core/src/autogen_core/base/_agent.py b/python/packages/autogen-core/src/autogen_core/base/_agent.py index 376efa254..edb5e59b1 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_agent.py @@ -33,11 +33,11 @@ class Agent(Protocol): """ ... - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: """Save the state of the agent. The result must be JSON serializable.""" ... - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: """Load in the state of the agent obtained from `save_state`. Args: diff --git a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py index cd113e2f1..5d8e94225 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py @@ -133,11 +133,11 @@ class BaseAgent(ABC, Agent): ) -> None: await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token) - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: warnings.warn("save_state not implemented", stacklevel=2) return {} - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: warnings.warn("load_state not implemented", stacklevel=2) pass diff --git a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py index ca0bb8d33..1123c3ee4 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py @@ -90,10 +90,10 @@ class ClosureAgent(Agent): ) return await self._closure(self._runtime, self._id, message, ctx) - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: raise ValueError("save_state not implemented for ClosureAgent") - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: raise ValueError("load_state not implemented for ClosureAgent") @classmethod diff --git a/python/packages/autogen-core/tests/test_state.py b/python/packages/autogen-core/tests/test_state.py index cba5631be..5d6844736 100644 --- a/python/packages/autogen-core/tests/test_state.py +++ b/python/packages/autogen-core/tests/test_state.py @@ -13,10 +13,10 @@ class StatefulAgent(BaseAgent): async def on_message(self, message: Any, ctx: MessageContext) -> None: raise NotImplementedError - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: return {"state": self.state} - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: self.state = state["state"] @@ -31,12 +31,12 @@ async def test_agent_can_save_state() -> None: agent1.state = 1 assert agent1.state == 1 - agent1_state = agent1.save_state() + agent1_state = await agent1.save_state() agent1.state = 2 assert agent1.state == 2 - agent1.load_state(agent1_state) + await agent1.load_state(agent1_state) assert agent1.state == 1