Resolve all type issues in cache module (#1888)

* Resolve all type issues in cache module

* Union[str, int]
This commit is contained in:
Jack Gerrits
2024-03-07 18:55:29 -05:00
committed by GitHub
parent c5e76653cb
commit 2a62ffc566
5 changed files with 84 additions and 38 deletions

View File

@@ -1,4 +1,12 @@
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, Optional, Type
import sys
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
class AbstractCache(ABC):
@@ -11,7 +19,7 @@ class AbstractCache(ABC):
"""
@abstractmethod
def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""
Retrieve an item from the cache.
@@ -31,7 +39,7 @@ class AbstractCache(ABC):
"""
@abstractmethod
def set(self, key, value):
def set(self, key: str, value: Any) -> None:
"""
Set an item in the cache.
@@ -47,7 +55,7 @@ class AbstractCache(ABC):
"""
@abstractmethod
def close(self):
def close(self) -> None:
"""
Close the cache.
@@ -60,7 +68,7 @@ class AbstractCache(ABC):
"""
@abstractmethod
def __enter__(self):
def __enter__(self) -> Self:
"""
Enter the runtime context related to this object.
@@ -72,7 +80,12 @@ class AbstractCache(ABC):
"""
@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""
Exit the runtime context and close the cache.

View File

@@ -1,6 +1,17 @@
from typing import Dict, Any
from __future__ import annotations
from types import TracebackType
from typing import Dict, Any, Optional, Type, Union
from autogen.cache.cache_factory import CacheFactory
from .abstract_cache_base import AbstractCache
from .cache_factory import CacheFactory
import sys
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
class Cache:
@@ -19,12 +30,12 @@ class Cache:
ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"]
@staticmethod
def redis(cache_seed=42, redis_url="redis://localhost:6379/0"):
def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost:6379/0") -> Cache:
"""
Create a Redis cache instance.
Args:
cache_seed (int, optional): A seed for the cache. Defaults to 42.
cache_seed (Union[str, int], optional): A seed for the cache. Defaults to 42.
redis_url (str, optional): The URL for the Redis server. Defaults to "redis://localhost:6379/0".
Returns:
@@ -33,12 +44,12 @@ class Cache:
return Cache({"cache_seed": cache_seed, "redis_url": redis_url})
@staticmethod
def disk(cache_seed=42, cache_path_root=".cache"):
def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> Cache:
"""
Create a Disk cache instance.
Args:
cache_seed (int, optional): A seed for the cache. Defaults to 42.
cache_seed (Union[str, int], optional): A seed for the cache. Defaults to 42.
cache_path_root (str, optional): The root path for the disk cache. Defaults to ".cache".
Returns:
@@ -70,7 +81,7 @@ class Cache:
self.config.get("cache_path_root", None),
)
def __enter__(self):
def __enter__(self) -> AbstractCache:
"""
Enter the runtime context related to the cache object.
@@ -79,7 +90,12 @@ class Cache:
"""
return self.cache.__enter__()
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""
Exit the runtime context related to the cache object.
@@ -93,7 +109,7 @@ class Cache:
"""
return self.cache.__exit__(exc_type, exc_value, traceback)
def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""
Retrieve an item from the cache.
@@ -107,7 +123,7 @@ class Cache:
"""
return self.cache.get(key, default)
def set(self, key, value):
def set(self, key: str, value: Any) -> None:
"""
Set an item in the cache.
@@ -117,7 +133,7 @@ class Cache:
"""
self.cache.set(key, value)
def close(self):
def close(self) -> None:
"""
Close the cache.

View File

@@ -1,14 +1,13 @@
from autogen.cache.disk_cache import DiskCache
try:
from autogen.cache.redis_cache import RedisCache
except ImportError:
RedisCache = None
from typing import Optional, Union, Type
from .abstract_cache_base import AbstractCache
from .disk_cache import DiskCache
class CacheFactory:
@staticmethod
def cache_factory(seed, redis_url=None, cache_path_root=".cache"):
def cache_factory(
seed: Union[str, int], redis_url: Optional[str] = None, cache_path_root: str = ".cache"
) -> AbstractCache:
"""
Factory function for creating cache instances.
@@ -17,7 +16,7 @@ class CacheFactory:
a RedisCache instance is created. Otherwise, a DiskCache instance is used.
Args:
seed (str): A string used as a seed or namespace for the cache.
seed (Union[str, int]): A string or int used as a seed or namespace for the cache.
This could be useful for creating distinct cache instances
or for namespacing keys in the cache.
redis_url (str or None): The URL for the Redis server. If this is None
@@ -40,7 +39,12 @@ class CacheFactory:
disk_cache = cache_factory("myseed", None)
```
"""
if RedisCache is not None and redis_url is not None:
return RedisCache(seed, redis_url)
if redis_url is not None:
try:
from .redis_cache import RedisCache
return RedisCache(seed, redis_url)
except ImportError:
return DiskCache(f"./{cache_path_root}/{seed}")
else:
return DiskCache(f"./{cache_path_root}/{seed}")

View File

@@ -1,5 +1,13 @@
from types import TracebackType
from typing import Any, Optional, Type, Union
import diskcache
from .abstract_cache_base import AbstractCache
import sys
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
class DiskCache(AbstractCache):
@@ -21,18 +29,18 @@ class DiskCache(AbstractCache):
__exit__(self, exc_type, exc_value, traceback): Context management exit.
"""
def __init__(self, seed):
def __init__(self, seed: Union[str, int]):
"""
Initialize the DiskCache instance.
Args:
seed (str): A seed or namespace for the cache. This is used to create
seed (Union[str, int]): A seed or namespace for the cache. This is used to create
a unique storage location for the cache data.
"""
self.cache = diskcache.Cache(seed)
def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""
Retrieve an item from the cache.
@@ -46,7 +54,7 @@ class DiskCache(AbstractCache):
"""
return self.cache.get(key, default)
def set(self, key, value):
def set(self, key: str, value: Any) -> None:
"""
Set an item in the cache.
@@ -56,7 +64,7 @@ class DiskCache(AbstractCache):
"""
self.cache.set(key, value)
def close(self):
def close(self) -> None:
"""
Close the cache.
@@ -65,7 +73,7 @@ class DiskCache(AbstractCache):
"""
self.cache.close()
def __enter__(self):
def __enter__(self) -> Self:
"""
Enter the runtime context related to the object.
@@ -74,7 +82,12 @@ class DiskCache(AbstractCache):
"""
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""
Exit the runtime context related to the object.

View File

@@ -1,6 +1,6 @@
import pickle
from types import TracebackType
from typing import Any, Optional, Type
from typing import Any, Optional, Type, Union
import redis
import sys
from .abstract_cache_base import AbstractCache
@@ -19,7 +19,7 @@ class RedisCache(AbstractCache):
interface using the Redis database for caching data.
Attributes:
seed (str): A seed or namespace used as a prefix for cache keys.
seed (Union[str, int]): A seed or namespace used as a prefix for cache keys.
cache (redis.Redis): The Redis client used for caching.
Methods:
@@ -32,12 +32,12 @@ class RedisCache(AbstractCache):
__exit__(self, exc_type, exc_value, traceback): Context management exit.
"""
def __init__(self, seed: str, redis_url: str):
def __init__(self, seed: Union[str, int], redis_url: str):
"""
Initialize the RedisCache instance.
Args:
seed (str): A seed or namespace for the cache. This is used as a prefix for all cache keys.
seed (Union[str, int]): A seed or namespace for the cache. This is used as a prefix for all cache keys.
redis_url (str): The URL for the Redis server.
"""
@@ -56,7 +56,7 @@ class RedisCache(AbstractCache):
"""
return f"autogen:{self.seed}:{key}"
def get(self, key: str, default: Optional[Any] = None) -> Any:
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""
Retrieve an item from the Redis cache.