mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
impl state management (#28)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
from asyncio import Future
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Dict, List, Set
|
||||
from typing import Any, Awaitable, Dict, List, Mapping, Set
|
||||
|
||||
from ..core import Agent, AgentRuntime, CancellationToken
|
||||
from ..core.exceptions import MessageDroppedException
|
||||
@@ -108,6 +108,16 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
future.set_result(None)
|
||||
return future
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
state: Dict[str, Dict[str, Any]] = {}
|
||||
for agent in self._agents:
|
||||
state[agent.name] = dict(agent.save_state())
|
||||
return state
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
for agent in self._agents:
|
||||
agent.load_state(state[agent.name])
|
||||
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||
recipient = message_envelope.recipient
|
||||
assert recipient in self._agents
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Protocol, Sequence, runtime_checkable
|
||||
from typing import Any, Mapping, Protocol, Sequence, runtime_checkable
|
||||
|
||||
from agnext.core._cancellation_token import CancellationToken
|
||||
|
||||
@@ -33,3 +33,7 @@ class Agent(Protocol):
|
||||
If there was a cancellation, this function should raise a `CancelledError`.
|
||||
"""
|
||||
...
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]: ...
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None: ...
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from asyncio import Future
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Mapping, Protocol
|
||||
|
||||
from agnext.core._agent import Agent
|
||||
from agnext.core._cancellation_token import CancellationToken
|
||||
@@ -37,3 +37,7 @@ class AgentRuntime(Protocol):
|
||||
sender: Agent | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[None]: ...
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]: ...
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None: ...
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Future
|
||||
from typing import Any, Sequence, TypeVar
|
||||
from typing import Any, Mapping, Sequence, TypeVar
|
||||
|
||||
from agnext.core._agent_runtime import AgentRuntime
|
||||
from agnext.core._cancellation_token import CancellationToken
|
||||
@@ -62,3 +63,11 @@ class BaseAgent(ABC, Agent):
|
||||
cancellation_token = CancellationToken()
|
||||
future = self._router.publish_message(message, sender=self, cancellation_token=cancellation_token)
|
||||
return future
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
warnings.warn("save_state not implemented", stacklevel=2)
|
||||
return {}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
warnings.warn("load_state not implemented", stacklevel=2)
|
||||
pass
|
||||
|
||||
62
tests/test_state.py
Normal file
62
tests/test_state.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Any, Mapping, Sequence
|
||||
import pytest
|
||||
|
||||
from agnext.application_components import SingleThreadedAgentRuntime
|
||||
from agnext.core import AgentRuntime
|
||||
from agnext.core import BaseAgent
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
class StatefulAgent(BaseAgent):
|
||||
def __init__(self, name: str, runtime: AgentRuntime) -> None:
|
||||
super().__init__(name, runtime)
|
||||
self.state = 0
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
return []
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {"state": self.state}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self.state = state["state"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1 = StatefulAgent("name1", runtime)
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
|
||||
agent1_state = agent1.save_state()
|
||||
|
||||
agent1.state = 2
|
||||
assert agent1.state == 2
|
||||
|
||||
agent1.load_state(agent1_state)
|
||||
assert agent1.state == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1 = StatefulAgent("name1", runtime)
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
|
||||
runtime_state = runtime.save_state()
|
||||
|
||||
runtime2 = SingleThreadedAgentRuntime()
|
||||
agent2 = StatefulAgent("name1", runtime2)
|
||||
runtime2.load_state(runtime_state)
|
||||
assert agent2.state == 1
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user