mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-10 11:45:14 -05:00
Make save/load state for agent async (#4195)
This commit is contained in:
@@ -250,12 +250,12 @@ class ChatCompletionAgent(RoutedAgent):
|
||||
result_as_str = f"Error: {str(e)}"
|
||||
return (result_as_str, call_id)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._model_context.save_state(),
|
||||
"system_messages": self._system_messages,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state(state["memory"])
|
||||
self._system_messages = state["system_messages"]
|
||||
|
||||
@@ -142,12 +142,12 @@ class GroupChatManager(RoutedAgent):
|
||||
# Send the message to the selected speaker to ask it to publish a response.
|
||||
await self.send_message(PublishNow(), speaker)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"chat_history": self._model_context.save_state(),
|
||||
"termination_word": self._termination_word,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state(state["chat_history"])
|
||||
self._termination_word = state["termination_word"]
|
||||
|
||||
@@ -98,13 +98,13 @@ class SlowUserProxyAgent(RoutedAgent):
|
||||
GetSlowUserMessage(content=message.content), topic_id=DefaultTopicId("scheduling_assistant_conversation")
|
||||
)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state_to_save = {
|
||||
"memory": self._model_context.save_state(),
|
||||
}
|
||||
return state_to_save
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
|
||||
|
||||
|
||||
@@ -186,12 +186,12 @@ Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||
|
||||
await self.publish_message(speech, topic_id=DefaultTopicId("scheduling_assistant_conversation"))
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._model_context.save_state(),
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
|
||||
|
||||
|
||||
|
||||
@@ -274,14 +274,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state: Dict[str, Dict[str, Any]] = {}
|
||||
for agent_id in self._instantiated_agents:
|
||||
state[str(agent_id)] = dict((await self._get_agent(agent_id)).save_state())
|
||||
state[str(agent_id)] = dict(await (await self._get_agent(agent_id)).save_state())
|
||||
return state
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
for agent_id_str in state:
|
||||
agent_id = AgentId.from_str(agent_id_str)
|
||||
if agent_id.type in self._known_agent_names:
|
||||
(await self._get_agent(agent_id)).load_state(state[str(agent_id)])
|
||||
await (await self._get_agent(agent_id)).load_state(state[str(agent_id)])
|
||||
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
|
||||
@@ -526,10 +526,10 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
return (await self._get_agent(agent)).metadata
|
||||
|
||||
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
return (await self._get_agent(agent)).save_state()
|
||||
return await (await self._get_agent(agent)).save_state()
|
||||
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
(await self._get_agent(agent)).load_state(state)
|
||||
await (await self._get_agent(agent)).load_state(state)
|
||||
|
||||
@deprecated(
|
||||
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
|
||||
|
||||
@@ -33,11 +33,11 @@ class Agent(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the agent. The result must be JSON serializable."""
|
||||
...
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load in the state of the agent obtained from `save_state`.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -133,11 +133,11 @@ class BaseAgent(ABC, Agent):
|
||||
) -> None:
|
||||
await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
warnings.warn("save_state not implemented", stacklevel=2)
|
||||
return {}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
warnings.warn("load_state not implemented", stacklevel=2)
|
||||
pass
|
||||
|
||||
|
||||
@@ -90,10 +90,10 @@ class ClosureAgent(Agent):
|
||||
)
|
||||
return await self._closure(self._runtime, self._id, message, ctx)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
raise ValueError("save_state not implemented for ClosureAgent")
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
raise ValueError("load_state not implemented for ClosureAgent")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -13,10 +13,10 @@ class StatefulAgent(BaseAgent):
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return {"state": self.state}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self.state = state["state"]
|
||||
|
||||
|
||||
@@ -31,12 +31,12 @@ async def test_agent_can_save_state() -> None:
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
|
||||
agent1_state = agent1.save_state()
|
||||
agent1_state = await agent1.save_state()
|
||||
|
||||
agent1.state = 2
|
||||
assert agent1.state == 2
|
||||
|
||||
agent1.load_state(agent1_state)
|
||||
await agent1.load_state(agent1_state)
|
||||
assert agent1.state == 1
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user