mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' into swiftyos/secrt-1706-improve-store-search
This commit is contained in:
@@ -1,339 +0,0 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
cast,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_hashable_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Convert args and kwargs into a hashable cache key.
|
||||
|
||||
Handles unhashable types like dict, list, set by converting them to
|
||||
their sorted string representations.
|
||||
"""
|
||||
|
||||
def make_hashable(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a hashable representation."""
|
||||
if isinstance(obj, dict):
|
||||
# Sort dict items to ensure consistent ordering
|
||||
return (
|
||||
"__dict__",
|
||||
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
|
||||
)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return ("__list__", tuple(make_hashable(item) for item in obj))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# Handle objects with __dict__ attribute
|
||||
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
|
||||
else:
|
||||
# For basic hashable types (str, int, bool, None, etc.)
|
||||
try:
|
||||
hash(obj)
|
||||
return obj
|
||||
except TypeError:
|
||||
# Fallback: convert to string representation
|
||||
return ("__str__", str(obj))
|
||||
|
||||
hashable_args = tuple(make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
|
||||
return (hashable_args, hashable_kwargs)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for cached functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
|
||||
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
|
||||
return False
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def cached(
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> Callable[[Callable], CachedFunction]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
|
||||
Uses double-checked locking to prevent multiple threads/coroutines from
|
||||
executing the expensive operation simultaneously during cache misses.
|
||||
|
||||
Args:
|
||||
func: The function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
@cache() # Default: maxsize=128, no TTL
|
||||
def expensive_sync_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cache() # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cache(maxsize=1000, ttl_seconds=300) # Custom maxsize and TTL
|
||||
def another_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(target_func):
|
||||
# Cache storage and per-event-loop locks
|
||||
cache_storage = {}
|
||||
_event_loop_locks = {} # Maps event loop to its asyncio.Lock
|
||||
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
|
||||
def _get_cache_lock():
|
||||
"""Get or create an asyncio.Lock for the current event loop."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop, use None as default key
|
||||
loop = None
|
||||
|
||||
if loop not in _event_loop_locks:
|
||||
return _event_loop_locks.setdefault(loop, asyncio.Lock())
|
||||
return _event_loop_locks[loop]
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with _get_cache_lock():
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = async_wrapper
|
||||
|
||||
else:
|
||||
# Sync function with threading.Lock
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Add cache management methods
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
def cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
setattr(wrapper, "cache_delete", cache_delete)
|
||||
|
||||
return cast(CachedFunction, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def thread_cached(func):
|
||||
"""
|
||||
Thread-local cache decorator for both sync and async functions.
|
||||
|
||||
Each thread gets its own cache, which is useful for request-scoped caching
|
||||
in web applications where you want to cache within a single request but
|
||||
not across requests.
|
||||
|
||||
Args:
|
||||
func: The function to cache
|
||||
|
||||
Returns:
|
||||
Decorated function with thread-local caching
|
||||
|
||||
Example:
|
||||
@thread_cached
|
||||
def expensive_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@thread_cached # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = await func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
"""Clear thread-local cache for a function."""
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
@@ -5,7 +5,7 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import json
|
||||
from backend.util.json import loads
|
||||
|
||||
|
||||
class StepThroughItemsBlock(Block):
|
||||
@@ -68,7 +68,7 @@ class StepThroughItemsBlock(Block):
|
||||
raise ValueError(
|
||||
f"Input too large: {len(data)} bytes > {MAX_ITEM_SIZE} bytes"
|
||||
)
|
||||
items = json.loads(data)
|
||||
items = loads(data)
|
||||
else:
|
||||
items = data
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import (
|
||||
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
@@ -28,6 +27,7 @@ from pydantic import BaseModel
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.cache import cached
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .model import (
|
||||
@@ -722,7 +722,7 @@ def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
return cls() if cls else None
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_webhook_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
@@ -731,7 +731,7 @@ def get_webhook_block_ids() -> Sequence[str]:
|
||||
]
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_io_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
|
||||
@@ -36,6 +36,7 @@ from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
@@ -993,14 +994,31 @@ class DisabledUserCredit(UserCreditBase):
|
||||
pass
|
||||
|
||||
|
||||
def get_user_credit_model() -> UserCreditBase:
|
||||
async def get_user_credit_model(user_id: str) -> UserCreditBase:
|
||||
"""
|
||||
Get the credit model for a user, considering LaunchDarkly flags.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID to check flags for.
|
||||
|
||||
Returns:
|
||||
UserCreditBase: The appropriate credit model for the user
|
||||
"""
|
||||
if not settings.config.enable_credit:
|
||||
return DisabledUserCredit()
|
||||
|
||||
if settings.config.enable_beta_monthly_credit:
|
||||
return BetaUserCredit(settings.config.num_user_credits_refill)
|
||||
# Check LaunchDarkly flag for payment pilot users
|
||||
# Default to False (beta monthly credit behavior) to maintain current behavior
|
||||
is_payment_enabled = await is_feature_enabled(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
return UserCredit()
|
||||
if is_payment_enabled:
|
||||
# Payment enabled users get UserCredit (no monthly refills, enable payments)
|
||||
return UserCredit()
|
||||
else:
|
||||
# Default behavior: users get beta monthly credits
|
||||
return BetaUserCredit(settings.config.num_user_credits_refill)
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list["BlockCost"]]:
|
||||
@@ -1090,7 +1108,8 @@ async def admin_get_user_history(
|
||||
)
|
||||
reason = metadata.get("reason", "No reason provided")
|
||||
|
||||
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
|
||||
user_credit_model = await get_user_credit_model(tx.userId)
|
||||
balance, _ = await user_credit_model._get_credits(tx.userId)
|
||||
|
||||
history.append(
|
||||
UserTransaction(
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any, Optional
|
||||
|
||||
import prisma
|
||||
import pydantic
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
@@ -13,6 +12,7 @@ from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
# Mapping from user reason id to categories to search for when choosing agent to show
|
||||
@@ -26,8 +26,6 @@ REASON_MAPPING: dict[str, list[str]] = {
|
||||
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
|
||||
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
|
||||
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
@@ -147,7 +145,8 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
return
|
||||
|
||||
onboarding.rewardedFor.append(step)
|
||||
await user_credit.onboarding_reward(user_id, reward, step)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
await user_credit_model.onboarding_reward(user_id, reward, step)
|
||||
await UserOnboarding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from autogpt_libs.utils.cache import cached, thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
from backend.util.cache import cached, thread_cached
|
||||
from backend.util.retry import conn_retry
|
||||
|
||||
load_dotenv()
|
||||
@@ -34,7 +34,7 @@ def disconnect():
|
||||
get_redis().close()
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_redis() -> Redis:
|
||||
return connect()
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Optional, cast
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import HTTPException
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User as PrismaUser
|
||||
@@ -16,6 +15,7 @@ from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import User, UserIntegrations, UserMetadata
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.util.cache import cached
|
||||
from backend.util.encryption import JSONCryptor
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
@@ -57,7 +57,6 @@ from backend.util.service import (
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -66,11 +65,13 @@ R = TypeVar("R")
|
||||
async def _spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
return await _user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
return await _user_credit_model.get_credits(user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from backend.util.cache import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
@@ -8,7 +8,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# --8<-- [start:load_webhook_managers]
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
|
||||
webhook_managers = {}
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import pydantic
|
||||
import stripe
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
@@ -40,6 +39,7 @@ from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
TransactionHistory,
|
||||
UserCredit,
|
||||
get_auto_top_up,
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
@@ -84,6 +84,7 @@ from backend.server.model import (
|
||||
UpdateTimezoneRequest,
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.util.cache import cached
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
@@ -107,9 +108,6 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
# Define the API routes
|
||||
v1_router = APIRouter()
|
||||
|
||||
@@ -291,7 +289,7 @@ def _compute_blocks_sync() -> str:
|
||||
return dumps(result)
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
async def _get_cached_blocks() -> str:
|
||||
"""
|
||||
Async cached function with thundering herd protection.
|
||||
@@ -478,7 +476,8 @@ async def upload_file(
|
||||
async def get_user_credits(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> dict[str, int]:
|
||||
return {"credits": await _user_credit_model.get_credits(user_id)}
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return {"credits": await user_credit_model.get_credits(user_id)}
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -490,9 +489,8 @@ async def get_user_credits(
|
||||
async def request_top_up(
|
||||
request: RequestTopUp, user_id: Annotated[str, Security(get_user_id)]
|
||||
):
|
||||
checkout_url = await _user_credit_model.top_up_intent(
|
||||
user_id, request.credit_amount
|
||||
)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
checkout_url = await user_credit_model.top_up_intent(user_id, request.credit_amount)
|
||||
return {"checkout_url": checkout_url}
|
||||
|
||||
|
||||
@@ -507,7 +505,8 @@ async def refund_top_up(
|
||||
transaction_key: str,
|
||||
metadata: dict[str, str],
|
||||
) -> int:
|
||||
return await _user_credit_model.top_up_refund(user_id, transaction_key, metadata)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.top_up_refund(user_id, transaction_key, metadata)
|
||||
|
||||
|
||||
@v1_router.patch(
|
||||
@@ -517,7 +516,8 @@ async def refund_top_up(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
|
||||
await _user_credit_model.fulfill_checkout(user_id=user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
await user_credit_model.fulfill_checkout(user_id=user_id)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@@ -537,12 +537,13 @@ async def configure_user_auto_top_up(
|
||||
if request.amount < request.threshold:
|
||||
raise ValueError("Amount must be greater than or equal to threshold")
|
||||
|
||||
current_balance = await _user_credit_model.get_credits(user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
|
||||
if current_balance < request.threshold:
|
||||
await _user_credit_model.top_up_credits(user_id, request.amount)
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await _user_credit_model.top_up_credits(user_id, 0)
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
@@ -590,15 +591,13 @@ async def stripe_webhook(request: Request):
|
||||
event["type"] == "checkout.session.completed"
|
||||
or event["type"] == "checkout.session.async_payment_succeeded"
|
||||
):
|
||||
await _user_credit_model.fulfill_checkout(
|
||||
session_id=event["data"]["object"]["id"]
|
||||
)
|
||||
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
|
||||
|
||||
if event["type"] == "charge.dispute.created":
|
||||
await _user_credit_model.handle_dispute(event["data"]["object"])
|
||||
await UserCredit().handle_dispute(event["data"]["object"])
|
||||
|
||||
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
|
||||
await _user_credit_model.deduct_credits(event["data"]["object"])
|
||||
await UserCredit().deduct_credits(event["data"]["object"])
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
@@ -612,7 +611,8 @@ async def stripe_webhook(request: Request):
|
||||
async def manage_payment_method(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> dict[str, str]:
|
||||
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return {"url": await user_credit_model.create_billing_portal_session(user_id)}
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -630,7 +630,8 @@ async def get_credit_history(
|
||||
if transaction_count_limit < 1 or transaction_count_limit > 1000:
|
||||
raise ValueError("Transaction count limit must be between 1 and 1000")
|
||||
|
||||
return await _user_credit_model.get_transaction_history(
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.get_transaction_history(
|
||||
user_id=user_id,
|
||||
transaction_time_ceiling=transaction_time,
|
||||
transaction_count_limit=transaction_count_limit,
|
||||
@@ -647,7 +648,8 @@ async def get_credit_history(
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[RefundRequest]:
|
||||
return await _user_credit_model.get_refund_requests(user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.get_refund_requests(user_id)
|
||||
|
||||
|
||||
########################################################
|
||||
@@ -869,7 +871,8 @@ async def execute_graph(
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
current_balance = await _user_credit_model.get_credits(user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
|
||||
@@ -194,8 +194,12 @@ def test_get_user_credits(
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test get user credits endpoint"""
|
||||
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model.get_credits = AsyncMock(return_value=1000)
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
response = client.get("/credits")
|
||||
|
||||
@@ -215,10 +219,14 @@ def test_request_top_up(
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test request top up endpoint"""
|
||||
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model.top_up_intent = AsyncMock(
|
||||
return_value="https://checkout.example.com/session123"
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
request_data = {"credit_amount": 500}
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@ from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
@@ -33,7 +31,8 @@ async def add_user_credits(
|
||||
logger.info(
|
||||
f"Admin user {admin_user_id} is adding {amount} credits to user {user_id}"
|
||||
)
|
||||
new_balance, transaction_key = await _user_credit_model._add_transaction(
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
new_balance, transaction_key = await user_credit_model._add_transaction(
|
||||
user_id,
|
||||
amount,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -37,12 +37,14 @@ def test_add_user_credits_success(
|
||||
) -> None:
|
||||
"""Test successful credit addition by admin"""
|
||||
# Mock the credit model
|
||||
mock_credit_model = mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
|
||||
)
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model._add_transaction = AsyncMock(
|
||||
return_value=(1500, "transaction-123-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"user_id": target_user_id,
|
||||
@@ -81,12 +83,14 @@ def test_add_user_credits_negative_amount(
|
||||
) -> None:
|
||||
"""Test credit deduction by admin (negative amount)"""
|
||||
# Mock the credit model
|
||||
mock_credit_model = mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
|
||||
)
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model._add_transaction = AsyncMock(
|
||||
return_value=(200, "transaction-456-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"user_id": "target-user-id",
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import prisma
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.data.block
|
||||
from backend.blocks import load_all_blocks
|
||||
@@ -18,6 +17,7 @@ from backend.server.v2.builder.model import (
|
||||
ProviderResponse,
|
||||
SearchBlocksResponse,
|
||||
)
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -307,7 +307,7 @@ def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
providers: dict[ProviderName, Provider] = {}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.server.v2.store.db
|
||||
from backend.util.cache import cached
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
@@ -17,7 +16,7 @@ def clear_all_caches():
|
||||
|
||||
# Cache store agents list for 5 minutes
|
||||
# Different cache entries for different query combinations
|
||||
@cached(maxsize=5000, ttl_seconds=300)
|
||||
@cached(maxsize=5000, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
@@ -40,7 +39,7 @@ async def _get_cached_store_agents(
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300)
|
||||
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
@@ -49,7 +48,7 @@ async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
|
||||
|
||||
# Cache creators list for 5 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300)
|
||||
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
@@ -68,7 +67,7 @@ async def _get_cached_store_creators(
|
||||
|
||||
|
||||
# Cache individual creator details for 5 minutes
|
||||
@cached(maxsize=100, ttl_seconds=300)
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
|
||||
457
autogpt_platform/backend/backend/util/cache.py
Normal file
457
autogpt_platform/backend/backend/util/cache.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
Caching utilities for the AutoGPT platform.
|
||||
|
||||
Provides decorators for caching function results with support for:
|
||||
- In-memory caching with TTL
|
||||
- Shared Redis-backed caching across processes
|
||||
- Thread-local caching for request-scoped data
|
||||
- Thundering herd protection
|
||||
- LRU eviction with optional TTL refresh
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, ParamSpec, Protocol, TypeVar, cast, runtime_checkable
|
||||
|
||||
from redis import ConnectionPool, Redis
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# RECOMMENDED REDIS CONFIGURATION FOR PRODUCTION:
|
||||
# Configure Redis with the following settings for optimal caching performance:
|
||||
# maxmemory-policy allkeys-lru # Evict least recently used keys when memory limit reached
|
||||
# maxmemory 2gb # Set memory limit (adjust based on your needs)
|
||||
# save "" # Disable persistence if using Redis purely for caching
|
||||
|
||||
# Create a dedicated Redis connection pool for caching (binary mode for pickle)
|
||||
_cache_pool: ConnectionPool | None = None
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring cache connection pool")
|
||||
def _get_cache_pool() -> ConnectionPool:
|
||||
"""Get or create a connection pool for cache operations."""
|
||||
global _cache_pool
|
||||
if _cache_pool is None:
|
||||
_cache_pool = ConnectionPool(
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
max_connections=50,
|
||||
socket_keepalive=True,
|
||||
socket_connect_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _cache_pool
|
||||
|
||||
|
||||
redis = Redis(connection_pool=_get_cache_pool())
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedValue:
|
||||
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
|
||||
|
||||
result: Any
|
||||
timestamp: float
|
||||
|
||||
|
||||
def _make_hashable_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Convert args and kwargs into a hashable cache key.
|
||||
|
||||
Handles unhashable types like dict, list, set by converting them to
|
||||
their sorted string representations.
|
||||
"""
|
||||
|
||||
def make_hashable(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a hashable representation."""
|
||||
if isinstance(obj, dict):
|
||||
# Sort dict items to ensure consistent ordering
|
||||
return (
|
||||
"__dict__",
|
||||
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
|
||||
)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return ("__list__", tuple(make_hashable(item) for item in obj))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# Handle objects with __dict__ attribute
|
||||
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
|
||||
else:
|
||||
# For basic hashable types (str, int, bool, None, etc.)
|
||||
try:
|
||||
hash(obj)
|
||||
return obj
|
||||
except TypeError:
|
||||
# Fallback: convert to string representation
|
||||
return ("__str__", str(obj))
|
||||
|
||||
hashable_args = tuple(make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
|
||||
return (hashable_args, hashable_kwargs)
|
||||
|
||||
|
||||
def _make_redis_key(key: tuple[Any, ...], func_name: str) -> str:
|
||||
"""Convert a hashable key tuple to a Redis key string."""
|
||||
# Ensure key is already hashable
|
||||
hashable_key = key if isinstance(key, tuple) else (key,)
|
||||
return f"cache:{func_name}:{hash(hashable_key)}"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for cached functions with cache management methods."""
|
||||
|
||||
def cache_clear(self, pattern: str | None = None) -> None:
|
||||
"""Clear cached entries. If pattern provided, clear matching entries only."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
|
||||
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
|
||||
return False
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def cached(
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int,
|
||||
shared_cache: bool = False,
|
||||
refresh_ttl_on_get: bool = False,
|
||||
) -> Callable[[Callable], CachedFunction]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
|
||||
Uses double-checked locking to prevent multiple threads/coroutines from
|
||||
executing the expensive operation simultaneously during cache misses.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries (only for in-memory cache)
|
||||
ttl_seconds: Time to live in seconds. Required - entries must expire.
|
||||
shared_cache: If True, use Redis for cross-process caching
|
||||
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
|
||||
|
||||
Returns:
|
||||
Decorated function with caching capabilities
|
||||
|
||||
Example:
|
||||
@cached(ttl_seconds=300) # 5 minute TTL
|
||||
def expensive_sync_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(target_func):
|
||||
cache_storage: dict[tuple, CachedValue] = {}
|
||||
_event_loop_locks: dict[Any, asyncio.Lock] = {}
|
||||
|
||||
def _get_from_redis(redis_key: str) -> Any | None:
|
||||
"""Get value from Redis, optionally refreshing TTL."""
|
||||
try:
|
||||
if refresh_ttl_on_get:
|
||||
# Use GETEX to get value and refresh expiry atomically
|
||||
cached_bytes = redis.getex(redis_key, ex=ttl_seconds)
|
||||
else:
|
||||
cached_bytes = redis.get(redis_key)
|
||||
|
||||
if cached_bytes and isinstance(cached_bytes, bytes):
|
||||
return pickle.loads(cached_bytes)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error during cache check for {target_func.__name__}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _set_to_redis(redis_key: str, value: Any) -> None:
|
||||
"""Set value in Redis with TTL."""
|
||||
try:
|
||||
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
redis.setex(redis_key, ttl_seconds, pickled_value)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error storing cache for {target_func.__name__}: {e}"
|
||||
)
|
||||
|
||||
def _get_from_memory(key: tuple) -> Any | None:
|
||||
"""Get value from in-memory cache, checking TTL."""
|
||||
if key in cache_storage:
|
||||
cached_data = cache_storage[key]
|
||||
if time.time() - cached_data.timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {target_func.__name__} args: {key[0]} kwargs: {key[1]}"
|
||||
)
|
||||
return cached_data.result
|
||||
return None
|
||||
|
||||
def _set_to_memory(key: tuple, value: Any) -> None:
|
||||
"""Set value in in-memory cache with timestamp."""
|
||||
cache_storage[key] = CachedValue(result=value, timestamp=time.time())
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
|
||||
def _get_cache_lock():
|
||||
"""Get or create an asyncio.Lock for the current event loop."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop not in _event_loop_locks:
|
||||
return _event_loop_locks.setdefault(loop, asyncio.Lock())
|
||||
return _event_loop_locks[loop]
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = (
|
||||
_make_redis_key(key, target_func.__name__) if shared_cache else ""
|
||||
)
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with _get_cache_lock():
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = async_wrapper
|
||||
|
||||
else:
|
||||
# Sync function with threading.Lock
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = (
|
||||
_make_redis_key(key, target_func.__name__) if shared_cache else ""
|
||||
)
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Add cache management methods
|
||||
def cache_clear(pattern: str | None = None) -> None:
|
||||
"""Clear cache entries. If pattern provided, clear matching entries."""
|
||||
if shared_cache:
|
||||
if pattern:
|
||||
# Clear entries matching pattern
|
||||
keys = list(
|
||||
redis.scan_iter(f"cache:{target_func.__name__}:{pattern}")
|
||||
)
|
||||
else:
|
||||
# Clear all cache keys
|
||||
keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
|
||||
|
||||
if keys:
|
||||
pipeline = redis.pipeline()
|
||||
for key in keys:
|
||||
pipeline.delete(key)
|
||||
pipeline.execute()
|
||||
else:
|
||||
if pattern:
|
||||
# For in-memory cache, pattern matching not supported
|
||||
logger.warning(
|
||||
"Pattern-based clearing not supported for in-memory cache"
|
||||
)
|
||||
else:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
if shared_cache:
|
||||
cache_keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
|
||||
return {
|
||||
"size": len(cache_keys),
|
||||
"maxsize": None, # Redis manages its own size
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
def cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if shared_cache:
|
||||
redis_key = _make_redis_key(key, target_func.__name__)
|
||||
if redis.exists(redis_key):
|
||||
redis.delete(redis_key)
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
setattr(wrapper, "cache_delete", cache_delete)
|
||||
|
||||
return cast(CachedFunction, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def thread_cached(func):
|
||||
"""
|
||||
Thread-local cache decorator for both sync and async functions.
|
||||
|
||||
Each thread gets its own cache, which is useful for request-scoped caching
|
||||
in web applications where you want to cache within a single request but
|
||||
not across requests.
|
||||
|
||||
Args:
|
||||
func: The function to cache
|
||||
|
||||
Returns:
|
||||
Decorated function with thread-local caching
|
||||
|
||||
Example:
|
||||
@thread_cached
|
||||
def expensive_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@thread_cached # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = await func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
"""Clear thread-local cache for a function."""
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
@@ -16,7 +16,7 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import cached, clear_thread_cache, thread_cached
|
||||
from backend.util.cache import cached, clear_thread_cache, thread_cached
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
@@ -332,7 +332,7 @@ class TestCache:
|
||||
"""Test basic sync caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
def expensive_sync_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -358,7 +358,7 @@ class TestCache:
|
||||
"""Test basic async caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -385,7 +385,7 @@ class TestCache:
|
||||
call_count = 0
|
||||
results = []
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
def slow_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -412,7 +412,7 @@ class TestCache:
|
||||
"""Test that concurrent async calls don't cause thundering herd."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
async def slow_async_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -508,7 +508,7 @@ class TestCache:
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -537,7 +537,7 @@ class TestCache:
|
||||
"""Test cache clearing functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
async def async_clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -567,7 +567,7 @@ class TestCache:
|
||||
"""Test that cached async functions return actual results, not coroutines."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
async def async_result_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -593,7 +593,7 @@ class TestCache:
|
||||
"""Test selective cache deletion functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
def deletable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -636,7 +636,7 @@ class TestCache:
|
||||
"""Test selective cache deletion functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
async def async_deletable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -674,3 +674,450 @@ class TestCache:
|
||||
# Try to delete non-existent entry
|
||||
was_deleted = async_deletable_function.cache_delete(99)
|
||||
assert was_deleted is False
|
||||
|
||||
|
||||
class TestSharedCache:
|
||||
"""Tests for shared_cache (Redis-backed) functionality."""
|
||||
|
||||
def test_sync_shared_cache_basic(self):
|
||||
"""Test basic shared cache functionality with sync function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def shared_sync_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + y
|
||||
|
||||
# Clear any existing cache
|
||||
shared_sync_function.cache_clear()
|
||||
|
||||
# First call
|
||||
result1 = shared_sync_function(10, 20)
|
||||
assert result1 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use Redis cache
|
||||
result2 = shared_sync_function(10, 20)
|
||||
assert result2 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = shared_sync_function(15, 25)
|
||||
assert result3 == 40
|
||||
assert call_count == 2
|
||||
|
||||
# Cleanup
|
||||
shared_sync_function.cache_clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_shared_cache_basic(self):
|
||||
"""Test basic shared cache functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def shared_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x + y
|
||||
|
||||
# Clear any existing cache
|
||||
shared_async_function.cache_clear()
|
||||
|
||||
# First call
|
||||
result1 = await shared_async_function(10, 20)
|
||||
assert result1 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use Redis cache
|
||||
result2 = await shared_async_function(10, 20)
|
||||
assert result2 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await shared_async_function(15, 25)
|
||||
assert result3 == 40
|
||||
assert call_count == 2
|
||||
|
||||
# Cleanup
|
||||
shared_async_function.cache_clear()
|
||||
|
||||
def test_shared_cache_ttl_refresh(self):
|
||||
"""Test TTL refresh functionality with shared cache."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=2, shared_cache=True, refresh_ttl_on_get=True)
|
||||
def ttl_refresh_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 10
|
||||
|
||||
# Clear any existing cache
|
||||
ttl_refresh_function.cache_clear()
|
||||
|
||||
# First call
|
||||
result1 = ttl_refresh_function(3)
|
||||
assert result1 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Wait 1 second
|
||||
time.sleep(1)
|
||||
|
||||
# Second call - should refresh TTL and use cache
|
||||
result2 = ttl_refresh_function(3)
|
||||
assert result2 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Wait another 1.5 seconds (total 2.5s from first call, 1.5s from second)
|
||||
time.sleep(1.5)
|
||||
|
||||
# Third call - TTL should have been refreshed, so still cached
|
||||
result3 = ttl_refresh_function(3)
|
||||
assert result3 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Wait 2.1 seconds - now it should expire
|
||||
time.sleep(2.1)
|
||||
|
||||
# Fourth call - should call function again
|
||||
result4 = ttl_refresh_function(3)
|
||||
assert result4 == 30
|
||||
assert call_count == 2
|
||||
|
||||
# Cleanup
|
||||
ttl_refresh_function.cache_clear()
|
||||
|
||||
def test_shared_cache_without_ttl_refresh(self):
|
||||
"""Test that TTL doesn't refresh when refresh_ttl_on_get=False."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=2, shared_cache=True, refresh_ttl_on_get=False)
|
||||
def no_ttl_refresh_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 10
|
||||
|
||||
# Clear any existing cache
|
||||
no_ttl_refresh_function.cache_clear()
|
||||
|
||||
# First call
|
||||
result1 = no_ttl_refresh_function(4)
|
||||
assert result1 == 40
|
||||
assert call_count == 1
|
||||
|
||||
# Wait 1 second
|
||||
time.sleep(1)
|
||||
|
||||
# Second call - should use cache but NOT refresh TTL
|
||||
result2 = no_ttl_refresh_function(4)
|
||||
assert result2 == 40
|
||||
assert call_count == 1
|
||||
|
||||
# Wait another 1.1 seconds (total 2.1s from first call)
|
||||
time.sleep(1.1)
|
||||
|
||||
# Third call - should have expired
|
||||
result3 = no_ttl_refresh_function(4)
|
||||
assert result3 == 40
|
||||
assert call_count == 2
|
||||
|
||||
# Cleanup
|
||||
no_ttl_refresh_function.cache_clear()
|
||||
|
||||
def test_shared_cache_complex_objects(self):
|
||||
"""Test caching complex objects with shared cache (pickle serialization)."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def complex_object_function(x: int) -> dict:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return {
|
||||
"number": x,
|
||||
"squared": x**2,
|
||||
"nested": {"list": [1, 2, x], "tuple": (x, x * 2)},
|
||||
"string": f"value_{x}",
|
||||
}
|
||||
|
||||
# Clear any existing cache
|
||||
complex_object_function.cache_clear()
|
||||
|
||||
# First call
|
||||
result1 = complex_object_function(5)
|
||||
assert result1["number"] == 5
|
||||
assert result1["squared"] == 25
|
||||
assert result1["nested"]["list"] == [1, 2, 5]
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = complex_object_function(5)
|
||||
assert result2 == result1
|
||||
assert call_count == 1
|
||||
|
||||
# Cleanup
|
||||
complex_object_function.cache_clear()
|
||||
|
||||
def test_shared_cache_info(self):
|
||||
"""Test cache_info for shared cache."""
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def info_shared_function(x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
# Clear any existing cache
|
||||
info_shared_function.cache_clear()
|
||||
|
||||
# Check initial info
|
||||
info = info_shared_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] is None # Redis manages size
|
||||
assert info["ttl_seconds"] == 30
|
||||
|
||||
# Add some entries
|
||||
info_shared_function(1)
|
||||
info_shared_function(2)
|
||||
info_shared_function(3)
|
||||
|
||||
info = info_shared_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Cleanup
|
||||
info_shared_function.cache_clear()
|
||||
|
||||
def test_shared_cache_delete(self):
|
||||
"""Test selective deletion with shared cache."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def delete_shared_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 3
|
||||
|
||||
# Clear any existing cache
|
||||
delete_shared_function.cache_clear()
|
||||
|
||||
# Add entries
|
||||
delete_shared_function(1)
|
||||
delete_shared_function(2)
|
||||
delete_shared_function(3)
|
||||
assert call_count == 3
|
||||
|
||||
# Verify cached
|
||||
delete_shared_function(1)
|
||||
delete_shared_function(2)
|
||||
assert call_count == 3
|
||||
|
||||
# Delete specific entry
|
||||
was_deleted = delete_shared_function.cache_delete(2)
|
||||
assert was_deleted is True
|
||||
|
||||
# Entry for x=2 should be gone
|
||||
delete_shared_function(2)
|
||||
assert call_count == 4
|
||||
|
||||
# Others should still be cached
|
||||
delete_shared_function(1)
|
||||
delete_shared_function(3)
|
||||
assert call_count == 4
|
||||
|
||||
# Try to delete non-existent
|
||||
was_deleted = delete_shared_function.cache_delete(99)
|
||||
assert was_deleted is False
|
||||
|
||||
# Cleanup
|
||||
delete_shared_function.cache_clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_shared_cache_thundering_herd(self):
|
||||
"""Test that shared cache prevents thundering herd for async functions."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def shared_slow_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.1)
|
||||
return x * x
|
||||
|
||||
# Clear any existing cache
|
||||
shared_slow_function.cache_clear()
|
||||
|
||||
# Launch multiple concurrent tasks
|
||||
tasks = [shared_slow_function(8) for _ in range(10)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All should return same result
|
||||
assert all(r == 64 for r in results)
|
||||
# Only one should have executed
|
||||
assert call_count == 1
|
||||
|
||||
# Cleanup
|
||||
shared_slow_function.cache_clear()
|
||||
|
||||
def test_shared_cache_clear_pattern(self):
|
||||
"""Test pattern-based cache clearing (Redis feature)."""
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def pattern_function(category: str, item: int) -> str:
|
||||
return f"{category}_{item}"
|
||||
|
||||
# Clear any existing cache
|
||||
pattern_function.cache_clear()
|
||||
|
||||
# Add various entries
|
||||
pattern_function("fruit", 1)
|
||||
pattern_function("fruit", 2)
|
||||
pattern_function("vegetable", 1)
|
||||
pattern_function("vegetable", 2)
|
||||
|
||||
info = pattern_function.cache_info()
|
||||
assert info["size"] == 4
|
||||
|
||||
# Note: Pattern clearing with wildcards requires specific Redis scan
|
||||
# implementation. The current code clears by pattern but needs
|
||||
# adjustment for partial matching. For now, test full clear.
|
||||
pattern_function.cache_clear()
|
||||
info = pattern_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
|
||||
def test_shared_vs_local_cache_isolation(self):
|
||||
"""Test that shared and local caches are isolated."""
|
||||
shared_count = 0
|
||||
local_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def shared_function(x: int) -> int:
|
||||
nonlocal shared_count
|
||||
shared_count += 1
|
||||
return x * 2
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=False)
|
||||
def local_function(x: int) -> int:
|
||||
nonlocal local_count
|
||||
local_count += 1
|
||||
return x * 2
|
||||
|
||||
# Clear caches
|
||||
shared_function.cache_clear()
|
||||
local_function.cache_clear()
|
||||
|
||||
# Call both with same args
|
||||
shared_result = shared_function(5)
|
||||
local_result = local_function(5)
|
||||
|
||||
assert shared_result == local_result == 10
|
||||
assert shared_count == 1
|
||||
assert local_count == 1
|
||||
|
||||
# Call again - both should use their respective caches
|
||||
shared_function(5)
|
||||
local_function(5)
|
||||
assert shared_count == 1
|
||||
assert local_count == 1
|
||||
|
||||
# Clear only shared cache
|
||||
shared_function.cache_clear()
|
||||
|
||||
# Shared should recompute, local should still use cache
|
||||
shared_function(5)
|
||||
local_function(5)
|
||||
assert shared_count == 2
|
||||
assert local_count == 1
|
||||
|
||||
# Cleanup
|
||||
shared_function.cache_clear()
|
||||
local_function.cache_clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shared_cache_concurrent_different_keys(self):
|
||||
"""Test that concurrent calls with different keys work correctly."""
|
||||
call_counts = {}
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def multi_key_function(key: str) -> str:
|
||||
if key not in call_counts:
|
||||
call_counts[key] = 0
|
||||
call_counts[key] += 1
|
||||
await asyncio.sleep(0.05)
|
||||
return f"result_{key}"
|
||||
|
||||
# Clear cache
|
||||
multi_key_function.cache_clear()
|
||||
|
||||
# Launch concurrent tasks with different keys
|
||||
keys = ["a", "b", "c", "d", "e"]
|
||||
tasks = []
|
||||
for key in keys:
|
||||
# Multiple calls per key
|
||||
tasks.extend([multi_key_function(key) for _ in range(3)])
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify results
|
||||
for i, key in enumerate(keys):
|
||||
expected = f"result_{key}"
|
||||
# Each key appears 3 times in results
|
||||
key_results = results[i * 3 : (i + 1) * 3]
|
||||
assert all(r == expected for r in key_results)
|
||||
|
||||
# Each key should only be computed once
|
||||
for key in keys:
|
||||
assert call_counts[key] == 1
|
||||
|
||||
# Cleanup
|
||||
multi_key_function.cache_clear()
|
||||
|
||||
def test_shared_cache_performance_comparison(self):
|
||||
"""Compare performance of shared vs local cache."""
|
||||
import statistics
|
||||
|
||||
shared_times = []
|
||||
local_times = []
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def shared_perf_function(x: int) -> int:
|
||||
time.sleep(0.01) # Simulate work
|
||||
return x * 2
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=False)
|
||||
def local_perf_function(x: int) -> int:
|
||||
time.sleep(0.01) # Simulate work
|
||||
return x * 2
|
||||
|
||||
# Clear caches
|
||||
shared_perf_function.cache_clear()
|
||||
local_perf_function.cache_clear()
|
||||
|
||||
# Warm up both caches
|
||||
for i in range(5):
|
||||
shared_perf_function(i)
|
||||
local_perf_function(i)
|
||||
|
||||
# Measure cache hit times
|
||||
for i in range(5):
|
||||
# Shared cache hit
|
||||
start = time.time()
|
||||
shared_perf_function(i)
|
||||
shared_times.append(time.time() - start)
|
||||
|
||||
# Local cache hit
|
||||
start = time.time()
|
||||
local_perf_function(i)
|
||||
local_times.append(time.time() - start)
|
||||
|
||||
# Local cache should be faster (no Redis round-trip)
|
||||
avg_shared = statistics.mean(shared_times)
|
||||
avg_local = statistics.mean(local_times)
|
||||
|
||||
print(f"Avg shared cache hit time: {avg_shared:.6f}s")
|
||||
print(f"Avg local cache hit time: {avg_local:.6f}s")
|
||||
|
||||
# Local should be significantly faster for cache hits
|
||||
# Redis adds network latency even for cache hits
|
||||
assert avg_local < avg_shared
|
||||
|
||||
# Cleanup
|
||||
shared_perf_function.cache_clear()
|
||||
local_perf_function.cache_clear()
|
||||
@@ -4,8 +4,7 @@ Centralized service client helpers with thread caching.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt_libs.utils.cache import cached, thread_cached
|
||||
|
||||
from backend.util.cache import cached, thread_cached
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
@@ -120,7 +119,7 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
# ============ Supabase Clients ============ #
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_supabase() -> "Client":
|
||||
"""Get a process-cached synchronous Supabase client instance."""
|
||||
from supabase import create_client
|
||||
@@ -130,7 +129,7 @@ def get_supabase() -> "Client":
|
||||
)
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
async def get_async_supabase() -> "AClient":
|
||||
"""Get a process-cached asynchronous Supabase client instance."""
|
||||
from supabase import create_async_client
|
||||
|
||||
@@ -5,12 +5,12 @@ from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
import ldclient
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from backend.util.cache import cached
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,6 +35,7 @@ class Flag(str, Enum):
|
||||
AI_ACTIVITY_STATUS = "ai-agent-execution-summary"
|
||||
BETA_BLOCKS = "beta-blocks"
|
||||
AGENT_ACTIVITY = "agent-activity"
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Type, TypeGuard, TypeVar, overload
|
||||
|
||||
@@ -14,13 +13,6 @@ from .type import type_match
|
||||
# Removes \u0000-\u0008, \u000B-\u000C, \u000E-\u001F, \u007F (keeps tab \u0009, newline \u000A, carriage return \u000D)
|
||||
POSTGRES_CONTROL_CHARS = re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]")
|
||||
|
||||
# Comprehensive regex to remove all PostgreSQL-incompatible control character sequences in JSON
|
||||
# Handles both Unicode escapes (\\u0000-\\u0008, \\u000B-\\u000C, \\u000E-\\u001F, \\u007F)
|
||||
# and JSON single-char escapes (\\b, \\f) while preserving legitimate file paths
|
||||
POSTGRES_JSON_ESCAPES = re.compile(
|
||||
r"\\u000[0-8]|\\u000[bB]|\\u000[cC]|\\u00[0-1][0-9a-fA-F]|\\u007[fF]|(?<!\\)\\[bf](?!\\)"
|
||||
)
|
||||
|
||||
|
||||
def to_dict(data) -> dict:
|
||||
if isinstance(data, BaseModel):
|
||||
@@ -130,24 +122,64 @@ def convert_pydantic_to_json(output_data: Any) -> Any:
|
||||
return output_data
|
||||
|
||||
|
||||
def _sanitize_value(value: Any) -> Any:
|
||||
"""
|
||||
Recursively sanitize values by removing PostgreSQL-incompatible control characters.
|
||||
|
||||
This function walks through data structures and removes control characters from strings.
|
||||
It handles:
|
||||
- Strings: Remove control chars directly from the string
|
||||
- Lists: Recursively sanitize each element
|
||||
- Dicts: Recursively sanitize keys and values
|
||||
- Other types: Return as-is
|
||||
|
||||
Args:
|
||||
value: The value to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized version of the value with control characters removed
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
# Remove control characters directly from the string
|
||||
return POSTGRES_CONTROL_CHARS.sub("", value)
|
||||
elif isinstance(value, dict):
|
||||
# Recursively sanitize dictionary keys and values
|
||||
return {_sanitize_value(k): _sanitize_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
# Recursively sanitize list elements
|
||||
return [_sanitize_value(item) for item in value]
|
||||
elif isinstance(value, tuple):
|
||||
# Recursively sanitize tuple elements
|
||||
return tuple(_sanitize_value(item) for item in value)
|
||||
else:
|
||||
# For other types (int, float, bool, None, etc.), return as-is
|
||||
return value
|
||||
|
||||
|
||||
def SafeJson(data: Any) -> Json:
|
||||
"""
|
||||
Safely serialize data and return Prisma's Json type.
|
||||
Sanitizes null bytes to prevent PostgreSQL 22P05 errors.
|
||||
Sanitizes control characters to prevent PostgreSQL 22P05 errors.
|
||||
|
||||
This function:
|
||||
1. Converts Pydantic models to dicts
|
||||
2. Recursively removes PostgreSQL-incompatible control characters from strings
|
||||
3. Returns a Prisma Json object safe for database storage
|
||||
|
||||
Args:
|
||||
data: Input data to sanitize and convert to Json
|
||||
|
||||
Returns:
|
||||
Prisma Json object with control characters removed
|
||||
|
||||
Examples:
|
||||
>>> SafeJson({"text": "Hello\\x00World"}) # null char removed
|
||||
>>> SafeJson({"path": "C:\\\\temp"}) # backslashes preserved
|
||||
>>> SafeJson({"data": "Text\\\\u0000here"}) # literal backslash-u preserved
|
||||
"""
|
||||
# Convert Pydantic models to dict first
|
||||
if isinstance(data, BaseModel):
|
||||
json_string = data.model_dump_json(
|
||||
warnings="error",
|
||||
exclude_none=True,
|
||||
fallback=lambda v: None,
|
||||
)
|
||||
else:
|
||||
json_string = dumps(data, default=lambda v: None)
|
||||
data = data.model_dump(exclude_none=True)
|
||||
|
||||
# Remove PostgreSQL-incompatible control characters in JSON string
|
||||
# Single comprehensive regex handles all control character sequences
|
||||
sanitized_json = POSTGRES_JSON_ESCAPES.sub("", json_string)
|
||||
|
||||
# Remove any remaining raw control characters (fallback safety net)
|
||||
sanitized_json = POSTGRES_CONTROL_CHARS.sub("", sanitized_json)
|
||||
return Json(json.loads(sanitized_json))
|
||||
# Return as Prisma Json type
|
||||
return Json(_sanitize_value(data))
|
||||
|
||||
@@ -8,18 +8,9 @@ from typing import Optional
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.settings import set_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_SERVICE_NAME = "MainProcess"
|
||||
|
||||
|
||||
def get_service_name():
|
||||
return _SERVICE_NAME
|
||||
|
||||
|
||||
def set_service_name(name: str):
|
||||
global _SERVICE_NAME
|
||||
_SERVICE_NAME = name
|
||||
|
||||
|
||||
class AppProcess(ABC):
|
||||
|
||||
@@ -13,7 +13,7 @@ import idna
|
||||
from aiohttp import FormData, abc
|
||||
from tenacity import retry, retry_if_result, wait_exponential_jitter
|
||||
|
||||
from backend.util.json import json
|
||||
from backend.util.json import loads
|
||||
|
||||
# Retry status codes for which we will automatically retry the request
|
||||
THROTTLE_RETRY_STATUS_CODES: set[int] = {429, 500, 502, 503, 504, 408}
|
||||
@@ -259,7 +259,7 @@ class Response:
|
||||
"""
|
||||
Parse the body as JSON and return the resulting Python object.
|
||||
"""
|
||||
return json.loads(
|
||||
return loads(
|
||||
self.content.decode(encoding or "utf-8", errors="replace"), **kwargs
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from tenacity import (
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from backend.util.process import get_service_name
|
||||
from backend.util.settings import get_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -31,9 +31,9 @@ import backend.util.exceptions as exceptions
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import conn_retry, create_retry_decorator
|
||||
from backend.util.settings import Config
|
||||
from backend.util.settings import Config, get_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -15,6 +15,17 @@ from backend.util.data import get_data_path
|
||||
|
||||
T = TypeVar("T", bound=BaseSettings)
|
||||
|
||||
_SERVICE_NAME = "MainProcess"
|
||||
|
||||
|
||||
def get_service_name():
|
||||
return _SERVICE_NAME
|
||||
|
||||
|
||||
def set_service_name(name: str):
|
||||
global _SERVICE_NAME
|
||||
_SERVICE_NAME = name
|
||||
|
||||
|
||||
class AppEnvironment(str, Enum):
|
||||
LOCAL = "local"
|
||||
@@ -254,6 +265,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default="localhost",
|
||||
description="The host for the RabbitMQ server",
|
||||
)
|
||||
|
||||
rabbitmq_port: int = Field(
|
||||
default=5672,
|
||||
description="The port for the RabbitMQ server",
|
||||
@@ -264,6 +276,21 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The vhost for the RabbitMQ server",
|
||||
)
|
||||
|
||||
redis_host: str = Field(
|
||||
default="localhost",
|
||||
description="The host for the Redis server",
|
||||
)
|
||||
|
||||
redis_port: int = Field(
|
||||
default=6379,
|
||||
description="The port for the Redis server",
|
||||
)
|
||||
|
||||
redis_password: str = Field(
|
||||
default="",
|
||||
description="The password for the Redis server (empty string if no password)",
|
||||
)
|
||||
|
||||
postmark_sender_email: str = Field(
|
||||
default="invalid@invalid.com",
|
||||
description="The email address to use for sending emails",
|
||||
|
||||
@@ -411,3 +411,71 @@ class TestSafeJson:
|
||||
assert "C:\\temp\\file" in str(file_path_with_null)
|
||||
assert ".txt" in str(file_path_with_null)
|
||||
assert "\x00" not in str(file_path_with_null) # Null removed from path
|
||||
|
||||
def test_invalid_escape_error_prevention(self):
|
||||
"""Test that SafeJson prevents 'Invalid \\escape' errors that occurred in upsert_execution_output."""
|
||||
# This reproduces the exact scenario that was causing the error:
|
||||
# POST /upsert_execution_output failed: Invalid \escape: line 1 column 36404 (char 36403)
|
||||
|
||||
# Create data with various problematic escape sequences that could cause JSON parsing errors
|
||||
problematic_output_data = {
|
||||
"web_content": "Article text\x00with null\x01and control\x08chars\x0C\x1F\x7F",
|
||||
"file_path": "C:\\Users\\test\\file\x00.txt",
|
||||
"json_like_string": '{"text": "data\x00\x08\x1F"}',
|
||||
"escaped_sequences": "Text with \\u0000 and \\u0008 sequences",
|
||||
"mixed_content": "Normal text\tproperly\nformatted\rwith\x00invalid\x08chars\x1Fmixed",
|
||||
"large_text": "A" * 35000
|
||||
+ "\x00\x08\x1F"
|
||||
+ "B" * 5000, # Large text like in the error
|
||||
}
|
||||
|
||||
# This should not raise any JSON parsing errors
|
||||
result = SafeJson(problematic_output_data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
# Verify the result is a valid Json object that can be safely stored in PostgreSQL
|
||||
result_data = cast(dict[str, Any], result.data)
|
||||
assert isinstance(result_data, dict)
|
||||
|
||||
# Verify problematic characters are removed but safe content preserved
|
||||
web_content = result_data.get("web_content", "")
|
||||
file_path = result_data.get("file_path", "")
|
||||
large_text = result_data.get("large_text", "")
|
||||
|
||||
# Check that control characters are removed
|
||||
assert "\x00" not in str(web_content)
|
||||
assert "\x01" not in str(web_content)
|
||||
assert "\x08" not in str(web_content)
|
||||
assert "\x0C" not in str(web_content)
|
||||
assert "\x1F" not in str(web_content)
|
||||
assert "\x7F" not in str(web_content)
|
||||
|
||||
# Check that legitimate content is preserved
|
||||
assert "Article text" in str(web_content)
|
||||
assert "with null" in str(web_content)
|
||||
assert "and control" in str(web_content)
|
||||
assert "chars" in str(web_content)
|
||||
|
||||
# Check file path handling
|
||||
assert "C:\\Users\\test\\file" in str(file_path)
|
||||
assert ".txt" in str(file_path)
|
||||
assert "\x00" not in str(file_path)
|
||||
|
||||
# Check large text handling (the scenario from the error at char 36403)
|
||||
assert len(str(large_text)) > 35000 # Content preserved
|
||||
assert "A" * 1000 in str(large_text) # A's preserved
|
||||
assert "B" * 1000 in str(large_text) # B's preserved
|
||||
assert "\x00" not in str(large_text) # Control chars removed
|
||||
assert "\x08" not in str(large_text)
|
||||
assert "\x1F" not in str(large_text)
|
||||
|
||||
# Most importantly: ensure the result can be JSON-serialized without errors
|
||||
# This would have failed with the old approach
|
||||
import json
|
||||
|
||||
json_string = json.dumps(result.data) # Should not raise "Invalid \escape"
|
||||
assert len(json_string) > 0
|
||||
|
||||
# And can be parsed back
|
||||
parsed_back = json.loads(json_string)
|
||||
assert isinstance(parsed_back, dict)
|
||||
|
||||
@@ -749,10 +749,11 @@ class TestDataCreator:
|
||||
"""Add credits to users."""
|
||||
print("Adding credits to users...")
|
||||
|
||||
credit_model = get_user_credit_model()
|
||||
|
||||
for user in self.users:
|
||||
try:
|
||||
# Get user-specific credit model
|
||||
credit_model = await get_user_credit_model(user["id"])
|
||||
|
||||
# Skip credits for disabled credit model to avoid errors
|
||||
if (
|
||||
hasattr(credit_model, "__class__")
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
NEXT_PUBLIC_LAUNCHDARKLY_ENABLED=false
|
||||
NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID=687ab1372f497809b131e06e
|
||||
|
||||
NEXT_PUBLIC_SHOW_BILLING_PAGE=false
|
||||
NEXT_PUBLIC_TURNSTILE=disabled
|
||||
NEXT_PUBLIC_REACT_QUERY_DEVTOOL=true
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import WalletRefill from "./components/WalletRefill";
|
||||
import { OnboardingStep } from "@/lib/autogpt-server-api";
|
||||
import { storage, Key as StorageKey } from "@/services/storage/local-storage";
|
||||
import { WalletIcon } from "@phosphor-icons/react";
|
||||
import { useGetFlag, Flag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
export interface Task {
|
||||
id: OnboardingStep;
|
||||
@@ -40,6 +41,7 @@ export interface TaskGroup {
|
||||
|
||||
export default function Wallet() {
|
||||
const { state, updateState } = useOnboarding();
|
||||
const isPaymentEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
|
||||
|
||||
const groups = useMemo<TaskGroup[]>(() => {
|
||||
return [
|
||||
@@ -379,9 +381,7 @@ export default function Wallet() {
|
||||
</div>
|
||||
<ScrollArea className="max-h-[85vh] overflow-y-auto">
|
||||
{/* Top ups */}
|
||||
{process.env.NEXT_PUBLIC_SHOW_BILLING_PAGE === "true" && (
|
||||
<WalletRefill />
|
||||
)}
|
||||
{isPaymentEnabled && <WalletRefill />}
|
||||
{/* Tasks */}
|
||||
<p className="mx-1 my-3 font-sans text-xs font-normal text-zinc-400">
|
||||
Complete the following tasks to earn more credits!
|
||||
|
||||
@@ -5,7 +5,7 @@ import { useFlow } from "./useFlow";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { useMemo } from "react";
|
||||
import { CustomNode } from "../nodes/CustomNode";
|
||||
import { CustomNode } from "../nodes/CustomNode/CustomNode";
|
||||
import { useCustomEdge } from "../edges/useCustomEdge";
|
||||
import { GraphLoadingBox } from "./GraphLoadingBox";
|
||||
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
import React from "react";
|
||||
import { Node as XYNode, NodeProps } from "@xyflow/react";
|
||||
import { FormCreator } from "./FormCreator";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import { preprocessInputSchema } from "../processors/input-schema-pre-processor";
|
||||
import { OutputHandler } from "./OutputHandler";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { BlockUIType } from "../../types";
|
||||
import { StickyNoteBlock } from "./StickyNoteBlock";
|
||||
|
||||
export type CustomNodeData = {
|
||||
hardcodedValues: {
|
||||
[key: string]: any;
|
||||
};
|
||||
title: string;
|
||||
block_id: string;
|
||||
description: string;
|
||||
inputSchema: RJSFSchema;
|
||||
outputSchema: RJSFSchema;
|
||||
uiType: BlockUIType;
|
||||
};
|
||||
|
||||
export type CustomNode = XYNode<CustomNodeData, "custom">;
|
||||
|
||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
({ data, id: nodeId, selected }) => {
|
||||
const showAdvanced = useNodeStore(
|
||||
(state) => state.nodeAdvancedStates[nodeId] || false,
|
||||
);
|
||||
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
|
||||
|
||||
if (data.uiType === BlockUIType.NOTE) {
|
||||
return <StickyNoteBlock selected={selected} data={data} id={nodeId} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"z-12 rounded-xl bg-gradient-to-br from-white to-slate-50/30 shadow-lg shadow-slate-900/5 ring-1 ring-slate-200/60 backdrop-blur-sm",
|
||||
selected && "shadow-2xl ring-2 ring-slate-200",
|
||||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="flex h-14 items-center justify-center rounded-xl border-b border-slate-200/50 bg-gradient-to-r from-slate-50/80 to-white/90">
|
||||
<Text
|
||||
variant="large-semibold"
|
||||
className="tracking-tight text-slate-800"
|
||||
>
|
||||
{data.title}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{/* Input Handles */}
|
||||
<div className="bg-white/40 pb-6 pr-6">
|
||||
<FormCreator
|
||||
jsonSchema={preprocessInputSchema(data.inputSchema)}
|
||||
nodeId={nodeId}
|
||||
uiType={data.uiType}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Advanced Button */}
|
||||
<div className="flex items-center justify-between gap-2 rounded-b-xl border-t border-slate-200/50 bg-gradient-to-r from-slate-50/60 to-white/80 px-5 py-3.5">
|
||||
<Text variant="body" className="font-medium text-slate-700">
|
||||
Advanced
|
||||
</Text>
|
||||
<Switch
|
||||
onCheckedChange={(checked) => setShowAdvanced(nodeId, checked)}
|
||||
checked={showAdvanced}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Output Handles */}
|
||||
<OutputHandler outputSchema={data.outputSchema} nodeId={nodeId} />
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
CustomNode.displayName = "CustomNode";
|
||||
@@ -0,0 +1,45 @@
|
||||
import React from "react";
|
||||
import { Node as XYNode, NodeProps } from "@xyflow/react";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { BlockUIType } from "../../../types";
|
||||
import { StickyNoteBlock } from "./StickyNoteBlock";
|
||||
import { BlockInfoCategoriesItem } from "@/app/api/__generated__/models/blockInfoCategoriesItem";
|
||||
import { StandardNodeBlock } from "./StandardNodeBlock";
|
||||
import { BlockCost } from "@/app/api/__generated__/models/blockCost";
|
||||
|
||||
export type CustomNodeData = {
|
||||
hardcodedValues: {
|
||||
[key: string]: any;
|
||||
};
|
||||
title: string;
|
||||
description: string;
|
||||
inputSchema: RJSFSchema;
|
||||
outputSchema: RJSFSchema;
|
||||
uiType: BlockUIType;
|
||||
block_id: string;
|
||||
// TODO : We need better type safety for the following backend fields.
|
||||
costs: BlockCost[];
|
||||
categories: BlockInfoCategoriesItem[];
|
||||
};
|
||||
|
||||
export type CustomNode = XYNode<CustomNodeData, "custom">;
|
||||
|
||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
({ data, id: nodeId, selected }) => {
|
||||
if (data.uiType === BlockUIType.NOTE) {
|
||||
return <StickyNoteBlock selected={selected} data={data} id={nodeId} />;
|
||||
}
|
||||
|
||||
if (data.uiType === BlockUIType.STANDARD) {
|
||||
return (
|
||||
<StandardNodeBlock data={data} selected={selected} nodeId={nodeId} />
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<StandardNodeBlock data={data} selected={selected} nodeId={nodeId} />
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
CustomNode.displayName = "CustomNode";
|
||||
@@ -0,0 +1,79 @@
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
import { CustomNodeData } from "./CustomNode";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { FormCreator } from "../FormCreator";
|
||||
import { preprocessInputSchema } from "../../processors/input-schema-pre-processor";
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { OutputHandler } from "../OutputHandler";
|
||||
import { NodeCost } from "./components/NodeCost";
|
||||
import { NodeBadges } from "./components/NodeBadges";
|
||||
|
||||
type StandardNodeBlockType = {
|
||||
data: CustomNodeData;
|
||||
selected: boolean;
|
||||
nodeId: string;
|
||||
};
|
||||
export const StandardNodeBlock = ({
|
||||
data,
|
||||
selected,
|
||||
nodeId,
|
||||
}: StandardNodeBlockType) => {
|
||||
const showAdvanced = useNodeStore(
|
||||
(state) => state.nodeAdvancedStates[nodeId] || false,
|
||||
);
|
||||
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"z-12 rounded-xl bg-gradient-to-br from-white to-slate-50/30 shadow-lg shadow-slate-900/5 ring-1 ring-slate-200/60 backdrop-blur-sm",
|
||||
selected && "shadow-2xl ring-2 ring-slate-200",
|
||||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="flex h-auto flex-col gap-2 rounded-xl border-b border-slate-200/50 bg-gradient-to-r from-slate-50/80 to-white/90 px-4 py-4">
|
||||
{/* Upper section */}
|
||||
<div className="flex items-center gap-2">
|
||||
<Text
|
||||
variant="large-semibold"
|
||||
className="tracking-tight text-slate-800"
|
||||
>
|
||||
{beautifyString(data.title)}
|
||||
</Text>
|
||||
<Text variant="small" className="!font-medium !text-slate-500">
|
||||
#{nodeId.split("-")[0]}
|
||||
</Text>
|
||||
</div>
|
||||
{/* Lower section */}
|
||||
<div className="flex space-x-2">
|
||||
<NodeCost blockCosts={data.costs} nodeId={nodeId} />
|
||||
<NodeBadges categories={data.categories} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Input Handles */}
|
||||
<div className="bg-white/40 pb-6 pr-6">
|
||||
<FormCreator
|
||||
jsonSchema={preprocessInputSchema(data.inputSchema)}
|
||||
nodeId={nodeId}
|
||||
uiType={data.uiType}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Advanced Button */}
|
||||
<div className="flex items-center justify-between gap-2 rounded-b-xl border-t border-slate-200/50 bg-gradient-to-r from-slate-50/60 to-white/80 px-5 py-3.5">
|
||||
<Text variant="body" className="font-medium text-slate-700">
|
||||
Advanced
|
||||
</Text>
|
||||
<Switch
|
||||
onCheckedChange={(checked) => setShowAdvanced(nodeId, checked)}
|
||||
checked={showAdvanced}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Output Handles */}
|
||||
<OutputHandler outputSchema={data.outputSchema} nodeId={nodeId} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useMemo } from "react";
|
||||
import { FormCreator } from "./FormCreator";
|
||||
import { preprocessInputSchema } from "../processors/input-schema-pre-processor";
|
||||
import { FormCreator } from "../FormCreator";
|
||||
import { preprocessInputSchema } from "../../processors/input-schema-pre-processor";
|
||||
import { CustomNodeData } from "./CustomNode";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
@@ -0,0 +1,20 @@
|
||||
import { BlockInfoCategoriesItem } from "@/app/api/__generated__/models/blockInfoCategoriesItem";
|
||||
import { Badge } from "@/components/__legacy__/ui/badge";
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
|
||||
export const NodeBadges = ({
|
||||
categories,
|
||||
}: {
|
||||
categories: BlockInfoCategoriesItem[];
|
||||
}) => {
|
||||
return categories.map((category) => (
|
||||
<Badge
|
||||
key={category.category}
|
||||
className={cn(
|
||||
"rounded-full border border-slate-500 bg-slate-100 text-black shadow-none",
|
||||
)}
|
||||
>
|
||||
{beautifyString(category.category.toLowerCase())}
|
||||
</Badge>
|
||||
));
|
||||
};
|
||||
@@ -0,0 +1,39 @@
|
||||
import { BlockCost } from "@/app/api/__generated__/models/blockCost";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
import { CoinIcon } from "@phosphor-icons/react";
|
||||
import { isCostFilterMatch } from "../../../../helper";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
|
||||
export const NodeCost = ({
|
||||
blockCosts,
|
||||
nodeId,
|
||||
}: {
|
||||
blockCosts: BlockCost[];
|
||||
nodeId: string;
|
||||
}) => {
|
||||
const { formatCredits } = useCredits();
|
||||
const hardcodedValues = useNodeStore((state) =>
|
||||
state.getHardCodedValues(nodeId),
|
||||
);
|
||||
const blockCost =
|
||||
blockCosts &&
|
||||
blockCosts.find((cost) =>
|
||||
isCostFilterMatch(cost.cost_filter, hardcodedValues),
|
||||
);
|
||||
|
||||
if (!blockCost) return null;
|
||||
|
||||
return (
|
||||
<div className="mr-3 flex items-center gap-1 text-base font-light">
|
||||
<CoinIcon className="h-3 w-3" />
|
||||
<Text variant="small" className="!font-medium">
|
||||
{formatCredits(blockCost.cost_amount)}
|
||||
</Text>
|
||||
<Text variant="small">
|
||||
{" \/"}
|
||||
{blockCost.cost_type}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -8,11 +8,11 @@ import { GraphExecutionID } from "@/lib/autogpt-server-api";
|
||||
// import { ControlPanelButton } from "../ControlPanelButton";
|
||||
import { ArrowUUpLeftIcon, ArrowUUpRightIcon } from "@phosphor-icons/react";
|
||||
// import { GraphSearchMenu } from "../GraphMenu/GraphMenu";
|
||||
import { CustomNode } from "../FlowEditor/nodes/CustomNode";
|
||||
import { history } from "@/app/(platform)/build/components/legacy-builder/history";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import { NewSaveControl } from "./NewSaveControl/NewSaveControl";
|
||||
import { CustomNode } from "../FlowEditor/nodes/CustomNode/CustomNode";
|
||||
|
||||
export type Control = {
|
||||
icon: React.ReactNode;
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
import { CustomNode, CustomNodeData } from "./FlowEditor/nodes/CustomNode";
|
||||
import {
|
||||
CustomNode,
|
||||
CustomNodeData,
|
||||
} from "./FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import { BlockUIType } from "./types";
|
||||
import { NodeModel } from "@/app/api/__generated__/models/nodeModel";
|
||||
import { NodeModelMetadata } from "@/app/api/__generated__/models/nodeModelMetadata";
|
||||
@@ -14,8 +17,10 @@ export const convertBlockInfoIntoCustomNodeData = (
|
||||
description: block.description,
|
||||
inputSchema: block.inputSchema,
|
||||
outputSchema: block.outputSchema,
|
||||
categories: block.categories,
|
||||
uiType: block.uiType as BlockUIType,
|
||||
block_id: block.id,
|
||||
costs: block.costs,
|
||||
};
|
||||
return customNodeData;
|
||||
};
|
||||
@@ -51,3 +56,38 @@ export const convertNodesPlusBlockInfoIntoCustomNodes = (
|
||||
};
|
||||
return customNode;
|
||||
};
|
||||
|
||||
export enum BlockCategory {
|
||||
AI = "AI",
|
||||
SOCIAL = "SOCIAL",
|
||||
TEXT = "TEXT",
|
||||
SEARCH = "SEARCH",
|
||||
BASIC = "BASIC",
|
||||
INPUT = "INPUT",
|
||||
OUTPUT = "OUTPUT",
|
||||
LOGIC = "LOGIC",
|
||||
COMMUNICATION = "COMMUNICATION",
|
||||
DEVELOPER_TOOLS = "DEVELOPER_TOOLS",
|
||||
DATA = "DATA",
|
||||
HARDWARE = "HARDWARE",
|
||||
AGENT = "AGENT",
|
||||
CRM = "CRM",
|
||||
SAFETY = "SAFETY",
|
||||
PRODUCTIVITY = "PRODUCTIVITY",
|
||||
ISSUE_TRACKING = "ISSUE_TRACKING",
|
||||
MULTIMEDIA = "MULTIMEDIA",
|
||||
MARKETING = "MARKETING",
|
||||
}
|
||||
|
||||
// Cost related helpers
|
||||
export const isCostFilterMatch = (
|
||||
costFilter: any,
|
||||
inputValues: any,
|
||||
): boolean => {
|
||||
return typeof costFilter === "object" && typeof inputValues === "object"
|
||||
? Object.entries(costFilter).every(
|
||||
([k, v]) =>
|
||||
(!v && !inputValues[k]) || isCostFilterMatch(v, inputValues[k]),
|
||||
)
|
||||
: costFilter === inputValues;
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { create } from "zustand";
|
||||
import { NodeChange, applyNodeChanges } from "@xyflow/react";
|
||||
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode";
|
||||
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
import { convertBlockInfoIntoCustomNodeData } from "../components/helper";
|
||||
import { Node } from "@/app/api/__generated__/models/node";
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import {
|
||||
@@ -8,8 +10,11 @@ import {
|
||||
IconCoin,
|
||||
} from "@/components/__legacy__/ui/icons";
|
||||
import { KeyIcon } from "lucide-react";
|
||||
import { useGetFlag, Flag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
export default function Layout({ children }: { children: React.ReactNode }) {
|
||||
const isPaymentEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
|
||||
|
||||
const sidebarLinkGroups = [
|
||||
{
|
||||
links: [
|
||||
@@ -18,7 +23,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
|
||||
href: "/profile/dashboard",
|
||||
icon: <IconDashboardLayout className="h-6 w-6" />,
|
||||
},
|
||||
...(process.env.NEXT_PUBLIC_SHOW_BILLING_PAGE === "true"
|
||||
...(isPaymentEnabled
|
||||
? [
|
||||
{
|
||||
text: "Billing",
|
||||
|
||||
@@ -15,6 +15,7 @@ export enum Flag {
|
||||
SHARE_EXECUTION_RESULTS = "share-execution-results",
|
||||
AGENT_FAVORITING = "agent-favoriting",
|
||||
MARKETPLACE_SEARCH_TERMS = "marketplace-search-terms",
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment",
|
||||
}
|
||||
|
||||
export type FlagValues = {
|
||||
@@ -28,6 +29,7 @@ export type FlagValues = {
|
||||
[Flag.SHARE_EXECUTION_RESULTS]: boolean;
|
||||
[Flag.AGENT_FAVORITING]: boolean;
|
||||
[Flag.MARKETPLACE_SEARCH_TERMS]: string[];
|
||||
[Flag.ENABLE_PLATFORM_PAYMENT]: boolean;
|
||||
};
|
||||
|
||||
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
|
||||
@@ -43,6 +45,7 @@ const mockFlags = {
|
||||
[Flag.SHARE_EXECUTION_RESULTS]: false,
|
||||
[Flag.AGENT_FAVORITING]: false,
|
||||
[Flag.MARKETPLACE_SEARCH_TERMS]: DEFAULT_SEARCH_TERMS,
|
||||
[Flag.ENABLE_PLATFORM_PAYMENT]: false,
|
||||
};
|
||||
|
||||
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null {
|
||||
@@ -50,7 +53,9 @@ export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null {
|
||||
const flagValue = currentFlags[flag];
|
||||
const isCloud = getBehaveAs() === BehaveAs.CLOUD;
|
||||
|
||||
if (isPwMockEnabled && !isCloud) return mockFlags[flag];
|
||||
if ((isPwMockEnabled && !isCloud) || flagValue === undefined) {
|
||||
return mockFlags[flag];
|
||||
}
|
||||
|
||||
return flagValue;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user