mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Make Memory and Team an ABC (#5149)
* make memory and team an ABC * update memory test * update tests
This commit is contained in:
@@ -1,17 +1,20 @@
|
||||
from typing import Any, Mapping, Protocol
|
||||
|
||||
from typing import Any, Mapping
|
||||
from abc import ABC, abstractmethod
|
||||
from ._task import TaskRunner
|
||||
|
||||
|
||||
class Team(TaskRunner, Protocol):
|
||||
class Team(ABC, TaskRunner):
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset the team and all its participants to its initial state."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the current state of the team."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the team."""
|
||||
...
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Protocol, Union, runtime_checkable
|
||||
from typing import Any, Dict, List, Union
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -48,8 +49,7 @@ class UpdateContextResult(BaseModel):
|
||||
memories: MemoryQueryResult
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Memory(Protocol):
|
||||
class Memory(ABC):
|
||||
"""Protocol defining the interface for memory implementations.
|
||||
|
||||
A memory is the storage for data that can be used to enrich or modify the model context.
|
||||
@@ -64,6 +64,7 @@ class Memory(Protocol):
|
||||
See :class:`~autogen_core.memory.ListMemory` for an example implementation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
@@ -79,6 +80,7 @@ class Memory(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent,
|
||||
@@ -98,6 +100,7 @@ class Memory(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
"""
|
||||
Add a new content to memory.
|
||||
@@ -108,10 +111,12 @@ class Memory(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear all entries from memory."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Clean up any resources used by the memory implementation."""
|
||||
...
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Any
|
||||
import pytest
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.memory import (
|
||||
@@ -21,19 +22,22 @@ def test_memory_protocol_attributes() -> None:
|
||||
assert hasattr(Memory, "close")
|
||||
|
||||
|
||||
def test_memory_protocol_runtime_checkable() -> None:
|
||||
"""Test that Memory protocol is properly runtime-checkable."""
|
||||
def test_memory_abc_implementation() -> None:
|
||||
"""Test that Memory ABC is properly implemented."""
|
||||
|
||||
class ValidMemory:
|
||||
class ValidMemory(Memory):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "test"
|
||||
|
||||
async def update_context(self, context: ChatCompletionContext) -> UpdateContextResult:
|
||||
async def update_context(self, model_context: ChatCompletionContext) -> UpdateContextResult:
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
|
||||
async def query(
|
||||
self, query: MemoryContent, cancellation_token: CancellationToken | None = None
|
||||
self,
|
||||
query: str | MemoryContent,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
return MemoryQueryResult(results=[])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user