impl state management (#28)

This commit is contained in:
Jack Gerrits
2024-05-27 20:25:25 -04:00
committed by GitHub
parent afc1666d5b
commit 7568a7a447
5 changed files with 93 additions and 4 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
from asyncio import Future
from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Set
from typing import Any, Awaitable, Dict, List, Mapping, Set
from ..core import Agent, AgentRuntime, CancellationToken
from ..core.exceptions import MessageDroppedException
@@ -108,6 +108,16 @@ class SingleThreadedAgentRuntime(AgentRuntime):
future.set_result(None)
return future
def save_state(self) -> Mapping[str, Any]:
state: Dict[str, Dict[str, Any]] = {}
for agent in self._agents:
state[agent.name] = dict(agent.save_state())
return state
def load_state(self, state: Mapping[str, Any]) -> None:
for agent in self._agents:
agent.load_state(state[agent.name])
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
recipient = message_envelope.recipient
assert recipient in self._agents

View File

@@ -1,4 +1,4 @@
from typing import Any, Protocol, Sequence, runtime_checkable
from typing import Any, Mapping, Protocol, Sequence, runtime_checkable
from agnext.core._cancellation_token import CancellationToken
@@ -33,3 +33,7 @@ class Agent(Protocol):
If there was a cancellation, this function should raise a `CancelledError`.
"""
...
def save_state(self) -> Mapping[str, Any]: ...
def load_state(self, state: Mapping[str, Any]) -> None: ...

View File

@@ -1,5 +1,5 @@
from asyncio import Future
from typing import Any, Protocol
from typing import Any, Mapping, Protocol
from agnext.core._agent import Agent
from agnext.core._cancellation_token import CancellationToken
@@ -37,3 +37,7 @@ class AgentRuntime(Protocol):
sender: Agent | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[None]: ...
def save_state(self) -> Mapping[str, Any]: ...
def load_state(self, state: Mapping[str, Any]) -> None: ...

View File

@@ -1,6 +1,7 @@
import warnings
from abc import ABC, abstractmethod
from asyncio import Future
from typing import Any, Sequence, TypeVar
from typing import Any, Mapping, Sequence, TypeVar
from agnext.core._agent_runtime import AgentRuntime
from agnext.core._cancellation_token import CancellationToken
@@ -62,3 +63,11 @@ class BaseAgent(ABC, Agent):
cancellation_token = CancellationToken()
future = self._router.publish_message(message, sender=self, cancellation_token=cancellation_token)
return future
def save_state(self) -> Mapping[str, Any]:
warnings.warn("save_state not implemented", stacklevel=2)
return {}
def load_state(self, state: Mapping[str, Any]) -> None:
warnings.warn("load_state not implemented", stacklevel=2)
pass

62
tests/test_state.py Normal file
View File

@@ -0,0 +1,62 @@
from typing import Any, Mapping, Sequence
import pytest
from agnext.application_components import SingleThreadedAgentRuntime
from agnext.core import AgentRuntime
from agnext.core import BaseAgent
from agnext.core import CancellationToken
class StatefulAgent(BaseAgent):
def __init__(self, name: str, runtime: AgentRuntime) -> None:
super().__init__(name, runtime)
self.state = 0
@property
def subscriptions(self) -> Sequence[type]:
return []
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
raise NotImplementedError
def save_state(self) -> Mapping[str, Any]:
return {"state": self.state}
def load_state(self, state: Mapping[str, Any]) -> None:
self.state = state["state"]
@pytest.mark.asyncio
async def test_agent_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
agent1 = StatefulAgent("name1", runtime)
assert agent1.state == 0
agent1.state = 1
assert agent1.state == 1
agent1_state = agent1.save_state()
agent1.state = 2
assert agent1.state == 2
agent1.load_state(agent1_state)
assert agent1.state == 1
@pytest.mark.asyncio
async def test_runtime_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
agent1 = StatefulAgent("name1", runtime)
assert agent1.state == 0
agent1.state = 1
assert agent1.state == 1
runtime_state = runtime.save_state()
runtime2 = SingleThreadedAgentRuntime()
agent2 = StatefulAgent("name1", runtime2)
runtime2.load_state(runtime_state)
assert agent2.state == 1