Mark cache as a protocol and update type hints to reflect (#2168)

* Mark cache as a protocl and update type hints to reflect

* int

* undo init change
This commit is contained in:
Jack Gerrits
2024-03-27 18:15:24 -04:00
committed by GitHub
parent 1002882f01
commit 95c0118568
8 changed files with 26 additions and 58 deletions

View File

@@ -150,7 +150,7 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
- "recipient": the recipient agent.
- "clear_history" (bool): whether to clear the chat history with the agent. Default is True.
- "silent" (bool or None): (Experimental) whether to print the messages in this conversation. Default is False.
- "cache" (Cache or None): the cache client to use for this conversation. Default is None.
- "cache" (AbstractCache or None): the cache client to use for this conversation. Default is None.
- "max_turns" (int or None): maximum number of turns for the chat. If None, the chat will continue until a termination condition is met. Default is None.
- "summary_method" (str or callable): a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- "summary_args" (dict): a dictionary of arguments to be passed to the summary_method. Default is {}.

View File

@@ -5,7 +5,7 @@ from openai import OpenAI
from PIL.Image import Image
from autogen import Agent, ConversableAgent, code_utils
from autogen.cache import Cache
from autogen.cache import AbstractCache
from autogen.agentchat.contrib import img_utils
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
@@ -142,7 +142,7 @@ class ImageGeneration(AgentCapability):
def __init__(
self,
image_generator: ImageGenerator,
cache: Optional[Cache] = None,
cache: Optional[AbstractCache] = None,
text_analyzer_llm_config: Optional[Dict] = None,
text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
verbosity: int = 0,
@@ -151,7 +151,7 @@ class ImageGeneration(AgentCapability):
"""
Args:
image_generator (ImageGenerator): The image generator you would like to use to generate images.
cache (None or Cache): The cache client to use to store and retrieve generated images. If None,
cache (None or AbstractCache): The cache client to use to store and retrieve generated images. If None,
no caching will be used.
text_analyzer_llm_config (Dict or None): The LLM config for the text analyzer. If None, the LLM config will
be retrieved from the agent you're adding the ability to.

View File

@@ -15,7 +15,7 @@ from openai import BadRequestError
from autogen.exception_utils import InvalidCarryOverType, SenderRequired
from .._pydantic import model_dump
from ..cache.cache import Cache
from ..cache.cache import AbstractCache
from ..code_utils import (
UNKNOWN,
check_can_use_docker_or_throw,
@@ -865,7 +865,7 @@ class ConversableAgent(LLMAgent):
recipient: "ConversableAgent",
clear_history: bool = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
cache: Optional[AbstractCache] = None,
max_turns: Optional[int] = None,
summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD,
summary_args: Optional[dict] = {},
@@ -882,7 +882,7 @@ class ConversableAgent(LLMAgent):
recipient: the recipient agent.
clear_history (bool): whether to clear the chat history with the agent. Default is True.
silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
cache (Cache or None): the cache client to be used for this conversation. Default is None.
cache (AbstractCache or None): the cache client to be used for this conversation. Default is None.
max_turns (int or None): the maximum number of turns for the chat between the two agents. One turn means one conversation round trip. Note that this is different from
[max_consecutive_auto_reply](#max_consecutive_auto_reply) which is the maximum number of consecutive auto replies; and it is also different from [max_rounds in GroupChat](./groupchat#groupchat-objects) which is the maximum number of rounds in a group chat session.
If max_turns is set to None, the chat will continue until a termination condition is met. Default is None.
@@ -1007,7 +1007,7 @@ class ConversableAgent(LLMAgent):
recipient: "ConversableAgent",
clear_history: bool = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
cache: Optional[AbstractCache] = None,
max_turns: Optional[int] = None,
summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD,
summary_args: Optional[dict] = {},
@@ -1073,7 +1073,7 @@ class ConversableAgent(LLMAgent):
summary_method,
summary_args,
recipient: Optional[Agent] = None,
cache: Optional[Cache] = None,
cache: Optional[AbstractCache] = None,
) -> str:
"""Get a chat summary from an agent participating in a chat.
@@ -1141,7 +1141,7 @@ class ConversableAgent(LLMAgent):
return summary
def _reflection_with_llm(
self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[Cache] = None
self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[AbstractCache] = None
) -> str:
"""Get a chat summary using reflection with an llm client based on the conversation history.
@@ -1149,7 +1149,7 @@ class ConversableAgent(LLMAgent):
prompt (str): The prompt (in this method it is used as system prompt) used to get the summary.
messages (list): The messages generated as part of a chat conversation.
llm_agent: the agent with an llm client.
cache (Cache or None): the cache client to be used for this conversation.
cache (AbstractCache or None): the cache client to be used for this conversation.
"""
system_msg = [
{

View File

@@ -1,3 +1,3 @@
from .cache import Cache
from .cache import Cache, AbstractCache
__all__ = ["Cache"]
__all__ = ["Cache", "AbstractCache"]

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, Optional, Type
from typing import Any, Optional, Protocol, Type
import sys
if sys.version_info >= (3, 11):
@@ -9,23 +8,17 @@ else:
from typing_extensions import Self
class AbstractCache(ABC):
class AbstractCache(Protocol):
"""
Abstract base class for cache implementations.
This class defines the basic interface for cache operations.
This protocol defines the basic interface for cache operations.
Implementing classes should provide concrete implementations for
these methods to handle caching mechanisms.
"""
@abstractmethod
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""
Retrieve an item from the cache.
Abstract method that must be implemented by subclasses to
retrieve an item from the cache.
Args:
key (str): The key identifying the item in the cache.
default (optional): The default value to return if the key is not found.
@@ -33,53 +26,35 @@ class AbstractCache(ABC):
Returns:
The value associated with the key if found, else the default value.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
...
@abstractmethod
def set(self, key: str, value: Any) -> None:
"""
Set an item in the cache.
Abstract method that must be implemented by subclasses to
store an item in the cache.
Args:
key (str): The key under which the item is to be stored.
value: The value to be stored in the cache.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
...
@abstractmethod
def close(self) -> None:
"""
Close the cache.
Abstract method that should be implemented by subclasses to
perform any necessary cleanup, such as closing network connections or
Close the cache. Perform any necessary cleanup, such as closing network connections or
releasing resources.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
...
@abstractmethod
def __enter__(self) -> Self:
"""
Enter the runtime context related to this object.
The with statement will bind this methods return value to the target(s)
The with statement will bind this method's return value to the target(s)
specified in the as clause of the statement, if any.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
...
@abstractmethod
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
@@ -89,15 +64,9 @@ class AbstractCache(ABC):
"""
Exit the runtime context and close the cache.
Abstract method that should be implemented by subclasses to handle
the exit from a with statement. It is responsible for resource
release and cleanup.
Args:
exc_type: The exception type if an exception was raised in the context.
exc_value: The exception value if an exception was raised in the context.
traceback: The traceback if an exception was raised in the context.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
...

View File

@@ -14,7 +14,7 @@ else:
from typing_extensions import Self
class Cache:
class Cache(AbstractCache):
"""
A wrapper class for managing cache configuration and instances.

View File

@@ -10,7 +10,7 @@ from flaml.automl.logger import logger_formatter
from pydantic import BaseModel
from typing import Protocol
from autogen.cache.cache import Cache
from autogen.cache import Cache
from autogen.io.base import IOStream
from autogen.oai.openai_utils import get_key, is_valid_api_key, OAI_PRICE1K
from autogen.token_count_utils import count_token
@@ -517,7 +517,7 @@ class OpenAIWrapper:
The actual prompt will be:
"Complete the following sentence: Today I feel".
More examples can be found at [templating](/docs/Use-Cases/enhanced_inference#templating).
- cache (Cache | None): A Cache object to use for response cache. Default to None.
- cache (AbstractCache | None): A Cache object to use for response cache. Default to None.
Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided,
then the cache_seed argument is ignored. If this argument is not provided or None,
then the cache_seed argument is used.

View File

@@ -6,7 +6,6 @@ import re
import pytest
from autogen import UserProxyAgent, config_list_from_json
from autogen.oai.openai_utils import filter_config
from autogen.cache import Cache
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai # noqa: E402