Make save/load state for agent async (#4195)

This commit is contained in:
Jack Gerrits
2024-11-15 10:38:01 -05:00
committed by GitHub
parent 88229513e5
commit 2b565713d0
8 changed files with 22 additions and 22 deletions

View File

@@ -250,12 +250,12 @@ class ChatCompletionAgent(RoutedAgent):
result_as_str = f"Error: {str(e)}"
return (result_as_str, call_id)
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"system_messages": self._system_messages,
}
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["memory"])
self._system_messages = state["system_messages"]

View File

@@ -142,12 +142,12 @@ class GroupChatManager(RoutedAgent):
# Send the message to the selected speaker to ask it to publish a response.
await self.send_message(PublishNow(), speaker)
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
return {
"chat_history": self._model_context.save_state(),
"termination_word": self._termination_word,
}
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["chat_history"])
self._termination_word = state["termination_word"]

View File

@@ -98,13 +98,13 @@ class SlowUserProxyAgent(RoutedAgent):
GetSlowUserMessage(content=message.content), topic_id=DefaultTopicId("scheduling_assistant_conversation")
)
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
state_to_save = {
"memory": self._model_context.save_state(),
}
return state_to_save
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
@@ -186,12 +186,12 @@ Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
await self.publish_message(speech, topic_id=DefaultTopicId("scheduling_assistant_conversation"))
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
}
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})

View File

@@ -274,14 +274,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
async def save_state(self) -> Mapping[str, Any]:
state: Dict[str, Dict[str, Any]] = {}
for agent_id in self._instantiated_agents:
state[str(agent_id)] = dict((await self._get_agent(agent_id)).save_state())
state[str(agent_id)] = dict(await (await self._get_agent(agent_id)).save_state())
return state
async def load_state(self, state: Mapping[str, Any]) -> None:
for agent_id_str in state:
agent_id = AgentId.from_str(agent_id_str)
if agent_id.type in self._known_agent_names:
(await self._get_agent(agent_id)).load_state(state[str(agent_id)])
await (await self._get_agent(agent_id)).load_state(state[str(agent_id)])
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
@@ -526,10 +526,10 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return (await self._get_agent(agent)).metadata
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
return (await self._get_agent(agent)).save_state()
return await (await self._get_agent(agent)).save_state()
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
(await self._get_agent(agent)).load_state(state)
await (await self._get_agent(agent)).load_state(state)
@deprecated(
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."

View File

@@ -33,11 +33,11 @@ class Agent(Protocol):
"""
...
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
"""Save the state of the agent. The result must be JSON serializable."""
...
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load in the state of the agent obtained from `save_state`.
Args:

View File

@@ -133,11 +133,11 @@ class BaseAgent(ABC, Agent):
) -> None:
await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token)
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
warnings.warn("save_state not implemented", stacklevel=2)
return {}
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
warnings.warn("load_state not implemented", stacklevel=2)
pass

View File

@@ -90,10 +90,10 @@ class ClosureAgent(Agent):
)
return await self._closure(self._runtime, self._id, message, ctx)
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
raise ValueError("save_state not implemented for ClosureAgent")
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
raise ValueError("load_state not implemented for ClosureAgent")
@classmethod

View File

@@ -13,10 +13,10 @@ class StatefulAgent(BaseAgent):
async def on_message(self, message: Any, ctx: MessageContext) -> None:
raise NotImplementedError
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
return {"state": self.state}
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
self.state = state["state"]
@@ -31,12 +31,12 @@ async def test_agent_can_save_state() -> None:
agent1.state = 1
assert agent1.state == 1
agent1_state = agent1.save_state()
agent1_state = await agent1.save_state()
agent1.state = 2
assert agent1.state == 2
agent1.load_state(agent1_state)
await agent1.load_state(agent1_state)
assert agent1.state == 1