mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 09:08:02 -05:00
Compare commits
3 Commits
swiftyos/s
...
fix/databa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d42e144322 | ||
|
|
4f9ff74d02 | ||
|
|
cca84fabe3 |
3
.github/workflows/platform-frontend-ci.yml
vendored
3
.github/workflows/platform-frontend-ci.yml
vendored
@@ -217,9 +217,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate API client
|
||||
run: pnpm generate:api
|
||||
|
||||
- name: Install Browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
[pr_reviewer]
|
||||
num_code_suggestions=0
|
||||
|
||||
[pr_code_suggestions]
|
||||
commitable_code_suggestions=false
|
||||
num_code_suggestions=0
|
||||
|
||||
@@ -10,7 +10,7 @@ from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
|
||||
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
@@ -20,9 +20,7 @@ async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
async def requires_admin_user(
|
||||
jwt_payload: dict = fastapi.Security(get_jwt_payload),
|
||||
) -> User:
|
||||
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
@@ -32,7 +30,7 @@ async def requires_admin_user(
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
async def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestAuthDependencies:
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
@@ -53,12 +53,12 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_user(jwt_payload)
|
||||
user = requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
@@ -69,28 +69,28 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_user(jwt_payload)
|
||||
user = requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
async def test_requires_user_missing_sub(self):
|
||||
def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_user(jwt_payload)
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
async def test_requires_user_empty_sub(self):
|
||||
def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_user(jwt_payload)
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
@@ -101,51 +101,51 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_admin_user(jwt_payload)
|
||||
user = requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
async def test_requires_admin_user_with_regular_user(self):
|
||||
def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_admin_user(jwt_payload)
|
||||
requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
async def test_requires_admin_user_missing_role(self):
|
||||
def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
await requires_admin_user(jwt_payload)
|
||||
requires_admin_user(jwt_payload)
|
||||
|
||||
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = await get_user_id(jwt_payload)
|
||||
user_id = get_user_id(jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
async def test_get_user_id_missing_sub(self):
|
||||
def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(jwt_payload)
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
async def test_get_user_id_none_sub(self):
|
||||
def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(jwt_payload)
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestAuthDependenciesIntegration:
|
||||
|
||||
return _create_token
|
||||
|
||||
async def test_endpoint_auth_enabled_no_token(self):
|
||||
def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -184,7 +184,7 @@ class TestAuthDependenciesIntegration:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_endpoint_with_valid_token(self, create_token):
|
||||
def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -203,7 +203,7 @@ class TestAuthDependenciesIntegration:
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
async def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -240,7 +240,7 @@ class TestAuthDependenciesIntegration:
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
async def test_dependency_with_complex_payload(self):
|
||||
def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -256,14 +256,14 @@ class TestAuthDependenciesEdgeCases:
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = await requires_user(complex_payload)
|
||||
user = requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = await requires_admin_user(complex_payload)
|
||||
admin = requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
async def test_dependency_with_unicode_in_payload(self):
|
||||
def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
@@ -272,11 +272,11 @@ class TestAuthDependenciesEdgeCases:
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = await requires_user(unicode_payload)
|
||||
user = requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
async def test_dependency_with_null_values(self):
|
||||
def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -286,18 +286,18 @@ class TestAuthDependenciesEdgeCases:
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = await requires_user(null_payload)
|
||||
user = requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
async def test_concurrent_requests_isolation(self):
|
||||
def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = await requires_user(payload1)
|
||||
user2 = await requires_admin_user(payload2)
|
||||
user1 = requires_user(payload1)
|
||||
user2 = requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
@@ -314,7 +314,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
async def test_dependency_error_cases(
|
||||
def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
@@ -325,7 +325,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
async def test_dependency_valid_user(self):
|
||||
def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
@@ -16,7 +16,7 @@ bearer_jwt_auth = HTTPBearer(
|
||||
)
|
||||
|
||||
|
||||
async def get_jwt_payload(
|
||||
def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -116,32 +116,32 @@ def test_parse_jwt_token_missing_audience():
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_get_jwt_payload_with_valid_token():
|
||||
def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = await jwt_utils.get_jwt_payload(credentials)
|
||||
result = jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
async def test_get_jwt_payload_no_credentials():
|
||||
def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await jwt_utils.get_jwt_payload(None)
|
||||
jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
async def test_get_jwt_payload_invalid_token():
|
||||
def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await jwt_utils.get_jwt_payload(credentials)
|
||||
jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
@@ -140,13 +139,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
print(f"Log directory: {config.log_dir}")
|
||||
|
||||
# Activity log handler (INFO and above)
|
||||
# Security fix: Use RotatingFileHandler with size limits to prevent disk exhaustion
|
||||
activity_log_handler = RotatingFileHandler(
|
||||
config.log_dir / LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
activity_log_handler = logging.FileHandler(
|
||||
config.log_dir / LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
activity_log_handler.setLevel(config.level)
|
||||
activity_log_handler.setFormatter(
|
||||
@@ -156,13 +150,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
if config.level == logging.DEBUG:
|
||||
# Debug log handler (all levels)
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
debug_log_handler = RotatingFileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
debug_log_handler = logging.FileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
debug_log_handler.setLevel(logging.DEBUG)
|
||||
debug_log_handler.setFormatter(
|
||||
@@ -171,13 +160,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
log_handlers.append(debug_log_handler)
|
||||
|
||||
# Error log handler (ERROR and above)
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
error_log_handler = RotatingFileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
error_log_handler = logging.FileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
|
||||
|
||||
328
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py
Normal file
328
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py
Normal file
@@ -0,0 +1,328 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
cast,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_hashable_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Convert args and kwargs into a hashable cache key.
|
||||
|
||||
Handles unhashable types like dict, list, set by converting them to
|
||||
their sorted string representations.
|
||||
"""
|
||||
|
||||
def make_hashable(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a hashable representation."""
|
||||
if isinstance(obj, dict):
|
||||
# Sort dict items to ensure consistent ordering
|
||||
return (
|
||||
"__dict__",
|
||||
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
|
||||
)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return ("__list__", tuple(make_hashable(item) for item in obj))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# Handle objects with __dict__ attribute
|
||||
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
|
||||
else:
|
||||
# For basic hashable types (str, int, bool, None, etc.)
|
||||
try:
|
||||
hash(obj)
|
||||
return obj
|
||||
except TypeError:
|
||||
# Fallback: convert to string representation
|
||||
return ("__str__", str(obj))
|
||||
|
||||
hashable_args = tuple(make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
|
||||
return (hashable_args, hashable_kwargs)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for cached functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
|
||||
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
|
||||
return False
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def cached(
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> Callable[[Callable], CachedFunction]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
|
||||
Uses double-checked locking to prevent multiple threads/coroutines from
|
||||
executing the expensive operation simultaneously during cache misses.
|
||||
|
||||
Args:
|
||||
func: The function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
@cache() # Default: maxsize=128, no TTL
|
||||
def expensive_sync_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cache() # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cache(maxsize=1000, ttl_seconds=300) # Custom maxsize and TTL
|
||||
def another_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(target_func):
|
||||
# Cache storage and locks
|
||||
cache_storage = {}
|
||||
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
# Async function with asyncio.Lock
|
||||
cache_lock = asyncio.Lock()
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with cache_lock:
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = async_wrapper
|
||||
|
||||
else:
|
||||
# Sync function with threading.Lock
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Add cache management methods
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
def cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
setattr(wrapper, "cache_delete", cache_delete)
|
||||
|
||||
return cast(CachedFunction, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def thread_cached(func):
|
||||
"""
|
||||
Thread-local cache decorator for both sync and async functions.
|
||||
|
||||
Each thread gets its own cache, which is useful for request-scoped caching
|
||||
in web applications where you want to cache within a single request but
|
||||
not across requests.
|
||||
|
||||
Args:
|
||||
func: The function to cache
|
||||
|
||||
Returns:
|
||||
Decorated function with thread-local caching
|
||||
|
||||
Example:
|
||||
@thread_cached
|
||||
def expensive_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@thread_cached # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = await func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
"""Clear thread-local cache for a function."""
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
@@ -12,11 +12,11 @@ import asyncio
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.cache import cached, clear_thread_cache, thread_cached
|
||||
from autogpt_libs.utils.cache import cached, clear_thread_cache, thread_cached
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
@@ -332,7 +332,7 @@ class TestCache:
|
||||
"""Test basic sync caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
@cached()
|
||||
def expensive_sync_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -358,7 +358,7 @@ class TestCache:
|
||||
"""Test basic async caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
@cached()
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -385,7 +385,7 @@ class TestCache:
|
||||
call_count = 0
|
||||
results = []
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
@cached()
|
||||
def slow_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -412,7 +412,7 @@ class TestCache:
|
||||
"""Test that concurrent async calls don't cause thundering herd."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
@cached()
|
||||
async def slow_async_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -508,7 +508,7 @@ class TestCache:
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
@cached()
|
||||
def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -537,7 +537,7 @@ class TestCache:
|
||||
"""Test cache clearing functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
@cached()
|
||||
async def async_clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -567,7 +567,7 @@ class TestCache:
|
||||
"""Test that cached async functions return actual results, not coroutines."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
@cached()
|
||||
async def async_result_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -593,7 +593,7 @@ class TestCache:
|
||||
"""Test selective cache deletion functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
@cached()
|
||||
def deletable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -636,7 +636,7 @@ class TestCache:
|
||||
"""Test selective cache deletion functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
@cached()
|
||||
async def async_deletable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -674,333 +674,3 @@ class TestCache:
|
||||
# Try to delete non-existent entry
|
||||
was_deleted = async_deletable_function.cache_delete(99)
|
||||
assert was_deleted is False
|
||||
|
||||
|
||||
class TestSharedCache:
|
||||
"""Tests for shared_cache functionality using Redis."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_redis_mock(self):
|
||||
"""Mock Redis client for testing."""
|
||||
with patch("backend.util.cache._get_redis_client") as mock_redis_func:
|
||||
# Configure mock to behave like Redis
|
||||
mock_redis = Mock()
|
||||
self.mock_redis = mock_redis
|
||||
self.redis_storage = {}
|
||||
|
||||
def mock_get(key):
|
||||
return self.redis_storage.get(key)
|
||||
|
||||
def mock_getex(key, ex=None):
|
||||
# GETEX returns value and optionally refreshes TTL
|
||||
return self.redis_storage.get(key)
|
||||
|
||||
def mock_set(key, value):
|
||||
self.redis_storage[key] = value
|
||||
return True
|
||||
|
||||
def mock_setex(key, ttl, value):
|
||||
self.redis_storage[key] = value
|
||||
return True
|
||||
|
||||
def mock_exists(key):
|
||||
return 1 if key in self.redis_storage else 0
|
||||
|
||||
def mock_delete(key):
|
||||
if key in self.redis_storage:
|
||||
del self.redis_storage[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def mock_scan_iter(pattern, count=None):
|
||||
# Pattern is a string like "cache:*", keys in storage are strings
|
||||
prefix = pattern.rstrip("*")
|
||||
return [
|
||||
k
|
||||
for k in self.redis_storage.keys()
|
||||
if isinstance(k, str) and k.startswith(prefix)
|
||||
]
|
||||
|
||||
def mock_pipeline():
|
||||
pipe = Mock()
|
||||
deleted_keys = []
|
||||
|
||||
def pipe_delete(key):
|
||||
deleted_keys.append(key)
|
||||
return pipe
|
||||
|
||||
def pipe_execute():
|
||||
# Actually delete the keys when pipeline executes
|
||||
for key in deleted_keys:
|
||||
self.redis_storage.pop(key, None)
|
||||
deleted_keys.clear()
|
||||
return []
|
||||
|
||||
pipe.delete = Mock(side_effect=pipe_delete)
|
||||
pipe.execute = Mock(side_effect=pipe_execute)
|
||||
return pipe
|
||||
|
||||
mock_redis.get = Mock(side_effect=mock_get)
|
||||
mock_redis.getex = Mock(side_effect=mock_getex)
|
||||
mock_redis.set = Mock(side_effect=mock_set)
|
||||
mock_redis.setex = Mock(side_effect=mock_setex)
|
||||
mock_redis.exists = Mock(side_effect=mock_exists)
|
||||
mock_redis.delete = Mock(side_effect=mock_delete)
|
||||
mock_redis.scan_iter = Mock(side_effect=mock_scan_iter)
|
||||
mock_redis.pipeline = Mock(side_effect=mock_pipeline)
|
||||
|
||||
# Make _get_redis_client return the mock
|
||||
mock_redis_func.return_value = mock_redis
|
||||
|
||||
yield mock_redis
|
||||
|
||||
# Cleanup
|
||||
self.redis_storage.clear()
|
||||
|
||||
def test_sync_shared_cache_basic(self):
|
||||
"""Test basic shared cache functionality with sync function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(shared_cache=True, ttl_seconds=300)
|
||||
def shared_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 10
|
||||
|
||||
# First call - should miss cache
|
||||
result1 = shared_function(5)
|
||||
assert result1 == 50
|
||||
assert call_count == 1
|
||||
assert self.mock_redis.get.called
|
||||
assert self.mock_redis.setex.called # setex is used for TTL
|
||||
|
||||
# Second call - should hit cache
|
||||
result2 = shared_function(5)
|
||||
assert result2 == 50
|
||||
assert call_count == 1 # Function not called again
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_shared_cache_basic(self):
|
||||
"""Test basic shared cache functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(shared_cache=True, ttl_seconds=300)
|
||||
async def async_shared_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 20
|
||||
|
||||
# First call - should miss cache
|
||||
result1 = await async_shared_function(3)
|
||||
assert result1 == 60
|
||||
assert call_count == 1
|
||||
assert self.mock_redis.get.called
|
||||
assert self.mock_redis.setex.called # setex is used for TTL
|
||||
|
||||
# Second call - should hit cache
|
||||
result2 = await async_shared_function(3)
|
||||
assert result2 == 60
|
||||
assert call_count == 1 # Function not called again
|
||||
|
||||
def test_sync_shared_cache_with_ttl(self):
|
||||
"""Test shared cache with TTL using sync function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(shared_cache=True, ttl_seconds=60)
|
||||
def shared_ttl_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 30
|
||||
|
||||
# First call
|
||||
result1 = shared_ttl_function(2)
|
||||
assert result1 == 60
|
||||
assert call_count == 1
|
||||
assert self.mock_redis.setex.called
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = shared_ttl_function(2)
|
||||
assert result2 == 60
|
||||
assert call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_shared_cache_with_ttl(self):
|
||||
"""Test shared cache with TTL using async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(shared_cache=True, ttl_seconds=120)
|
||||
async def async_shared_ttl_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 40
|
||||
|
||||
# First call
|
||||
result1 = await async_shared_ttl_function(4)
|
||||
assert result1 == 160
|
||||
assert call_count == 1
|
||||
assert self.mock_redis.setex.called
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await async_shared_ttl_function(4)
|
||||
assert result2 == 160
|
||||
assert call_count == 1
|
||||
|
||||
def test_shared_cache_clear(self):
|
||||
"""Test clearing shared cache."""
|
||||
call_count = 0
|
||||
|
||||
@cached(shared_cache=True, ttl_seconds=300)
|
||||
def clearable_shared_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 50
|
||||
|
||||
# First call
|
||||
result1 = clearable_shared_function(1)
|
||||
assert result1 == 50
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = clearable_shared_function(1)
|
||||
assert result2 == 50
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_shared_function.cache_clear()
|
||||
assert self.mock_redis.pipeline.called
|
||||
|
||||
# Third call - should execute function again
|
||||
result3 = clearable_shared_function(1)
|
||||
assert result3 == 50
|
||||
assert call_count == 2
|
||||
|
||||
def test_shared_cache_delete(self):
|
||||
"""Test deleting specific shared cache entry."""
|
||||
call_count = 0
|
||||
|
||||
@cached(shared_cache=True, ttl_seconds=300)
|
||||
def deletable_shared_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 60
|
||||
|
||||
# First call for x=1
|
||||
result1 = deletable_shared_function(1)
|
||||
assert result1 == 60
|
||||
assert call_count == 1
|
||||
|
||||
# First call for x=2
|
||||
result2 = deletable_shared_function(2)
|
||||
assert result2 == 120
|
||||
assert call_count == 2
|
||||
|
||||
# Delete entry for x=1
|
||||
was_deleted = deletable_shared_function.cache_delete(1)
|
||||
assert was_deleted is True
|
||||
|
||||
# Call with x=1 should execute function again
|
||||
result3 = deletable_shared_function(1)
|
||||
assert result3 == 60
|
||||
assert call_count == 3
|
||||
|
||||
# Call with x=2 should still use cache
|
||||
result4 = deletable_shared_function(2)
|
||||
assert result4 == 120
|
||||
assert call_count == 3
|
||||
|
||||
def test_shared_cache_error_handling(self):
|
||||
"""Test that Redis errors are handled gracefully."""
|
||||
call_count = 0
|
||||
|
||||
@cached(shared_cache=True, ttl_seconds=300)
|
||||
def error_prone_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 70
|
||||
|
||||
# Simulate Redis error
|
||||
self.mock_redis.get.side_effect = Exception("Redis connection error")
|
||||
|
||||
# Function should still work
|
||||
result = error_prone_function(1)
|
||||
assert result == 70
|
||||
assert call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_shared_cache_error_handling(self):
|
||||
"""Test that Redis errors are handled gracefully in async functions."""
|
||||
call_count = 0
|
||||
|
||||
@cached(shared_cache=True, ttl_seconds=300)
|
||||
async def async_error_prone_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 80
|
||||
|
||||
# Simulate Redis error
|
||||
self.mock_redis.get.side_effect = Exception("Redis connection error")
|
||||
|
||||
# Function should still work
|
||||
result = await async_error_prone_function(1)
|
||||
assert result == 80
|
||||
assert call_count == 1
|
||||
|
||||
def test_shared_cache_with_complex_types(self):
|
||||
"""Test shared cache with complex return types (lists, dicts)."""
|
||||
call_count = 0
|
||||
|
||||
@cached(shared_cache=True, ttl_seconds=300)
|
||||
def complex_return_function(x: int) -> dict:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return {"value": x, "squared": x * x, "list": [1, 2, 3]}
|
||||
|
||||
# First call
|
||||
result1 = complex_return_function(5)
|
||||
assert result1 == {"value": 5, "squared": 25, "list": [1, 2, 3]}
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = complex_return_function(5)
|
||||
assert result2 == {"value": 5, "squared": 25, "list": [1, 2, 3]}
|
||||
assert call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_thundering_herd_shared_cache(self):
|
||||
"""Test thundering herd protection with shared cache."""
|
||||
call_count = 0
|
||||
|
||||
@cached(shared_cache=True, ttl_seconds=300)
|
||||
async def slow_shared_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.1)
|
||||
return x * x
|
||||
|
||||
# Launch concurrent coroutines
|
||||
tasks = [slow_shared_function(9) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 81 for result in results)
|
||||
# Only one coroutine should have executed the function
|
||||
assert call_count == 1
|
||||
|
||||
def test_shared_cache_info(self):
|
||||
"""Test cache_info with shared cache."""
|
||||
|
||||
@cached(shared_cache=True, maxsize=100, ttl_seconds=300)
|
||||
def info_function(x: int) -> int:
|
||||
return x * 90
|
||||
|
||||
# Call the function to populate cache
|
||||
info_function(1)
|
||||
|
||||
# Get cache info
|
||||
info = info_function.cache_info()
|
||||
assert "size" in info
|
||||
assert info["maxsize"] is None # Redis manages its own size
|
||||
assert info["ttl_seconds"] == 300
|
||||
18
autogpt_platform/autogpt_libs/poetry.lock
generated
18
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1719,22 +1719,6 @@ files = [
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
strenum = ">=0.4.15,<0.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "9.1.2"
|
||||
description = "Retry code until it succeeds"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138"},
|
||||
{file = "tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
doc = ["reno", "sphinx"]
|
||||
test = ["pytest", "tornado (>=4.5)", "typeguard"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.2.1"
|
||||
@@ -1945,4 +1929,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "5ec9e6cd2ef7524a356586354755215699e7b37b9bbdfbabc9c73b43085915f4"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
|
||||
@@ -19,7 +19,6 @@ pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
tenacity = "^9.1.2"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -5,7 +5,7 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from backend.util.cache import cached
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600) # Cache blocks for 1 hour
|
||||
@cached()
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
)
|
||||
from backend.data.block import BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
An AI-powered condition block that uses natural language to evaluate conditions.
|
||||
|
||||
This block allows users to define conditions in plain English (e.g., "the input is an email address",
|
||||
"the input is a city in the USA") and uses AI to determine if the input satisfies the condition.
|
||||
It provides the same yes/no data pass-through functionality as the standard ConditionBlock.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
input_value: Any = SchemaField(
|
||||
description="The input value to evaluate with the AI condition",
|
||||
placeholder="Enter the value to be evaluated (text, number, or any data)",
|
||||
)
|
||||
condition: str = SchemaField(
|
||||
description="A plaintext English description of the condition to evaluate",
|
||||
placeholder="E.g., 'the input is the body of an email', 'the input is a City in the USA', 'the input is an error or a refusal'",
|
||||
)
|
||||
yes_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is true. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
no_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is false. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the AI condition evaluation (True or False)"
|
||||
)
|
||||
yes_output: Any = SchemaField(
|
||||
description="The output value if the condition is true"
|
||||
)
|
||||
no_output: Any = SchemaField(
|
||||
description="The output value if the condition is false"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the AI evaluation is uncertain or fails"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553ec5b8-6c45-4299-8d75-b394d05f72ff",
|
||||
input_schema=AIConditionBlock.Input,
|
||||
output_schema=AIConditionBlock.Output,
|
||||
description="Uses AI to evaluate natural language conditions and provide conditional outputs",
|
||||
categories={BlockCategory.AI, BlockCategory.LOGIC},
|
||||
test_input={
|
||||
"input_value": "john@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("result", True),
|
||||
("yes_output", "Valid email"),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response="true",
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def llm_call(
|
||||
self,
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list,
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Wrapper method for llm_call to enable mocking in tests."""
|
||||
return await llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
force_json_output=False,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Evaluate the AI condition and return appropriate outputs.
|
||||
"""
|
||||
# Prepare the yes and no values, using input_value as default
|
||||
yes_value = (
|
||||
input_data.yes_value
|
||||
if input_data.yes_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
no_value = (
|
||||
input_data.no_value
|
||||
if input_data.no_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
|
||||
# Convert input_value to string for AI evaluation
|
||||
input_str = str(input_data.input_value)
|
||||
|
||||
# Create the prompt for AI evaluation
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an AI assistant that evaluates conditions based on input data. "
|
||||
"You must respond with only 'true' or 'false' (lowercase) to indicate whether "
|
||||
"the given condition is met by the input value. Be accurate and consider the "
|
||||
"context and meaning of both the input and the condition."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Input value: {input_str}\n"
|
||||
f"Condition to evaluate: {input_data.condition}\n\n"
|
||||
f"Does the input value satisfy the condition? Respond with only 'true' or 'false'."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
@@ -1,10 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
from e2b_code_interpreter import Result as E2BExecutionResult
|
||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||
from pydantic import BaseModel, JsonValue, SecretStr
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
@@ -38,135 +36,14 @@ class ProgrammingLanguage(Enum):
|
||||
JAVA = "java"
|
||||
|
||||
|
||||
class MainCodeExecutionResult(BaseModel):
|
||||
"""
|
||||
*Pydantic model mirroring `e2b_code_interpreter.Result`*
|
||||
|
||||
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
|
||||
The result is similar to the structure returned by ipython kernel: https://ipython.readthedocs.io/en/stable/development/execution.html#execution-semantics
|
||||
|
||||
The result can contain multiple types of data, such as text, images, plots, etc. Each type of data is represented
|
||||
as a string, and the result can contain multiple types of data. The display calls don't have to have text representation,
|
||||
for the actual result the representation is always present for the result, the other representations are always optional.
|
||||
""" # noqa
|
||||
|
||||
class Chart(BaseModel, E2BExecutionResultChart):
|
||||
pass
|
||||
|
||||
text: Optional[str] = None
|
||||
html: Optional[str] = None
|
||||
markdown: Optional[str] = None
|
||||
svg: Optional[str] = None
|
||||
png: Optional[str] = None
|
||||
jpeg: Optional[str] = None
|
||||
pdf: Optional[str] = None
|
||||
latex: Optional[str] = None
|
||||
json: Optional[JsonValue] = None # type: ignore (reportIncompatibleMethodOverride)
|
||||
javascript: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
chart: Optional[Chart] = None
|
||||
extra: Optional[dict] = None
|
||||
"""Extra data that can be included. Not part of the standard types."""
|
||||
|
||||
|
||||
class CodeExecutionResult(MainCodeExecutionResult):
|
||||
__doc__ = MainCodeExecutionResult.__doc__
|
||||
|
||||
is_main_result: bool = False
|
||||
"""Whether this data is the main result of the cell. Data can be produced by display calls of which can be multiple in a cell.""" # noqa
|
||||
|
||||
|
||||
class BaseE2BExecutorMixin:
|
||||
"""Shared implementation methods for E2B executor blocks."""
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
api_key: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
template_id: str = "",
|
||||
setup_commands: Optional[list[str]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
sandbox_id: Optional[str] = None,
|
||||
dispose_sandbox: bool = False,
|
||||
):
|
||||
"""
|
||||
Unified code execution method that handles all three use cases:
|
||||
1. Create new sandbox and execute (ExecuteCodeBlock)
|
||||
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
||||
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
||||
""" # noqa
|
||||
sandbox = None
|
||||
try:
|
||||
if sandbox_id:
|
||||
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
||||
sandbox = await AsyncSandbox.connect(
|
||||
sandbox_id=sandbox_id, api_key=api_key
|
||||
)
|
||||
else:
|
||||
# Create new sandbox (ExecuteCodeBlock/InstantiateCodeSandboxBlock case)
|
||||
sandbox = await AsyncSandbox.create(
|
||||
api_key=api_key, template=template_id, timeout=timeout
|
||||
)
|
||||
if setup_commands:
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
results = execution.results
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
|
||||
finally:
|
||||
# Dispose of sandbox if requested to reduce usage costs
|
||||
if dispose_sandbox and sandbox:
|
||||
await sandbox.kill()
|
||||
|
||||
def process_execution_results(
|
||||
self, results: list[E2BExecutionResult]
|
||||
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
|
||||
"""Process and filter execution results."""
|
||||
# Filter out empty formats and convert to dicts
|
||||
processed_results = [
|
||||
{
|
||||
f: value
|
||||
for f in [*r.formats(), "extra", "is_main_result"]
|
||||
if (value := getattr(r, f, None)) is not None
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
if main_result := next(
|
||||
(r for r in processed_results if r.get("is_main_result")), None
|
||||
):
|
||||
# Make main_result a copy we can modify & remove is_main_result
|
||||
(main_result := {**main_result}).pop("is_main_result")
|
||||
|
||||
return main_result, processed_results
|
||||
|
||||
|
||||
class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
class CodeExecutionBlock(Block):
|
||||
# TODO : Add support to upload and download files
|
||||
# NOTE: Currently, you can only customize the CPU and Memory
|
||||
# by creating a pre customized sandbox template
|
||||
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -199,14 +76,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
description="Execution timeout in seconds", default=300
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description=(
|
||||
"Whether to dispose of the sandbox immediately after execution. "
|
||||
"If disabled, the sandbox will run until its timeout expires."
|
||||
),
|
||||
default=True,
|
||||
)
|
||||
|
||||
template_id: str = SchemaField(
|
||||
description=(
|
||||
"You can use an E2B sandbox template by entering its ID here. "
|
||||
@@ -218,16 +87,7 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -237,10 +97,10 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0b02b072-abe7-11ef-8372-fb5d162dd712",
|
||||
description="Executes code in a sandbox environment with internet access.",
|
||||
description="Executes code in an isolated sandbox environment with internet access.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ExecuteCodeBlock.Input,
|
||||
output_schema=ExecuteCodeBlock.Output,
|
||||
input_schema=CodeExecutionBlock.Input,
|
||||
output_schema=CodeExecutionBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -251,59 +111,91 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
"template_id": "",
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
)
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
class InstantiationBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
)
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -348,10 +240,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
|
||||
class Output(BlockSchema):
|
||||
sandbox_id: str = SchemaField(description="ID of the sandbox instance")
|
||||
response: str = SchemaField(
|
||||
title="Text Result",
|
||||
description="Text result (if any) of the setup code execution",
|
||||
)
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -361,13 +250,10 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ff0861c9-1726-4aec-9e5b-bf53f3622112",
|
||||
description=(
|
||||
"Instantiate a sandbox environment with internet access "
|
||||
"in which you can execute code with the Execute Code Step block."
|
||||
),
|
||||
description="Instantiate an isolated sandbox environment with internet access where to execute code in.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=InstantiateCodeSandboxBlock.Input,
|
||||
output_schema=InstantiateCodeSandboxBlock.Output,
|
||||
input_schema=InstantiationBlock.Input,
|
||||
output_schema=InstantiationBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -383,12 +269,11 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
"execute_code": lambda setup_code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"sandbox_id",
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -397,38 +282,78 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
_, text_output, stdout, stderr, sandbox_id = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.setup_code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
sandbox_id, response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.setup_code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
)
|
||||
if sandbox_id:
|
||||
yield "sandbox_id", sandbox_id
|
||||
else:
|
||||
yield "error", "Sandbox ID not found"
|
||||
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return sandbox.sandbox_id, response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class StepExecutionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
sandbox_id: str = SchemaField(
|
||||
@@ -449,22 +374,8 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description="Whether to dispose of the sandbox after executing this code.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -474,10 +385,10 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="82b59b8e-ea10-4d57-9161-8b169b0adba6",
|
||||
description="Execute code in a previously instantiated sandbox.",
|
||||
description="Execute code in a previously instantiated sandbox environment.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ExecuteCodeStepBlock.Input,
|
||||
output_schema=ExecuteCodeStepBlock.Output,
|
||||
input_schema=StepExecutionBlock.Input,
|
||||
output_schema=StepExecutionBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -486,43 +397,61 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
"language": ProgrammingLanguage.PYTHON.value,
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda api_key, code, language, sandbox_id, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
sandbox_id, # sandbox_id
|
||||
"execute_step_code": lambda sandbox_id, step_code, language, api_key: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_step_code(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
api_key: str,
|
||||
):
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id=sandbox_id, api_key=api_key)
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not found")
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(code, language=language.value)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.step_code,
|
||||
language=input_data.language,
|
||||
sandbox_id=input_data.sandbox_id,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
response, stdout_logs, stderr_logs = await self.execute_step_code(
|
||||
input_data.sandbox_id,
|
||||
input_data.step_code,
|
||||
input_data.language,
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -90,7 +90,7 @@ class CodeExtractionBlock(Block):
|
||||
for aliases in language_aliases.values()
|
||||
for alias in aliases
|
||||
)
|
||||
+ r")[ \t]*\n[\s\S]*?```"
|
||||
+ r")\s+[\s\S]*?```"
|
||||
)
|
||||
|
||||
remaining_text = re.sub(pattern, "", input_data.text).strip()
|
||||
@@ -103,9 +103,7 @@ class CodeExtractionBlock(Block):
|
||||
# Escape special regex characters in the language string
|
||||
language = re.escape(language)
|
||||
# Extract all code blocks enclosed in ```language``` blocks
|
||||
pattern = re.compile(
|
||||
rf"```{language}[ \t]*\n(.*?)\n```", re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
pattern = re.compile(rf"```{language}\s+(.*?)```", re.DOTALL | re.IGNORECASE)
|
||||
matches = pattern.finditer(text)
|
||||
# Combine all code blocks for this language with newlines between them
|
||||
code_blocks = [match.group(1).strip() for match in matches]
|
||||
|
||||
@@ -90,7 +90,6 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -162,52 +161,43 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch keyword suggestions: {str(e)}"
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
|
||||
@@ -98,7 +98,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -172,60 +171,50 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get(
|
||||
"competition"
|
||||
),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get(
|
||||
"keyword_properties", {}
|
||||
).get("keyword_difficulty"),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch related keywords: {str(e)}"
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Utility functions for converting between our ScrapeFormat enum and firecrawl FormatOption types."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from firecrawl.v2.types import FormatOption, ScreenshotFormat
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
|
||||
|
||||
def convert_to_format_options(
|
||||
formats: List[ScrapeFormat],
|
||||
) -> List[FormatOption]:
|
||||
"""Convert our ScrapeFormat enum values to firecrawl FormatOption types.
|
||||
|
||||
Handles special cases like screenshot@fullPage which needs to be converted
|
||||
to a ScreenshotFormat object.
|
||||
"""
|
||||
result: List[FormatOption] = []
|
||||
|
||||
for format_enum in formats:
|
||||
if format_enum.value == "screenshot@fullPage":
|
||||
# Special case: convert to ScreenshotFormat with full_page=True
|
||||
result.append(ScreenshotFormat(type="screenshot", full_page=True))
|
||||
else:
|
||||
# Regular string literals
|
||||
result.append(format_enum.value)
|
||||
|
||||
return result
|
||||
@@ -1,9 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -15,10 +14,21 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
|
||||
class FirecrawlCrawlBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
@@ -68,17 +78,18 @@ class FirecrawlCrawlBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
crawl_result = app.crawl(
|
||||
crawl_result = app.crawl_url(
|
||||
input_data.url,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=convert_to_format_options(input_data.formats),
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
formats=[format.value for format in input_data.formats],
|
||||
onlyMainContent=input_data.only_main_content,
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", crawl_result.data
|
||||
@@ -90,7 +101,7 @@ class FirecrawlCrawlBlock(Block):
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", data.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", data.raw_html
|
||||
yield "raw_html", data.rawHtml
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", data.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
@@ -98,6 +109,6 @@ class FirecrawlCrawlBlock(Block):
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", data.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", data.change_tracking
|
||||
yield "change_tracking", data.changeTracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", data.json
|
||||
|
||||
@@ -20,6 +20,7 @@ from ._config import firecrawl
|
||||
|
||||
@cost(BlockCost(2, BlockCostType.RUN))
|
||||
class FirecrawlExtractBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
urls: list[str] = SchemaField(
|
||||
@@ -52,6 +53,7 @@ class FirecrawlExtractBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
extract_result = app.extract(
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.sdk import (
|
||||
@@ -16,16 +14,14 @@ from ._config import firecrawl
|
||||
|
||||
|
||||
class FirecrawlMapWebsiteBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
|
||||
url: str = SchemaField(description="The website url to map")
|
||||
|
||||
class Output(BlockSchema):
|
||||
links: list[str] = SchemaField(description="List of URLs found on the website")
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="List of search results with url, title, and description"
|
||||
)
|
||||
links: list[str] = SchemaField(description="The links of the website")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -39,22 +35,12 @@ class FirecrawlMapWebsiteBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
map_result = app.map(
|
||||
map_result = app.map_url(
|
||||
url=input_data.url,
|
||||
)
|
||||
|
||||
# Convert SearchResult objects to dicts
|
||||
results_data = [
|
||||
{
|
||||
"url": link.url,
|
||||
"title": link.title,
|
||||
"description": link.description,
|
||||
}
|
||||
for link in map_result.links
|
||||
]
|
||||
|
||||
yield "links", [link.url for link in map_result.links]
|
||||
yield "results", results_data
|
||||
yield "links", map_result.links
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -14,10 +14,21 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
|
||||
class FirecrawlScrapeBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
@@ -67,11 +78,12 @@ class FirecrawlScrapeBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
scrape_result = app.scrape(
|
||||
scrape_result = app.scrape_url(
|
||||
input_data.url,
|
||||
formats=convert_to_format_options(input_data.formats),
|
||||
formats=[format.value for format in input_data.formats],
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
@@ -84,7 +96,7 @@ class FirecrawlScrapeBlock(Block):
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", scrape_result.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", scrape_result.raw_html
|
||||
yield "raw_html", scrape_result.rawHtml
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", scrape_result.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
@@ -92,6 +104,6 @@ class FirecrawlScrapeBlock(Block):
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", scrape_result.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", scrape_result.change_tracking
|
||||
yield "change_tracking", scrape_result.changeTracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", scrape_result.json
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -15,10 +14,21 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
|
||||
class FirecrawlSearchBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
query: str = SchemaField(description="The query to search for")
|
||||
@@ -51,6 +61,7 @@ class FirecrawlSearchBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
@@ -58,12 +69,11 @@ class FirecrawlSearchBlock(Block):
|
||||
input_data.query,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=convert_to_format_options(input_data.formats) or None,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
formats=[format.value for format in input_data.formats],
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", scrape_result
|
||||
if hasattr(scrape_result, "web") and scrape_result.web:
|
||||
for site in scrape_result.web:
|
||||
yield "site", site
|
||||
for site in scrape_result.data:
|
||||
yield "site", site
|
||||
|
||||
@@ -554,89 +554,6 @@ class AgentToggleInputBlock(AgentInputBlock):
|
||||
)
|
||||
|
||||
|
||||
class AgentTableInputBlock(AgentInputBlock):
|
||||
"""
|
||||
This block allows users to input data in a table format.
|
||||
|
||||
Configure the table columns at build time, then users can input
|
||||
rows of data at runtime. Each row is output as a dictionary
|
||||
with column names as keys.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[list[dict[str, Any]]] = SchemaField(
|
||||
description="The table data as a list of dictionaries.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
column_headers: list[str] = SchemaField(
|
||||
description="Column headers for the table.",
|
||||
default_factory=lambda: ["Column 1", "Column 2", "Column 3"],
|
||||
advanced=False,
|
||||
title="Column Headers",
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
"""Generate schema for the value field with table format."""
|
||||
schema = super().generate_schema()
|
||||
schema["type"] = "array"
|
||||
schema["format"] = "table"
|
||||
schema["items"] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
header: {"type": "string"}
|
||||
for header in (
|
||||
self.column_headers or ["Column 1", "Column 2", "Column 3"]
|
||||
)
|
||||
},
|
||||
}
|
||||
if self.value is not None:
|
||||
schema["default"] = self.value
|
||||
return schema
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: list[dict[str, Any]] = SchemaField(
|
||||
description="The table data as a list of dictionaries with headers as keys."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5603b273-f41e-4020-af7d-fbc9c6a8d928",
|
||||
description="Block for table data input with customizable headers.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentTableInputBlock.Input,
|
||||
output_schema=AgentTableInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"name": "test_table",
|
||||
"column_headers": ["Name", "Age", "City"],
|
||||
"value": [
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
],
|
||||
"description": "Example table input",
|
||||
}
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
[
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Yields the table data as a list of dictionaries.
|
||||
"""
|
||||
# Pass through the value, defaulting to empty list if None
|
||||
yield "result", input_data.value if input_data.value is not None else []
|
||||
|
||||
|
||||
IO_BLOCK_IDs = [
|
||||
AgentInputBlock().id,
|
||||
AgentOutputBlock().id,
|
||||
@@ -648,5 +565,4 @@ IO_BLOCK_IDs = [
|
||||
AgentFileInputBlock().id,
|
||||
AgentDropdownInputBlock().id,
|
||||
AgentToggleInputBlock().id,
|
||||
AgentTableInputBlock().id,
|
||||
]
|
||||
|
||||
@@ -54,43 +54,20 @@ class StepThroughItemsBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add limits to prevent DoS from large iterations
|
||||
MAX_ITEMS = 10000 # Maximum items to iterate
|
||||
MAX_ITEM_SIZE = 1024 * 1024 # 1MB per item
|
||||
|
||||
for data in [input_data.items, input_data.items_object, input_data.items_str]:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# Limit string size before parsing
|
||||
if isinstance(data, str):
|
||||
if len(data) > MAX_ITEM_SIZE:
|
||||
raise ValueError(
|
||||
f"Input too large: {len(data)} bytes > {MAX_ITEM_SIZE} bytes"
|
||||
)
|
||||
items = json.loads(data)
|
||||
else:
|
||||
items = data
|
||||
|
||||
# Check total item count
|
||||
if isinstance(items, (list, dict)):
|
||||
if len(items) > MAX_ITEMS:
|
||||
raise ValueError(f"Too many items: {len(items)} > {MAX_ITEMS}")
|
||||
|
||||
iteration_count = 0
|
||||
if isinstance(items, dict):
|
||||
# If items is a dictionary, iterate over its values
|
||||
for key, value in items.items():
|
||||
if iteration_count >= MAX_ITEMS:
|
||||
break
|
||||
yield "item", value
|
||||
yield "key", key # Fixed: should yield key, not item
|
||||
iteration_count += 1
|
||||
for item in items.values():
|
||||
yield "item", item
|
||||
yield "key", item
|
||||
else:
|
||||
# If items is a list, iterate over the list
|
||||
for index, item in enumerate(items):
|
||||
if iteration_count >= MAX_ITEMS:
|
||||
break
|
||||
yield "item", item
|
||||
yield "key", index
|
||||
iteration_count += 1
|
||||
|
||||
@@ -101,7 +101,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
@@ -214,9 +213,6 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-3-7-sonnet-20250219
|
||||
@@ -1404,27 +1400,11 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
|
||||
@staticmethod
|
||||
def _split_text(text: str, max_tokens: int, overlap: int) -> list[str]:
|
||||
# Security fix: Add validation to prevent DoS attacks
|
||||
# Limit text size to prevent memory exhaustion
|
||||
MAX_TEXT_LENGTH = 1_000_000 # 1MB character limit
|
||||
MAX_CHUNKS = 100 # Maximum number of chunks to prevent excessive memory use
|
||||
|
||||
if len(text) > MAX_TEXT_LENGTH:
|
||||
text = text[:MAX_TEXT_LENGTH]
|
||||
|
||||
# Ensure chunk_size is at least 1 to prevent infinite loops
|
||||
chunk_size = max(1, max_tokens - overlap)
|
||||
|
||||
# Ensure overlap is less than max_tokens to prevent invalid configurations
|
||||
if overlap >= max_tokens:
|
||||
overlap = max(0, max_tokens - 1)
|
||||
|
||||
words = text.split()
|
||||
chunks = []
|
||||
chunk_size = max_tokens - overlap
|
||||
|
||||
for i in range(0, len(words), chunk_size):
|
||||
if len(chunks) >= MAX_CHUNKS:
|
||||
break # Limit the number of chunks to prevent memory exhaustion
|
||||
chunk = " ".join(words[i : i + max_tokens])
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
@@ -104,38 +101,7 @@ class ReadRSSFeedBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def parse_feed(url: str) -> dict[str, Any]:
|
||||
# Security fix: Add protection against memory exhaustion attacks
|
||||
MAX_FEED_SIZE = 10 * 1024 * 1024 # 10MB limit for RSS feeds
|
||||
|
||||
# Validate URL
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError(f"Invalid URL scheme: {parsed_url.scheme}")
|
||||
|
||||
# Download with size limit
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=30) as response:
|
||||
# Check content length if available
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if content_length and int(content_length) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: {content_length} bytes exceeds {MAX_FEED_SIZE} limit"
|
||||
)
|
||||
|
||||
# Read with size limit
|
||||
content = response.read(MAX_FEED_SIZE + 1)
|
||||
if len(content) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: exceeds {MAX_FEED_SIZE} byte limit"
|
||||
)
|
||||
|
||||
# Parse with feedparser using the validated content
|
||||
# feedparser has built-in protection against XML attacks
|
||||
return feedparser.parse(content) # type: ignore
|
||||
except Exception as e:
|
||||
# Log error and return empty feed
|
||||
logging.warning(f"Failed to parse RSS feed from {url}: {e}")
|
||||
return {"entries": []}
|
||||
return feedparser.parse(url) # type: ignore
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
keep_going = True
|
||||
|
||||
@@ -13,11 +13,6 @@ from backend.data.block import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.dynamic_fields import (
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
is_dynamic_field,
|
||||
)
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
@@ -103,22 +98,6 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Safely convert raw_response to dictionary format for conversation history.
|
||||
Handles different response types from different LLM providers.
|
||||
"""
|
||||
if isinstance(raw_response, str):
|
||||
# Ollama returns a string, convert to dict format
|
||||
return {"role": "assistant", "content": raw_response}
|
||||
elif isinstance(raw_response, dict):
|
||||
# Already a dict (from tests or some providers)
|
||||
return raw_response
|
||||
else:
|
||||
# OpenAI/Anthropic return objects, convert with json.to_dict
|
||||
return json.to_dict(raw_response)
|
||||
|
||||
|
||||
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
|
||||
"""
|
||||
All the tool calls entry in the conversation history requires a response.
|
||||
@@ -282,7 +261,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def cleanup(s: str):
|
||||
"""Clean up block names for use as tool function names."""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
@staticmethod
|
||||
@@ -310,66 +288,41 @@ class SmartDecisionMakerBlock(Block):
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
properties = {}
|
||||
field_mapping = {} # clean_name -> original_name
|
||||
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Clean property key to ensure Anthropic API compatibility for ALL fields
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
|
||||
if is_dynamic:
|
||||
# For dynamic fields, use cleaned name but preserve original in description
|
||||
properties[clean_field_name] = {
|
||||
# Handle dynamic fields (e.g., values_#_*, items_$_*, etc.)
|
||||
# These are fields that get merged by the executor into their base field
|
||||
if (
|
||||
"_#_" in link.sink_name
|
||||
or "_$_" in link.sink_name
|
||||
or "_@_" in link.sink_name
|
||||
):
|
||||
# For dynamic fields, provide a generic string schema
|
||||
# The executor will handle merging these into the appropriate structure
|
||||
properties[sink_name] = {
|
||||
"type": "string",
|
||||
"description": get_dynamic_field_description(field_name),
|
||||
"description": f"Dynamic value for {link.sink_name}",
|
||||
}
|
||||
else:
|
||||
# For regular fields, use the block's schema directly
|
||||
# For regular fields, use the block's schema
|
||||
try:
|
||||
properties[clean_field_name] = (
|
||||
sink_block_input_schema.get_field_schema(field_name)
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
# If field doesn't exist in schema, provide a generic one
|
||||
properties[clean_field_name] = {
|
||||
# If the field doesn't exist in the schema, provide a generic schema
|
||||
properties[sink_name] = {
|
||||
"type": "string",
|
||||
"description": f"Value for {field_name}",
|
||||
"description": f"Value for {link.sink_name}",
|
||||
}
|
||||
|
||||
# Build the parameters schema using a single unified path
|
||||
base_schema = block.input_schema.jsonschema()
|
||||
base_required = set(base_schema.get("required", []))
|
||||
|
||||
# Compute required fields at the leaf level:
|
||||
# - If a linked field is dynamic and its base is required in the block schema, require the leaf
|
||||
# - If a linked field is regular and is required in the block schema, require the leaf
|
||||
required_fields: set[str] = set()
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Always use cleaned field name for property key (Anthropic API compliance)
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
|
||||
if is_dynamic:
|
||||
base_name = extract_base_field_name(field_name)
|
||||
if base_name in base_required:
|
||||
required_fields.add(clean_field_name)
|
||||
else:
|
||||
if field_name in base_required:
|
||||
required_fields.add(clean_field_name)
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
**block.input_schema.jsonschema(),
|
||||
"properties": properties,
|
||||
"additionalProperties": False,
|
||||
"required": sorted(required_fields),
|
||||
}
|
||||
|
||||
# Store field mapping for later use in output processing
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
@@ -413,12 +366,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||
link.sink_name, {}
|
||||
)
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
description = (
|
||||
sink_block_properties["description"]
|
||||
if "description" in sink_block_properties
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name] = {
|
||||
properties[sink_name] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||
@@ -434,17 +388,24 @@ class SmartDecisionMakerBlock(Block):
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
async def _create_function_signature(
|
||||
node_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for connected tools.
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
This method filters the graph links to identify those that are tools and are
|
||||
connected to the given node_id. It then constructs function signatures for each
|
||||
tool based on the metadata and input schema of the linked nodes.
|
||||
|
||||
Args:
|
||||
node_id: The node_id for which to create function signatures.
|
||||
|
||||
Returns:
|
||||
List of function signatures for tools
|
||||
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
|
||||
for a tool, including its name, description, and parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
"""
|
||||
db_client = get_database_manager_async_client()
|
||||
tools = [
|
||||
@@ -469,116 +430,20 @@ class SmartDecisionMakerBlock(Block):
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
tool_func = (
|
||||
return_tool_functions.append(
|
||||
await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
else:
|
||||
tool_func = (
|
||||
return_tool_functions.append(
|
||||
await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
|
||||
return return_tool_functions
|
||||
|
||||
async def _attempt_llm_call_with_validation(
|
||||
self,
|
||||
credentials: llm.APIKeyCredentials,
|
||||
input_data: Input,
|
||||
current_prompt: list[dict],
|
||||
tool_functions: list[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Attempt a single LLM call with tool validation.
|
||||
|
||||
Returns the response if successful, raises ValueError if validation fails.
|
||||
"""
|
||||
resp = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=current_prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
|
||||
# Track LLM usage stats per call
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=resp.prompt_tokens,
|
||||
output_token_count=resp.completion_tokens,
|
||||
llm_call_count=1,
|
||||
)
|
||||
)
|
||||
|
||||
if not resp.tool_calls:
|
||||
return resp
|
||||
validation_errors_list: list[str] = []
|
||||
for tool_call in resp.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
try:
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
except Exception as e:
|
||||
validation_errors_list.append(
|
||||
f"Tool call '{tool_name}' has invalid JSON arguments: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_def is None and len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
|
||||
# Get parameters schema from tool definition
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
parameters = tool_def["function"]["parameters"]
|
||||
expected_args = parameters.get("properties", {})
|
||||
required_params = set(parameters.get("required", []))
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
required_params = set()
|
||||
|
||||
# Validate tool call arguments
|
||||
provided_args = set(tool_args.keys())
|
||||
expected_args_set = set(expected_args.keys())
|
||||
|
||||
# Check for unexpected arguments (typos)
|
||||
unexpected_args = provided_args - expected_args_set
|
||||
# Only check for missing REQUIRED parameters
|
||||
missing_required_args = required_params - provided_args
|
||||
|
||||
if unexpected_args or missing_required_args:
|
||||
error_msg = f"Tool call '{tool_name}' has parameter errors:"
|
||||
if unexpected_args:
|
||||
error_msg += f" Unknown parameters: {sorted(unexpected_args)}."
|
||||
if missing_required_args:
|
||||
error_msg += f" Missing required parameters: {sorted(missing_required_args)}."
|
||||
error_msg += f" Expected parameters: {sorted(expected_args_set)}."
|
||||
if required_params:
|
||||
error_msg += f" Required parameters: {sorted(required_params)}."
|
||||
validation_errors_list.append(error_msg)
|
||||
|
||||
if validation_errors_list:
|
||||
raise ValueError("; ".join(validation_errors_list))
|
||||
|
||||
return resp
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
@@ -601,19 +466,27 @@ class SmartDecisionMakerBlock(Block):
|
||||
if pending_tool_calls and input_data.last_tool_output is None:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
# Only assign the last tool output to the first pending tool call
|
||||
tool_output = []
|
||||
if pending_tool_calls and input_data.last_tool_output is not None:
|
||||
# Get the first pending tool call ID
|
||||
first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
tool_output.append(
|
||||
_create_tool_response(first_call_id, input_data.last_tool_output)
|
||||
)
|
||||
|
||||
# Add tool output to prompt right away
|
||||
prompt.extend(tool_output)
|
||||
|
||||
# Check if there are still pending tool calls after handling the first one
|
||||
remaining_pending_calls = get_pending_tool_calls(prompt)
|
||||
|
||||
# If there are still pending tool calls, yield the conversation and return early
|
||||
if remaining_pending_calls:
|
||||
yield "conversations", prompt
|
||||
return
|
||||
|
||||
# Fallback on adding tool output in the conversation history as user prompt.
|
||||
elif input_data.last_tool_output:
|
||||
logger.error(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
@@ -646,33 +519,24 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
||||
|
||||
current_prompt = list(prompt)
|
||||
max_attempts = max(1, int(input_data.retry))
|
||||
response = None
|
||||
response = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = await self._attempt_llm_call_with_validation(
|
||||
credentials, input_data, current_prompt, tool_functions
|
||||
)
|
||||
break
|
||||
|
||||
except ValueError as e:
|
||||
last_error = e
|
||||
error_feedback = (
|
||||
"Your tool call had parameter errors. Please fix the following issues and try again:\n"
|
||||
+ f"- {str(e)}\n"
|
||||
+ "\nPlease make sure to use the exact parameter names as specified in the function schema."
|
||||
)
|
||||
current_prompt = list(current_prompt) + [
|
||||
{"role": "user", "content": error_feedback}
|
||||
]
|
||||
|
||||
if response is None:
|
||||
raise last_error or ValueError(
|
||||
"Failed to get valid response after all retry attempts"
|
||||
# Track LLM usage stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
llm_call_count=1,
|
||||
)
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", response.response
|
||||
@@ -682,6 +546,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
@@ -690,6 +555,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
@@ -697,38 +563,20 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
expected_args = tool_args.keys()
|
||||
|
||||
# Get field mapping from tool definition
|
||||
field_mapping = (
|
||||
tool_def.get("function", {}).get("_field_mapping", {})
|
||||
if tool_def
|
||||
else {}
|
||||
)
|
||||
|
||||
for clean_arg_name in expected_args:
|
||||
# arg_name is now always the cleaned field name (for Anthropic API compliance)
|
||||
# Get the original field name from field mapping for proper emit key generation
|
||||
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
arg_value = tool_args.get(clean_arg_name)
|
||||
|
||||
sanitized_tool_name = self.cleanup(tool_name)
|
||||
sanitized_arg_name = self.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sanitized_tool_name}_~_{sanitized_arg_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
graph_exec_id,
|
||||
node_exec_id,
|
||||
emit_key,
|
||||
)
|
||||
yield emit_key, arg_value
|
||||
# Yield provided arguments and None for missing ones
|
||||
for arg_name in expected_args:
|
||||
if arg_name in tool_args:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(_convert_raw_response_to_dict(response.raw_response))
|
||||
|
||||
prompt.append(response.raw_response)
|
||||
yield "conversations", prompt
|
||||
|
||||
@@ -19,7 +19,7 @@ async def test_block_ids_valid(block: Type[Block]):
|
||||
# Skip list for blocks with known invalid UUIDs
|
||||
skip_blocks = {
|
||||
"GetWeatherInformationBlock",
|
||||
"ExecuteCodeBlock",
|
||||
"CodeExecutionBlock",
|
||||
"CountdownTimerBlock",
|
||||
"TwitterGetListTweetsBlock",
|
||||
"TwitterRemoveListMemberBlock",
|
||||
|
||||
@@ -1,269 +0,0 @@
|
||||
"""
|
||||
Test security fixes for various DoS vulnerabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.code_extraction_block import CodeExtractionBlock
|
||||
from backend.blocks.iteration import StepThroughItemsBlock
|
||||
from backend.blocks.llm import AITextSummarizerBlock
|
||||
from backend.blocks.text import ExtractTextInformationBlock
|
||||
from backend.blocks.xml_parser import XMLParserBlock
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class TestCodeExtractionBlockSecurity:
|
||||
"""Test ReDoS fixes in CodeExtractionBlock."""
|
||||
|
||||
async def test_redos_protection(self):
|
||||
"""Test that the regex patterns don't cause ReDoS."""
|
||||
block = CodeExtractionBlock()
|
||||
|
||||
# Test with input that would previously cause ReDoS
|
||||
malicious_input = "```python" + " " * 10000 # Large spaces
|
||||
|
||||
result = []
|
||||
async for output_name, output_data in block.run(
|
||||
CodeExtractionBlock.Input(text=malicious_input)
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(result) >= 1
|
||||
assert any(name == "remaining_text" for name, _ in result)
|
||||
|
||||
|
||||
class TestAITextSummarizerBlockSecurity:
|
||||
"""Test memory exhaustion fixes in AITextSummarizerBlock."""
|
||||
|
||||
def test_split_text_limits(self):
|
||||
"""Test that _split_text has proper limits."""
|
||||
# Test text size limit
|
||||
large_text = "a" * 2_000_000 # 2MB text
|
||||
result = AITextSummarizerBlock._split_text(large_text, 1000, 100)
|
||||
|
||||
# Should be truncated to 1MB
|
||||
total_chars = sum(len(chunk) for chunk in result)
|
||||
assert total_chars <= 1_000_000 + 1000 # Allow for chunk boundary
|
||||
|
||||
# Test chunk count limit
|
||||
result = AITextSummarizerBlock._split_text("word " * 10000, 10, 9)
|
||||
assert len(result) <= 100 # MAX_CHUNKS limit
|
||||
|
||||
# Test parameter validation
|
||||
result = AITextSummarizerBlock._split_text(
|
||||
"test", 10, 15
|
||||
) # overlap > max_tokens
|
||||
assert len(result) >= 1 # Should still work
|
||||
|
||||
|
||||
class TestExtractTextInformationBlockSecurity:
|
||||
"""Test ReDoS and memory exhaustion fixes in ExtractTextInformationBlock."""
|
||||
|
||||
async def test_text_size_limits(self):
|
||||
"""Test text size limits."""
|
||||
block = ExtractTextInformationBlock()
|
||||
|
||||
# Test with large input
|
||||
large_text = "a" * 2_000_000 # 2MB
|
||||
|
||||
results = []
|
||||
async for output_name, output_data in block.run(
|
||||
ExtractTextInformationBlock.Input(
|
||||
text=large_text, pattern=r"a+", find_all=True, group=0
|
||||
)
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
# Should complete and have limits applied
|
||||
matched_results = [r for name, r in results if name == "matched_results"]
|
||||
if matched_results:
|
||||
assert len(matched_results[0]) <= 1000 # MAX_MATCHES limit
|
||||
|
||||
async def test_dangerous_pattern_timeout(self):
|
||||
"""Test timeout protection for dangerous patterns."""
|
||||
block = ExtractTextInformationBlock()
|
||||
|
||||
# Test with potentially dangerous lookahead pattern
|
||||
test_input = "a" * 1000
|
||||
|
||||
# This should complete quickly due to timeout protection
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results = []
|
||||
async for output_name, output_data in block.run(
|
||||
ExtractTextInformationBlock.Input(
|
||||
text=test_input, pattern=r"(?=.+)", find_all=True, group=0
|
||||
)
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
# Should complete within reasonable time (much less than 5s timeout)
|
||||
assert (end_time - start_time) < 10
|
||||
|
||||
async def test_redos_catastrophic_backtracking(self):
|
||||
"""Test that ReDoS patterns with catastrophic backtracking are handled."""
|
||||
block = ExtractTextInformationBlock()
|
||||
|
||||
# Pattern that causes catastrophic backtracking: (a+)+b
|
||||
# With input "aaaaaaaaaaaaaaaaaaaaaaaaaaaa" (no 'b'), this causes exponential time
|
||||
dangerous_pattern = r"(a+)+b"
|
||||
test_input = "a" * 30 # 30 'a's without a 'b' at the end
|
||||
|
||||
# This should be handled by timeout protection or pattern detection
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results = []
|
||||
|
||||
async for output_name, output_data in block.run(
|
||||
ExtractTextInformationBlock.Input(
|
||||
text=test_input, pattern=dangerous_pattern, find_all=True, group=0
|
||||
)
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
elapsed = end_time - start_time
|
||||
|
||||
# Should complete within timeout (6 seconds to be safe)
|
||||
# The current threading.Timer approach doesn't work, so this will likely fail
|
||||
# demonstrating the need for a fix
|
||||
assert elapsed < 6, f"Regex took {elapsed}s, timeout mechanism failed"
|
||||
|
||||
# Should return empty results on timeout or no match
|
||||
matched_results = [r for name, r in results if name == "matched_results"]
|
||||
assert matched_results[0] == [] # No matches expected
|
||||
|
||||
|
||||
class TestStepThroughItemsBlockSecurity:
|
||||
"""Test iteration limits in StepThroughItemsBlock."""
|
||||
|
||||
async def test_item_count_limits(self):
|
||||
"""Test maximum item count limits."""
|
||||
block = StepThroughItemsBlock()
|
||||
|
||||
# Test with too many items
|
||||
large_list = list(range(20000)) # Exceeds MAX_ITEMS (10000)
|
||||
|
||||
with pytest.raises(ValueError, match="Too many items"):
|
||||
async for _ in block.run(StepThroughItemsBlock.Input(items=large_list)):
|
||||
pass
|
||||
|
||||
async def test_string_size_limits(self):
|
||||
"""Test string input size limits."""
|
||||
block = StepThroughItemsBlock()
|
||||
|
||||
# Test with large JSON string
|
||||
large_string = '["item"]' * 200000 # Large JSON string
|
||||
|
||||
with pytest.raises(ValueError, match="Input too large"):
|
||||
async for _ in block.run(
|
||||
StepThroughItemsBlock.Input(items_str=large_string)
|
||||
):
|
||||
pass
|
||||
|
||||
async def test_normal_iteration_works(self):
|
||||
"""Test that normal iteration still works."""
|
||||
block = StepThroughItemsBlock()
|
||||
|
||||
results = []
|
||||
async for output_name, output_data in block.run(
|
||||
StepThroughItemsBlock.Input(items=[1, 2, 3])
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
# Should have 6 outputs (item, key for each of 3 items)
|
||||
assert len(results) == 6
|
||||
items = [data for name, data in results if name == "item"]
|
||||
assert items == [1, 2, 3]
|
||||
|
||||
|
||||
class TestXMLParserBlockSecurity:
|
||||
"""Test XML size limits in XMLParserBlock."""
|
||||
|
||||
async def test_xml_size_limits(self):
|
||||
"""Test XML input size limits."""
|
||||
block = XMLParserBlock()
|
||||
|
||||
# Test with large XML - need to exceed 10MB limit
|
||||
# Each "<item>data</item>" is 17 chars, need ~620K items for >10MB
|
||||
large_xml = "<root>" + "<item>data</item>" * 620000 + "</root>"
|
||||
|
||||
with pytest.raises(ValueError, match="XML too large"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
@patch("backend.util.file.scan_content_safe")
|
||||
@patch("backend.util.file.get_cloud_storage_handler")
|
||||
async def test_file_size_limits(self, mock_cloud_storage, mock_scan):
|
||||
"""Test file size limits."""
|
||||
# Mock cloud storage handler - get_cloud_storage_handler is async
|
||||
# but is_cloud_path and parse_cloud_path are sync methods
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
|
||||
# Make get_cloud_storage_handler an async function that returns the mock handler
|
||||
async def async_get_handler():
|
||||
return mock_handler
|
||||
|
||||
mock_cloud_storage.side_effect = async_get_handler
|
||||
mock_scan.return_value = None
|
||||
|
||||
# Test with large base64 content
|
||||
large_content = "a" * (200 * 1024 * 1024) # 200MB
|
||||
large_data_uri = f"data:text/plain;base64,{large_content}"
|
||||
|
||||
with pytest.raises(ValueError, match="File too large"):
|
||||
await store_media_file(
|
||||
graph_exec_id="test",
|
||||
file=MediaFileType(large_data_uri),
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
@patch("backend.util.file.Path")
|
||||
@patch("backend.util.file.scan_content_safe")
|
||||
@patch("backend.util.file.get_cloud_storage_handler")
|
||||
async def test_directory_size_limits(self, mock_cloud_storage, mock_scan, MockPath):
|
||||
"""Test directory size limits."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
|
||||
async def async_get_handler():
|
||||
return mock_handler
|
||||
|
||||
mock_cloud_storage.side_effect = async_get_handler
|
||||
mock_scan.return_value = None
|
||||
|
||||
# Create mock path instance for the execution directory
|
||||
mock_path_instance = MagicMock()
|
||||
mock_path_instance.exists.return_value = True
|
||||
|
||||
# Mock glob to return files that total > 1GB
|
||||
mock_file = MagicMock()
|
||||
mock_file.is_file.return_value = True
|
||||
mock_file.stat.return_value.st_size = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
mock_path_instance.glob.return_value = [mock_file]
|
||||
|
||||
# Make Path() return our mock
|
||||
MockPath.return_value = mock_path_instance
|
||||
|
||||
# Should raise an error when directory size exceeds limit
|
||||
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
||||
await store_media_file(
|
||||
graph_exec_id="test",
|
||||
file=MediaFileType(
|
||||
"data:text/plain;base64,dGVzdA=="
|
||||
), # Small test file
|
||||
user_id="test_user",
|
||||
)
|
||||
@@ -216,17 +216,8 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
}
|
||||
|
||||
# Mock the _create_function_signature method to avoid database calls
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
with patch("backend.blocks.llm.llm_call", return_value=mock_response), patch.object(
|
||||
SmartDecisionMakerBlock, "_create_function_signature", return_value=[]
|
||||
):
|
||||
|
||||
# Create test input
|
||||
@@ -258,471 +249,3 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Verify outputs
|
||||
assert "finished" in outputs # Should have finished since no tool calls
|
||||
assert outputs["finished"] == "I need to think about this."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions with specific parameter schema
|
||||
mock_tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_keywords",
|
||||
"description": "Search for keywords with difficulty filtering",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"max_keyword_difficulty": {
|
||||
"type": "integer",
|
||||
"description": "Maximum keyword difficulty (required)",
|
||||
},
|
||||
"optional_param": {
|
||||
"type": "string",
|
||||
"description": "Optional parameter with default",
|
||||
"default": "default_value",
|
||||
},
|
||||
},
|
||||
"required": ["query", "max_keyword_difficulty"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test case 1: Tool call with TYPO in parameter name (should retry and eventually fail)
|
||||
mock_tool_call_with_typo = MagicMock()
|
||||
mock_tool_call_with_typo.function.name = "search_keywords"
|
||||
mock_tool_call_with_typo.function.arguments = '{"query": "test", "maximum_keyword_difficulty": 50}' # TYPO: maximum instead of max
|
||||
|
||||
mock_response_with_typo = MagicMock()
|
||||
mock_response_with_typo.response = None
|
||||
mock_response_with_typo.tool_calls = [mock_tool_call_with_typo]
|
||||
mock_response_with_typo.prompt_tokens = 50
|
||||
mock_response_with_typo.completion_tokens = 25
|
||||
mock_response_with_typo.reasoning = None
|
||||
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
)
|
||||
|
||||
# Should raise ValueError after retries due to typo'd parameter name
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify error message contains details about the typo
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call 'search_keywords' has parameter errors" in error_msg
|
||||
assert "Unknown parameters: ['maximum_keyword_difficulty']" in error_msg
|
||||
|
||||
# Verify that LLM was called the expected number of times (retries)
|
||||
assert mock_llm_call.call_count == 2 # Should retry based on input_data.retry
|
||||
|
||||
# Test case 2: Tool call missing REQUIRED parameter (should raise ValueError)
|
||||
mock_tool_call_missing_required = MagicMock()
|
||||
mock_tool_call_missing_required.function.name = "search_keywords"
|
||||
mock_tool_call_missing_required.function.arguments = (
|
||||
'{"query": "test"}' # Missing required max_keyword_difficulty
|
||||
)
|
||||
|
||||
mock_response_missing_required = MagicMock()
|
||||
mock_response_missing_required.response = None
|
||||
mock_response_missing_required.tool_calls = [mock_tool_call_missing_required]
|
||||
mock_response_missing_required.prompt_tokens = 50
|
||||
mock_response_missing_required.completion_tokens = 25
|
||||
mock_response_missing_required.reasoning = None
|
||||
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should raise ValueError due to missing required parameter
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call 'search_keywords' has parameter errors" in error_msg
|
||||
assert "Missing required parameters: ['max_keyword_difficulty']" in error_msg
|
||||
|
||||
# Test case 3: Valid tool call with OPTIONAL parameter missing (should succeed)
|
||||
mock_tool_call_valid = MagicMock()
|
||||
mock_tool_call_valid.function.name = "search_keywords"
|
||||
mock_tool_call_valid.function.arguments = '{"query": "test", "max_keyword_difficulty": 50}' # optional_param missing, but that's OK
|
||||
|
||||
mock_response_valid = MagicMock()
|
||||
mock_response_valid.response = None
|
||||
mock_response_valid.tool_calls = [mock_tool_call_valid]
|
||||
mock_response_valid.prompt_tokens = 50
|
||||
mock_response_valid.completion_tokens = 25
|
||||
mock_response_valid.reasoning = None
|
||||
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should succeed - optional parameter missing is OK
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify tool outputs were generated correctly
|
||||
assert "tools_^_search_keywords_~_query" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert "tools_^_search_keywords_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
# Optional parameter should be None when not provided
|
||||
assert "tools_^_search_keywords_~_optional_param" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] is None
|
||||
|
||||
# Test case 4: Valid tool call with ALL parameters (should succeed)
|
||||
mock_tool_call_all_params = MagicMock()
|
||||
mock_tool_call_all_params.function.name = "search_keywords"
|
||||
mock_tool_call_all_params.function.arguments = '{"query": "test", "max_keyword_difficulty": 50, "optional_param": "custom_value"}'
|
||||
|
||||
mock_response_all_params = MagicMock()
|
||||
mock_response_all_params.response = None
|
||||
mock_response_all_params.tool_calls = [mock_tool_call_all_params]
|
||||
mock_response_all_params.prompt_tokens = 50
|
||||
mock_response_all_params.completion_tokens = 25
|
||||
mock_response_all_params.reasoning = None
|
||||
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should succeed with all parameters
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify all tool outputs were generated correctly
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] == "custom_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_raw_response_conversion():
|
||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions
|
||||
mock_tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test case 1: Simulate ChatCompletionMessage raw_response that caused the original error
|
||||
class MockChatCompletionMessage:
|
||||
"""Simulate OpenAI's ChatCompletionMessage object that lacks .get() method"""
|
||||
|
||||
def __init__(self, role, content, tool_calls=None):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.tool_calls = tool_calls or []
|
||||
|
||||
# This is what caused the error - no .get() method
|
||||
# def get(self, key, default=None): # Intentionally missing
|
||||
|
||||
# First response: has invalid parameter name (triggers retry)
|
||||
mock_tool_call_invalid = MagicMock()
|
||||
mock_tool_call_invalid.function.name = "test_tool"
|
||||
mock_tool_call_invalid.function.arguments = (
|
||||
'{"wrong_param": "test_value"}' # Invalid parameter name
|
||||
)
|
||||
|
||||
mock_response_retry = MagicMock()
|
||||
mock_response_retry.response = None
|
||||
mock_response_retry.tool_calls = [mock_tool_call_invalid]
|
||||
mock_response_retry.prompt_tokens = 50
|
||||
mock_response_retry.completion_tokens = 25
|
||||
mock_response_retry.reasoning = None
|
||||
# This would cause the original error without our fix
|
||||
mock_response_retry.raw_response = MockChatCompletionMessage(
|
||||
role="assistant", content=None, tool_calls=[mock_tool_call_invalid]
|
||||
)
|
||||
|
||||
# Second response: successful (correct parameter name)
|
||||
mock_tool_call_valid = MagicMock()
|
||||
mock_tool_call_valid.function.name = "test_tool"
|
||||
mock_tool_call_valid.function.arguments = (
|
||||
'{"param": "test_value"}' # Correct parameter name
|
||||
)
|
||||
|
||||
mock_response_success = MagicMock()
|
||||
mock_response_success.response = None
|
||||
mock_response_success.tool_calls = [mock_tool_call_valid]
|
||||
mock_response_success.prompt_tokens = 50
|
||||
mock_response_success.completion_tokens = 25
|
||||
mock_response_success.reasoning = None
|
||||
mock_response_success.raw_response = MockChatCompletionMessage(
|
||||
role="assistant", content=None, tool_calls=[mock_tool_call_valid]
|
||||
)
|
||||
|
||||
# Mock llm_call to return different responses on different calls
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
# First call returns response that will trigger retry due to validation error
|
||||
# Second call returns successful response
|
||||
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
# Should succeed after retry, demonstrating our helper function works
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify the tool output was generated successfully
|
||||
assert "tools_^_test_tool_~_param" in outputs
|
||||
assert outputs["tools_^_test_tool_~_param"] == "test_value"
|
||||
|
||||
# Verify conversation history was properly maintained
|
||||
assert "conversations" in outputs
|
||||
conversations = outputs["conversations"]
|
||||
assert len(conversations) > 0
|
||||
|
||||
# The conversations should contain properly converted raw_response objects as dicts
|
||||
# This would have failed with the original bug due to ChatCompletionMessage.get() error
|
||||
for msg in conversations:
|
||||
assert isinstance(msg, dict), f"Expected dict, got {type(msg)}"
|
||||
if msg.get("role") == "assistant":
|
||||
# Should have been converted from ChatCompletionMessage to dict
|
||||
assert "role" in msg
|
||||
|
||||
# Verify LLM was called twice (initial + 1 retry)
|
||||
assert mock_llm_call.call_count == 2
|
||||
|
||||
# Test case 2: Test with different raw_response types (Ollama string, dict)
|
||||
# Test Ollama string response
|
||||
mock_response_ollama = MagicMock()
|
||||
mock_response_ollama.response = "I'll help you with that."
|
||||
mock_response_ollama.tool_calls = None
|
||||
mock_response_ollama.prompt_tokens = 30
|
||||
mock_response_ollama.completion_tokens = 15
|
||||
mock_response_ollama.reasoning = None
|
||||
mock_response_ollama.raw_response = (
|
||||
"I'll help you with that." # Ollama returns string
|
||||
)
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Should finish since no tool calls
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "I'll help you with that."
|
||||
|
||||
# Test case 3: Test with dict raw_response (some providers/tests)
|
||||
mock_response_dict = MagicMock()
|
||||
mock_response_dict.response = "Test response"
|
||||
mock_response_dict.tool_calls = None
|
||||
mock_response_dict.prompt_tokens = 25
|
||||
mock_response_dict.completion_tokens = 10
|
||||
mock_response_dict.reasoning = None
|
||||
mock_response_dict.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": "Test response",
|
||||
} # Dict format
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "Test response"
|
||||
|
||||
@@ -48,24 +48,16 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
assert "parameters" in signature["function"]
|
||||
assert "properties" in signature["function"]["parameters"]
|
||||
|
||||
# Check that dynamic fields are handled with original names
|
||||
# Check that dynamic fields are handled
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3 # Should have all three fields
|
||||
|
||||
# Check that field names are cleaned (for Anthropic API compatibility)
|
||||
assert "values___name" in properties
|
||||
assert "values___age" in properties
|
||||
assert "values___city" in properties
|
||||
|
||||
# Each dynamic field should have proper schema with descriptive text
|
||||
for field_name, prop_value in properties.items():
|
||||
# Each dynamic field should have proper schema
|
||||
for prop_value in properties.values():
|
||||
assert "type" in prop_value
|
||||
assert prop_value["type"] == "string" # Dynamic fields get string type
|
||||
assert "description" in prop_value
|
||||
# Check that descriptions properly explain the dynamic field
|
||||
if field_name == "values___name":
|
||||
assert "Dictionary field 'name'" in prop_value["description"]
|
||||
assert "values['name']" in prop_value["description"]
|
||||
assert "Dynamic value for" in prop_value["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -104,18 +96,10 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 2 # Should have both list items
|
||||
|
||||
# Check that field names are cleaned (for Anthropic API compatibility)
|
||||
assert "entries___0" in properties
|
||||
assert "entries___1" in properties
|
||||
|
||||
# Each dynamic field should have proper schema with descriptive text
|
||||
for field_name, prop_value in properties.items():
|
||||
# Each dynamic field should have proper schema
|
||||
for prop_value in properties.values():
|
||||
assert prop_value["type"] == "string"
|
||||
assert "description" in prop_value
|
||||
# Check that descriptions properly explain the list field
|
||||
if field_name == "entries___0":
|
||||
assert "List item 0" in prop_value["description"]
|
||||
assert "entries[0]" in prop_value["description"]
|
||||
assert "Dynamic value for" in prop_value["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,553 +0,0 @@
|
||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.text import MatchTextPatternBlock
|
||||
from backend.data.dynamic_fields import get_dynamic_field_description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_field_description_generation():
|
||||
"""Test that dynamic field descriptions are generated correctly."""
|
||||
# Test dictionary field description
|
||||
desc = get_dynamic_field_description("values_#_name")
|
||||
assert "Dictionary field 'name' for base field 'values'" in desc
|
||||
assert "values['name']" in desc
|
||||
|
||||
# Test list field description
|
||||
desc = get_dynamic_field_description("items_$_0")
|
||||
assert "List item 0 for base field 'items'" in desc
|
||||
assert "items[0]" in desc
|
||||
|
||||
# Test object field description
|
||||
desc = get_dynamic_field_description("user_@_email")
|
||||
assert "Object attribute 'email' for base field 'user'" in desc
|
||||
assert "user.email" in desc
|
||||
|
||||
# Test regular field fallback
|
||||
desc = get_dynamic_field_description("regular_field")
|
||||
assert desc == "Value for regular_field"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_dict_fields():
|
||||
"""Test that function signatures are created correctly for dictionary dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___name", # Sanitized source
|
||||
sink_name="values_#_name", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___age", # Sanitized source
|
||||
sink_name="values_#_age", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___email", # Sanitized source
|
||||
sink_name="values_#_email", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
assert signature["type"] == "function"
|
||||
assert "function" in signature
|
||||
assert "parameters" in signature["function"]
|
||||
assert "properties" in signature["function"]["parameters"]
|
||||
|
||||
# Check that dynamic fields are handled with original names
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "values___name" in properties
|
||||
assert "values___age" in properties
|
||||
assert "values___email" in properties
|
||||
|
||||
# Check descriptions mention they are dictionary fields
|
||||
assert "Dictionary field" in properties["values___name"]["description"]
|
||||
assert "values['name']" in properties["values___name"]["description"]
|
||||
|
||||
assert "Dictionary field" in properties["values___age"]["description"]
|
||||
assert "values['age']" in properties["values___age"]["description"]
|
||||
|
||||
assert "Dictionary field" in properties["values___email"]["description"]
|
||||
assert "values['email']" in properties["values___email"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_list_fields():
|
||||
"""Test that function signatures are created correctly for list dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_2",
|
||||
sink_name="entries_$_2", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
assert signature["type"] == "function"
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "entries___0" in properties
|
||||
assert "entries___1" in properties
|
||||
assert "entries___2" in properties
|
||||
|
||||
# Check descriptions mention they are list items
|
||||
assert "List item 0" in properties["entries___0"]["description"]
|
||||
assert "entries[0]" in properties["entries___0"]["description"]
|
||||
|
||||
assert "List item 1" in properties["entries___1"]["description"]
|
||||
assert "entries[1]" in properties["entries___1"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_object_fields():
|
||||
"""Test that function signatures are created correctly for object dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for MatchTextPatternBlock (simulating object fields)
|
||||
mock_node = Mock()
|
||||
mock_node.block = MatchTextPatternBlock()
|
||||
mock_node.block_id = MatchTextPatternBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic object fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_name",
|
||||
sink_name="data_@_user_name", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_email",
|
||||
sink_name="data_@_user_email", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "data___user_name" in properties
|
||||
assert "data___user_email" in properties
|
||||
|
||||
# Check descriptions mention they are object attributes
|
||||
assert "Object attribute" in properties["data___user_name"]["description"]
|
||||
assert "data.user_name" in properties["data___user_name"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_function_signature():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock the database client and connected nodes
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
) as mock_db:
|
||||
mock_client = AsyncMock()
|
||||
mock_db.return_value = mock_client
|
||||
|
||||
# Create mock nodes and links
|
||||
mock_dict_node = Mock()
|
||||
mock_dict_node.block = CreateDictionaryBlock()
|
||||
mock_dict_node.block_id = CreateDictionaryBlock().id
|
||||
mock_dict_node.input_default = {}
|
||||
|
||||
mock_list_node = Mock()
|
||||
mock_list_node.block = AddToListBlock()
|
||||
mock_list_node.block_id = AddToListBlock().id
|
||||
mock_list_node.input_default = {}
|
||||
|
||||
# Mock links with dynamic fields
|
||||
dict_link1 = Mock(
|
||||
source_name="tools_^_create_dictionary_~_name",
|
||||
sink_name="values_#_name",
|
||||
sink_id="dict_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
dict_link2 = Mock(
|
||||
source_name="tools_^_create_dictionary_~_age",
|
||||
sink_name="values_#_age",
|
||||
sink_id="dict_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
list_link = Mock(
|
||||
source_name="tools_^_add_to_list_~_0",
|
||||
sink_name="entries_$_0",
|
||||
sink_id="list_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
|
||||
mock_client.get_connected_output_nodes.return_value = [
|
||||
(dict_link1, mock_dict_node),
|
||||
(dict_link2, mock_dict_node),
|
||||
(list_link, mock_list_node),
|
||||
]
|
||||
|
||||
# Call the method that builds signatures
|
||||
tool_functions = await block._create_function_signature("test_node_id")
|
||||
|
||||
# Verify we got 2 tool functions (one for dict, one for list)
|
||||
assert len(tool_functions) == 2
|
||||
|
||||
# Verify the tool functions contain the dynamic field names
|
||||
dict_tool = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == "createdictionaryblock"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert dict_tool is not None
|
||||
dict_properties = dict_tool["function"]["parameters"]["properties"]
|
||||
assert "values___name" in dict_properties
|
||||
assert "values___age" in dict_properties
|
||||
|
||||
list_tool = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == "addtolistblock"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert list_tool is not None
|
||||
list_properties = list_tool["function"]["parameters"]["properties"]
|
||||
assert "entries___0" in list_properties
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_yielding_with_dynamic_fields():
|
||||
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# No more sanitized mapping needed since we removed sanitization
|
||||
|
||||
# Mock LLM response with tool calls
|
||||
mock_response = Mock()
|
||||
mock_response.tool_calls = [
|
||||
Mock(
|
||||
function=Mock(
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"values___name": "Alice",
|
||||
"values___age": 30,
|
||||
"values___email": "alice@example.com",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
mock_response.tool_calls[0].function.name = "createdictionaryblock"
|
||||
mock_response.reasoning = "Creating a dictionary with user information"
|
||||
mock_response.raw_response = {"role": "assistant", "content": "test"}
|
||||
mock_response.prompt_tokens = 100
|
||||
mock_response.completion_tokens = 50
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "createdictionaryblock",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"values___name": {"type": "string"},
|
||||
"values___age": {"type": "number"},
|
||||
"values___email": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create input data
|
||||
from backend.blocks import llm
|
||||
|
||||
input_data = block.input_schema(
|
||||
prompt="Create a user dictionary",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
)
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
graph_id="test_graph",
|
||||
node_id="test_node",
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify the outputs use sanitized field names (matching frontend normalizeToolName)
|
||||
assert "tools_^_createdictionaryblock_~_values___name" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___name"] == "Alice"
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___age" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___age"] == 30
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___email" in outputs
|
||||
assert (
|
||||
outputs["tools_^_createdictionaryblock_~_values___email"]
|
||||
== "alice@example.com"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_regular_and_dynamic_fields():
|
||||
"""Test handling of blocks with both regular and dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node
|
||||
mock_node = Mock()
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "A test block"
|
||||
mock_node.block.input_schema = Mock()
|
||||
|
||||
# Mock the get_field_schema to return a proper schema for regular fields
|
||||
def get_field_schema(field_name):
|
||||
if field_name == "regular_field":
|
||||
return {"type": "string", "description": "A regular field"}
|
||||
elif field_name == "values":
|
||||
return {"type": "object", "description": "A dictionary field"}
|
||||
else:
|
||||
raise KeyError(f"Field {field_name} not found")
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
|
||||
# Create links with both regular and dynamic fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_test_~_regular",
|
||||
sink_name="regular_field", # Regular field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key",
|
||||
sink_name="values_#_key1", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key2",
|
||||
sink_name="values_#_key2", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Check properties
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3
|
||||
|
||||
# Regular field should have its original schema
|
||||
assert "regular_field" in properties
|
||||
assert properties["regular_field"]["description"] == "A regular field"
|
||||
|
||||
# Dynamic fields should have generated descriptions
|
||||
assert "values___key1" in properties
|
||||
assert "Dictionary field" in properties["values___key1"]["description"]
|
||||
|
||||
assert "values___key2" in properties
|
||||
assert "Dictionary field" in properties["values___key2"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_errors_dont_pollute_conversation():
|
||||
"""Test that validation errors are only used during retries and don't pollute the conversation."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Track conversation history changes
|
||||
conversation_snapshots = []
|
||||
|
||||
# Mock response with invalid tool call (missing required parameter)
|
||||
invalid_response = Mock()
|
||||
invalid_response.tool_calls = [
|
||||
Mock(
|
||||
function=Mock(
|
||||
arguments=json.dumps({"wrong_param": "value"}), # Wrong parameter name
|
||||
)
|
||||
)
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
invalid_response.tool_calls[0].function.name = "test_tool"
|
||||
invalid_response.reasoning = None
|
||||
invalid_response.raw_response = {"role": "assistant", "content": "invalid"}
|
||||
invalid_response.prompt_tokens = 100
|
||||
invalid_response.completion_tokens = 50
|
||||
|
||||
# Mock valid response after retry
|
||||
valid_response = Mock()
|
||||
valid_response.tool_calls = [
|
||||
Mock(function=Mock(arguments=json.dumps({"correct_param": "value"})))
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
valid_response.tool_calls[0].function.name = "test_tool"
|
||||
valid_response.reasoning = None
|
||||
valid_response.raw_response = {"role": "assistant", "content": "valid"}
|
||||
valid_response.prompt_tokens = 100
|
||||
valid_response.completion_tokens = 50
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
# Capture conversation state
|
||||
conversation_snapshots.append(kwargs.get("prompt", []).copy())
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return invalid_response
|
||||
else:
|
||||
return valid_response
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.side_effect = mock_llm_call
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"correct_param": {
|
||||
"type": "string",
|
||||
"description": "The correct parameter",
|
||||
}
|
||||
},
|
||||
"required": ["correct_param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create input data
|
||||
from backend.blocks import llm
|
||||
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
retry=3, # Allow retries
|
||||
)
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
graph_id="test_graph",
|
||||
node_id="test_node",
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify we had 2 LLM calls (initial + retry)
|
||||
assert call_count == 2
|
||||
|
||||
# Check the final conversation output
|
||||
final_conversation = outputs.get("conversations", [])
|
||||
|
||||
# The final conversation should NOT contain the validation error message
|
||||
error_messages = [
|
||||
msg
|
||||
for msg in final_conversation
|
||||
if msg.get("role") == "user"
|
||||
and "parameter errors" in msg.get("content", "")
|
||||
]
|
||||
assert (
|
||||
len(error_messages) == 0
|
||||
), "Validation error leaked into final conversation"
|
||||
|
||||
# The final conversation should only have the successful response
|
||||
assert final_conversation[-1]["content"] == "valid"
|
||||
@@ -1,131 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from backend.blocks.io import AgentTableInputBlock
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_block():
|
||||
"""Test the AgentTableInputBlock with basic input/output."""
|
||||
block = AgentTableInputBlock()
|
||||
await execute_block_test(block)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_data():
|
||||
"""Test AgentTableInputBlock with actual table data."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="test_table",
|
||||
column_headers=["Name", "Age", "City"],
|
||||
value=[
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
{"Name": "Bob", "Age": "35", "City": "Paris"},
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 3
|
||||
assert result[0]["Name"] == "John"
|
||||
assert result[1]["Age"] == "25"
|
||||
assert result[2]["City"] == "Paris"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_empty_data():
|
||||
"""Test AgentTableInputBlock with empty data."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="empty_table", column_headers=["Col1", "Col2"], value=[]
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
assert output_data[0][1] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_missing_columns():
|
||||
"""Test AgentTableInputBlock passes through data with missing columns as-is."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="partial_table",
|
||||
column_headers=["Name", "Age", "City"],
|
||||
value=[
|
||||
{"Name": "John", "Age": "30"}, # Missing City
|
||||
{"Name": "Jane", "City": "London"}, # Missing Age
|
||||
{"Age": "35", "City": "Paris"}, # Missing Name
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 3
|
||||
|
||||
# Check data is passed through as-is
|
||||
assert result[0] == {"Name": "John", "Age": "30"}
|
||||
assert result[1] == {"Name": "Jane", "City": "London"}
|
||||
assert result[2] == {"Age": "35", "City": "Paris"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_none_value():
|
||||
"""Test AgentTableInputBlock with None value returns empty list."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="none_table", column_headers=["Name", "Age"], value=None
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
assert output_data[0][1] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_default_headers():
|
||||
"""Test AgentTableInputBlock with default column headers."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
# Don't specify column_headers, should use defaults
|
||||
input_data = block.Input(
|
||||
name="default_headers_table",
|
||||
value=[
|
||||
{"Column 1": "A", "Column 2": "B", "Column 3": "C"},
|
||||
{"Column 1": "D", "Column 2": "E", "Column 3": "F"},
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 2
|
||||
assert result[0]["Column 1"] == "A"
|
||||
assert result[1]["Column 3"] == "F"
|
||||
@@ -2,8 +2,6 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import regex # Has built-in timeout support
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json, text
|
||||
@@ -139,11 +137,6 @@ class ExtractTextInformationBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add limits to prevent ReDoS and memory exhaustion
|
||||
MAX_TEXT_LENGTH = 1_000_000 # 1MB character limit
|
||||
MAX_MATCHES = 1000 # Maximum number of matches to prevent memory exhaustion
|
||||
MAX_MATCH_LENGTH = 10_000 # Maximum length per match
|
||||
|
||||
flags = 0
|
||||
if not input_data.case_sensitive:
|
||||
flags = flags | re.IGNORECASE
|
||||
@@ -155,85 +148,20 @@ class ExtractTextInformationBlock(Block):
|
||||
else:
|
||||
txt = json.dumps(input_data.text)
|
||||
|
||||
# Limit text size to prevent DoS
|
||||
if len(txt) > MAX_TEXT_LENGTH:
|
||||
txt = txt[:MAX_TEXT_LENGTH]
|
||||
|
||||
# Validate regex pattern to prevent dangerous patterns
|
||||
dangerous_patterns = [
|
||||
r".*\+.*\+", # Nested quantifiers
|
||||
r".*\*.*\*", # Nested quantifiers
|
||||
r"(?=.*\+)", # Lookahead with quantifier
|
||||
r"(?=.*\*)", # Lookahead with quantifier
|
||||
r"\(.+\)\+", # Group with nested quantifier
|
||||
r"\(.+\)\*", # Group with nested quantifier
|
||||
r"\([^)]+\+\)\+", # Nested quantifiers like (a+)+
|
||||
r"\([^)]+\*\)\*", # Nested quantifiers like (a*)*
|
||||
matches = [
|
||||
match.group(input_data.group)
|
||||
for match in re.finditer(input_data.pattern, txt, flags)
|
||||
if input_data.group <= len(match.groups())
|
||||
]
|
||||
|
||||
# Check if pattern is potentially dangerous
|
||||
is_dangerous = any(
|
||||
re.search(dangerous, input_data.pattern) for dangerous in dangerous_patterns
|
||||
)
|
||||
|
||||
# Use regex module with timeout for dangerous patterns
|
||||
# For safe patterns, use standard re module for compatibility
|
||||
try:
|
||||
matches = []
|
||||
match_count = 0
|
||||
|
||||
if is_dangerous:
|
||||
# Use regex module with timeout (5 seconds) for dangerous patterns
|
||||
# The regex module supports timeout parameter in finditer
|
||||
try:
|
||||
for match in regex.finditer(
|
||||
input_data.pattern, txt, flags=flags, timeout=5.0
|
||||
):
|
||||
if match_count >= MAX_MATCHES:
|
||||
break
|
||||
if input_data.group <= len(match.groups()):
|
||||
match_text = match.group(input_data.group)
|
||||
# Limit match length to prevent memory exhaustion
|
||||
if len(match_text) > MAX_MATCH_LENGTH:
|
||||
match_text = match_text[:MAX_MATCH_LENGTH]
|
||||
matches.append(match_text)
|
||||
match_count += 1
|
||||
except regex.error as e:
|
||||
# Timeout occurred or regex error
|
||||
if "timeout" in str(e).lower():
|
||||
# Timeout - return empty results
|
||||
pass
|
||||
else:
|
||||
# Other regex error
|
||||
raise
|
||||
else:
|
||||
# Use standard re module for non-dangerous patterns
|
||||
for match in re.finditer(input_data.pattern, txt, flags):
|
||||
if match_count >= MAX_MATCHES:
|
||||
break
|
||||
if input_data.group <= len(match.groups()):
|
||||
match_text = match.group(input_data.group)
|
||||
# Limit match length to prevent memory exhaustion
|
||||
if len(match_text) > MAX_MATCH_LENGTH:
|
||||
match_text = match_text[:MAX_MATCH_LENGTH]
|
||||
matches.append(match_text)
|
||||
match_count += 1
|
||||
|
||||
if not input_data.find_all:
|
||||
matches = matches[:1]
|
||||
|
||||
for match in matches:
|
||||
yield "positive", match
|
||||
if not matches:
|
||||
yield "negative", input_data.text
|
||||
|
||||
yield "matched_results", matches
|
||||
yield "matched_count", len(matches)
|
||||
except Exception:
|
||||
# Return empty results on any regex error
|
||||
if not input_data.find_all:
|
||||
matches = matches[:1]
|
||||
for match in matches:
|
||||
yield "positive", match
|
||||
if not matches:
|
||||
yield "negative", input_data.text
|
||||
yield "matched_results", []
|
||||
yield "matched_count", 0
|
||||
|
||||
yield "matched_results", matches
|
||||
yield "matched_count", len(matches)
|
||||
|
||||
|
||||
class FillTextTemplateBlock(Block):
|
||||
|
||||
@@ -270,17 +270,13 @@ class GetCurrentDateBlock(Block):
|
||||
test_output=[
|
||||
(
|
||||
"date",
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%Y-%m-%d").date()
|
||||
)
|
||||
<= timedelta(days=8), # 7 days difference + 1 day error margin.
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%Y-%m-%d"))
|
||||
< timedelta(days=8), # 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
"date",
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%m/%d/%Y").date()
|
||||
)
|
||||
<= timedelta(days=8),
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%m/%d/%Y"))
|
||||
< timedelta(days=8),
|
||||
# 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
@@ -386,7 +382,7 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%Y/%m/%d").date()
|
||||
)
|
||||
<= timedelta(days=1), # Date format only, no time component
|
||||
< timedelta(days=1), # Date format only, no time component
|
||||
),
|
||||
(
|
||||
"date_time",
|
||||
|
||||
@@ -26,14 +26,6 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add size limits to prevent XML bomb attacks
|
||||
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
|
||||
|
||||
if len(input_data.input_xml) > MAX_XML_SIZE:
|
||||
raise ValueError(
|
||||
f"XML too large: {len(input_data.input_xml)} bytes > {MAX_XML_SIZE} bytes"
|
||||
)
|
||||
|
||||
try:
|
||||
tokens = tokenize(input_data.input_xml)
|
||||
parser = Parser(tokens)
|
||||
|
||||
@@ -9,7 +9,6 @@ from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -179,13 +178,9 @@ async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
async def list_user_api_keys(
|
||||
user_id: str, limit: int = MAX_USER_API_KEYS_FETCH
|
||||
) -> list[APIKeyInfo]:
|
||||
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyInfo.from_db(key) for key in api_keys]
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import (
|
||||
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
@@ -27,7 +28,6 @@ from pydantic import BaseModel
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.cache import cached
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .model import (
|
||||
@@ -722,7 +722,7 @@ def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
return cls() if cls else None
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
@cached()
|
||||
def get_webhook_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
@@ -731,7 +731,7 @@ def get_webhook_block_ids() -> Sequence[str]:
|
||||
]
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
@cached()
|
||||
def get_io_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
|
||||
@@ -69,7 +69,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
|
||||
@@ -23,7 +23,6 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -906,9 +905,7 @@ class UserCredit(UserCreditBase):
|
||||
),
|
||||
)
|
||||
|
||||
async def get_refund_requests(
|
||||
self, user_id: str, limit: int = MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
) -> list[RefundRequest]:
|
||||
async def get_refund_requests(self, user_id: str) -> list[RefundRequest]:
|
||||
return [
|
||||
RefundRequest(
|
||||
id=r.id,
|
||||
@@ -924,7 +921,6 @@ class UserCredit(UserCreditBase):
|
||||
for r in await CreditRefundRequest.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
"""
|
||||
Utilities for handling dynamic field names with special delimiters.
|
||||
|
||||
Dynamic fields allow graphs to connect complex data structures using special delimiters:
|
||||
- _#_ for dictionary keys (e.g., "values_#_name" → values["name"])
|
||||
- _$_ for list indices (e.g., "items_$_0" → items[0])
|
||||
- _@_ for object attributes (e.g., "obj_@_attr" → obj.attr)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
# Dynamic field delimiters
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
DYNAMIC_DELIMITERS = (LIST_SPLIT, DICT_SPLIT, OBJC_SPLIT)
|
||||
|
||||
|
||||
def extract_base_field_name(field_name: str) -> str:
|
||||
"""
|
||||
Extract the base field name from a dynamic field name by removing all dynamic suffixes.
|
||||
|
||||
Examples:
|
||||
extract_base_field_name("values_#_name") → "values"
|
||||
extract_base_field_name("items_$_0") → "items"
|
||||
extract_base_field_name("obj_@_attr") → "obj"
|
||||
extract_base_field_name("regular_field") → "regular_field"
|
||||
|
||||
Args:
|
||||
field_name: The field name that may contain dynamic delimiters
|
||||
|
||||
Returns:
|
||||
The base field name without any dynamic suffixes
|
||||
"""
|
||||
base_name = field_name
|
||||
for delimiter in DYNAMIC_DELIMITERS:
|
||||
if delimiter in base_name:
|
||||
base_name = base_name.split(delimiter)[0]
|
||||
return base_name
|
||||
|
||||
|
||||
def is_dynamic_field(field_name: str) -> bool:
|
||||
"""
|
||||
Check if a field name contains dynamic delimiters.
|
||||
|
||||
Args:
|
||||
field_name: The field name to check
|
||||
|
||||
Returns:
|
||||
True if the field contains any dynamic delimiters, False otherwise
|
||||
"""
|
||||
return any(delimiter in field_name for delimiter in DYNAMIC_DELIMITERS)
|
||||
|
||||
|
||||
def get_dynamic_field_description(field_name: str) -> str:
|
||||
"""
|
||||
Generate a description for a dynamic field based on its structure.
|
||||
|
||||
Args:
|
||||
field_name: The full dynamic field name (e.g., "values_#_name")
|
||||
|
||||
Returns:
|
||||
A descriptive string explaining what this dynamic field represents
|
||||
"""
|
||||
base_name = extract_base_field_name(field_name)
|
||||
|
||||
if DICT_SPLIT in field_name:
|
||||
# Extract the key part after _#_
|
||||
parts = field_name.split(DICT_SPLIT)
|
||||
if len(parts) > 1:
|
||||
key = parts[1].split("_")[0] if "_" in parts[1] else parts[1]
|
||||
return f"Dictionary field '{key}' for base field '{base_name}' ({base_name}['{key}'])"
|
||||
elif LIST_SPLIT in field_name:
|
||||
# Extract the index part after _$_
|
||||
parts = field_name.split(LIST_SPLIT)
|
||||
if len(parts) > 1:
|
||||
index = parts[1].split("_")[0] if "_" in parts[1] else parts[1]
|
||||
return (
|
||||
f"List item {index} for base field '{base_name}' ({base_name}[{index}])"
|
||||
)
|
||||
elif OBJC_SPLIT in field_name:
|
||||
# Extract the attribute part after _@_
|
||||
parts = field_name.split(OBJC_SPLIT)
|
||||
if len(parts) > 1:
|
||||
# Get the full attribute name (everything after _@_)
|
||||
attr = parts[1]
|
||||
return f"Object attribute '{attr}' for base field '{base_name}' ({base_name}.{attr})"
|
||||
|
||||
return f"Value for {field_name}"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Dynamic field parsing and merging utilities
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _next_delim(s: str) -> tuple[str | None, int]:
|
||||
"""
|
||||
Return the *earliest* delimiter appearing in `s` and its index.
|
||||
|
||||
If none present → (None, -1).
|
||||
"""
|
||||
first: str | None = None
|
||||
pos = len(s) # sentinel: larger than any real index
|
||||
for d in DYNAMIC_DELIMITERS:
|
||||
i = s.find(d)
|
||||
if 0 <= i < pos:
|
||||
first, pos = d, i
|
||||
return first, (pos if first else -1)
|
||||
|
||||
|
||||
def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Convert the raw path string (starting with a delimiter) into
|
||||
[ (delimiter, identifier), … ] or None if the syntax is malformed.
|
||||
"""
|
||||
tokens: list[tuple[str, str]] = []
|
||||
while path:
|
||||
# 1. Which delimiter starts this chunk?
|
||||
delim = next((d for d in DYNAMIC_DELIMITERS if path.startswith(d)), None)
|
||||
if delim is None:
|
||||
return None # invalid syntax
|
||||
|
||||
# 2. Slice off the delimiter, then up to the next delimiter (or EOS)
|
||||
path = path[len(delim) :]
|
||||
nxt_delim, pos = _next_delim(path)
|
||||
token, path = (
|
||||
path[: pos if pos != -1 else len(path)],
|
||||
path[pos if pos != -1 else len(path) :],
|
||||
)
|
||||
if token == "":
|
||||
return None # empty identifier is invalid
|
||||
tokens.append((delim, token))
|
||||
return tokens
|
||||
|
||||
|
||||
def parse_execution_output(output: tuple[str, Any], name: str) -> Any:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path)
|
||||
returns **None**.
|
||||
|
||||
Args:
|
||||
output: Tuple of (base_name, data) representing a block output entry
|
||||
name: The flattened field name to extract from the output data
|
||||
|
||||
Returns:
|
||||
The value at the specified path, or None if not found/invalid
|
||||
"""
|
||||
base_name, data = output
|
||||
|
||||
# Exact match → whole object
|
||||
if name == base_name:
|
||||
return data
|
||||
|
||||
# Must start with the expected name
|
||||
if not name.startswith(base_name):
|
||||
return None
|
||||
path = name[len(base_name) :]
|
||||
if not path:
|
||||
return None # nothing left to parse
|
||||
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: Any = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
return None
|
||||
if not isinstance(cur, list) or idx >= len(cur):
|
||||
return None
|
||||
cur = cur[idx]
|
||||
|
||||
elif delim == DICT_SPLIT:
|
||||
if not isinstance(cur, dict) or ident not in cur:
|
||||
return None
|
||||
cur = cur[ident]
|
||||
|
||||
elif delim == OBJC_SPLIT:
|
||||
if not hasattr(cur, ident):
|
||||
return None
|
||||
cur = getattr(cur, ident)
|
||||
|
||||
else:
|
||||
return None # unreachable
|
||||
|
||||
return cur
|
||||
|
||||
|
||||
def _assign(container: Any, tokens: list[tuple[str, str]], value: Any) -> Any:
|
||||
"""
|
||||
Recursive helper that *returns* the (possibly new) container with
|
||||
`value` assigned along the remaining `tokens` path.
|
||||
"""
|
||||
if not tokens:
|
||||
return value # leaf reached
|
||||
|
||||
delim, ident = tokens[0]
|
||||
rest = tokens[1:]
|
||||
|
||||
# ---------- list ----------
|
||||
if delim == LIST_SPLIT:
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
raise ValueError("index must be an integer")
|
||||
|
||||
if container is None:
|
||||
container = []
|
||||
elif not isinstance(container, list):
|
||||
container = list(container) if hasattr(container, "__iter__") else []
|
||||
|
||||
while len(container) <= idx:
|
||||
container.append(None)
|
||||
container[idx] = _assign(container[idx], rest, value)
|
||||
return container
|
||||
|
||||
# ---------- dict ----------
|
||||
if delim == DICT_SPLIT:
|
||||
if container is None:
|
||||
container = {}
|
||||
elif not isinstance(container, dict):
|
||||
container = dict(container) if hasattr(container, "items") else {}
|
||||
container[ident] = _assign(container.get(ident), rest, value)
|
||||
return container
|
||||
|
||||
# ---------- object ----------
|
||||
if delim == OBJC_SPLIT:
|
||||
if container is None:
|
||||
container = MockObject()
|
||||
elif not hasattr(container, "__dict__"):
|
||||
# If it's not an object, create a new one
|
||||
container = MockObject()
|
||||
setattr(
|
||||
container,
|
||||
ident,
|
||||
_assign(getattr(container, ident, None), rest, value),
|
||||
)
|
||||
return container
|
||||
|
||||
return value # unreachable
|
||||
|
||||
|
||||
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Reconstruct nested objects from a *flattened* dict of key → value.
|
||||
|
||||
Raises ValueError on syntactically invalid list indices.
|
||||
|
||||
Args:
|
||||
data: Dictionary with potentially flattened dynamic field keys
|
||||
|
||||
Returns:
|
||||
Dictionary with nested objects reconstructed from flattened keys
|
||||
"""
|
||||
merged: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
# Split off the base name (before the first delimiter, if any)
|
||||
delim, pos = _next_delim(key)
|
||||
if delim is None:
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
base, path = key[:pos], key[pos:]
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
# Invalid key; treat as scalar under the raw name
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
merged[base] = _assign(merged.get(base), tokens, value)
|
||||
|
||||
data.update(merged)
|
||||
return data
|
||||
@@ -20,8 +20,6 @@ from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.dynamic_fields import extract_base_field_name
|
||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
@@ -33,15 +31,7 @@ from backend.util import type as type_utils
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .block import (
|
||||
Block,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
EmptySchema,
|
||||
get_block,
|
||||
get_blocks,
|
||||
)
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
@@ -82,15 +72,12 @@ class Node(BaseDbModel):
|
||||
output_links: list[Link] = []
|
||||
|
||||
@property
|
||||
def block(self) -> "Block[BlockSchema, BlockSchema] | _UnknownBlockBase":
|
||||
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
||||
def block(self) -> Block[BlockSchema, BlockSchema]:
|
||||
block = get_block(self.block_id)
|
||||
if not block:
|
||||
# Log warning but don't raise exception - return a placeholder block for deleted blocks
|
||||
logger.warning(
|
||||
f"Block #{self.block_id} does not exist for Node #{self.id} (deleted/missing block), using UnknownBlock"
|
||||
raise ValueError(
|
||||
f"Block #{self.block_id} does not exist -> Node #{self.id} is invalid"
|
||||
)
|
||||
return _UnknownBlockBase(self.block_id)
|
||||
return block
|
||||
|
||||
|
||||
@@ -742,7 +729,7 @@ def _is_tool_pin(name: str) -> bool:
|
||||
|
||||
|
||||
def _sanitize_pin_name(name: str) -> str:
|
||||
sanitized_name = extract_base_field_name(name)
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
if _is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
@@ -1072,14 +1059,11 @@ async def set_graph_active_version(graph_id: str, version: int, user_id: str) ->
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: str, limit: int = MAX_GRAPH_VERSIONS_FETCH
|
||||
) -> list[GraphModel]:
|
||||
async def get_graph_all_versions(graph_id: str, user_id: str) -> list[GraphModel]:
|
||||
graph_versions = await AgentGraph.prisma().find_many(
|
||||
where={"id": graph_id, "userId": user_id},
|
||||
order={"version": "desc"},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
take=limit,
|
||||
)
|
||||
|
||||
if not graph_versions:
|
||||
@@ -1328,34 +1312,3 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
id,
|
||||
path,
|
||||
)
|
||||
|
||||
|
||||
# Simple placeholder class for deleted/missing blocks
|
||||
class _UnknownBlockBase(Block):
|
||||
"""
|
||||
Placeholder for deleted/missing blocks that inherits from Block
|
||||
but uses a name that doesn't end with 'Block' to avoid auto-discovery.
|
||||
"""
|
||||
|
||||
def __init__(self, block_id: str = "00000000-0000-0000-0000-000000000000"):
|
||||
# Initialize with minimal valid Block parameters
|
||||
super().__init__(
|
||||
id=block_id,
|
||||
description=f"Unknown or deleted block (original ID: {block_id})",
|
||||
disabled=True,
|
||||
input_schema=EmptySchema,
|
||||
output_schema=EmptySchema,
|
||||
categories=set(),
|
||||
contributors=[],
|
||||
static_output=False,
|
||||
block_type=BlockType.STANDARD,
|
||||
webhook_config=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "UnknownBlock"
|
||||
|
||||
async def run(self, input_data, **kwargs):
|
||||
"""Always yield an error for missing blocks."""
|
||||
yield "error", f"Block {self.id} no longer exists"
|
||||
|
||||
@@ -14,7 +14,6 @@ AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
|
||||
EXECUTION_RESULT_ORDER: list[prisma.types.AgentNodeExecutionOrderByInput] = [
|
||||
{"queuedTime": "desc"},
|
||||
# Fallback: Incomplete execs has no queuedTime.
|
||||
@@ -29,13 +28,6 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
MAX_LIBRARY_AGENT_EXECUTIONS_FETCH = 10
|
||||
|
||||
# Default limits for potentially large result sets
|
||||
MAX_CREDIT_REFUND_REQUESTS_FETCH = 100
|
||||
MAX_INTEGRATION_WEBHOOKS_FETCH = 100
|
||||
MAX_USER_API_KEYS_FETCH = 500
|
||||
MAX_GRAPH_VERSIONS_FETCH = 50
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
"NodeExecutions": {
|
||||
@@ -79,56 +71,13 @@ INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
}
|
||||
|
||||
|
||||
def library_agent_include(
|
||||
user_id: str,
|
||||
include_nodes: bool = True,
|
||||
include_executions: bool = True,
|
||||
execution_limit: int = MAX_LIBRARY_AGENT_EXECUTIONS_FETCH,
|
||||
) -> prisma.types.LibraryAgentInclude:
|
||||
"""
|
||||
Fully configurable includes for library agent queries with performance optimization.
|
||||
|
||||
Args:
|
||||
user_id: User ID for filtering user-specific data
|
||||
include_nodes: Whether to include graph nodes (default: True, needed for get_sub_graphs)
|
||||
include_executions: Whether to include executions (default: True, safe with execution_limit)
|
||||
execution_limit: Limit on executions to fetch (default: MAX_LIBRARY_AGENT_EXECUTIONS_FETCH)
|
||||
|
||||
Defaults maintain backward compatibility and safety - includes everything needed for all functionality.
|
||||
For performance optimization, explicitly set include_nodes=False and include_executions=False
|
||||
for listing views where frontend fetches data separately.
|
||||
|
||||
Performance impact:
|
||||
- Default (full nodes + limited executions): Original performance, works everywhere
|
||||
- Listing optimization (no nodes/executions): ~2s for 15 agents vs potential timeouts
|
||||
- Unlimited executions: varies by user (thousands of executions = timeouts)
|
||||
"""
|
||||
result: prisma.types.LibraryAgentInclude = {
|
||||
"Creator": True, # Always needed for creator info
|
||||
}
|
||||
|
||||
# Build AgentGraph include based on requested options
|
||||
if include_nodes or include_executions:
|
||||
agent_graph_include = {}
|
||||
|
||||
# Add nodes if requested (always full nodes)
|
||||
if include_nodes:
|
||||
agent_graph_include.update(AGENT_GRAPH_INCLUDE) # Full nodes
|
||||
|
||||
# Add executions if requested
|
||||
if include_executions:
|
||||
agent_graph_include["Executions"] = {
|
||||
"where": {"userId": user_id},
|
||||
"order_by": {"createdAt": "desc"},
|
||||
"take": execution_limit,
|
||||
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
|
||||
return {
|
||||
"AgentGraph": {
|
||||
"include": {
|
||||
**AGENT_GRAPH_INCLUDE,
|
||||
"Executions": {"where": {"userId": user_id}},
|
||||
}
|
||||
|
||||
result["AgentGraph"] = cast(
|
||||
prisma.types.AgentGraphArgsFromLibraryAgent,
|
||||
{"include": agent_graph_include},
|
||||
)
|
||||
else:
|
||||
# Default: Basic metadata only (fast - recommended for most use cases)
|
||||
result["AgentGraph"] = True # Basic graph metadata (name, description, id)
|
||||
|
||||
return result
|
||||
},
|
||||
"Creator": True,
|
||||
}
|
||||
|
||||
@@ -11,10 +11,7 @@ from prisma.types import (
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.includes import (
|
||||
INTEGRATION_WEBHOOK_INCLUDE,
|
||||
MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
)
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
@@ -131,36 +128,22 @@ async def get_webhook(
|
||||
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: Literal[True],
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[True]
|
||||
) -> list[WebhookWithRelations]: ...
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: Literal[False] = False,
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
|
||||
) -> list[Webhook]: ...
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: bool = False,
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
user_id: str, credentials_id: str, *, include_relations: bool = False
|
||||
) -> list[Webhook] | list[WebhookWithRelations]:
|
||||
if not credentials_id:
|
||||
raise ValueError("credentials_id must not be empty")
|
||||
webhooks = await IntegrationWebhook.prisma().find_many(
|
||||
where={"userId": user_id, "credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [
|
||||
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
|
||||
@@ -270,7 +270,6 @@ def SchemaField(
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
json_schema_extra: Optional[dict[str, Any]] = None,
|
||||
) -> T:
|
||||
if default is PydanticUndefined and default_factory is None:
|
||||
@@ -286,7 +285,6 @@ def SchemaField(
|
||||
"advanced": advanced,
|
||||
"hidden": hidden,
|
||||
"depends_on": depends_on,
|
||||
"format": format,
|
||||
**(json_schema_extra or {}),
|
||||
}.items()
|
||||
if v is not None
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma
|
||||
import pydantic
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
@@ -32,7 +30,7 @@ user_credit = get_user_credit_model()
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
walletShown: Optional[bool] = None
|
||||
notificationDot: Optional[bool] = None
|
||||
notified: Optional[list[OnboardingStep]] = None
|
||||
usageReason: Optional[str] = None
|
||||
integrations: Optional[list[str]] = None
|
||||
@@ -41,8 +39,6 @@ class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
agentInput: Optional[dict[str, Any]] = None
|
||||
onboardingAgentExecutionId: Optional[str] = None
|
||||
agentRuns: Optional[int] = None
|
||||
lastRunAt: Optional[datetime] = None
|
||||
consecutiveRunDays: Optional[int] = None
|
||||
|
||||
|
||||
async def get_user_onboarding(user_id: str):
|
||||
@@ -61,22 +57,16 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["completedSteps"] = list(set(data.completedSteps))
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.MARKETPLACE_VISIT,
|
||||
OnboardingStep.RUN_AGENTS,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.RE_RUN_AGENT,
|
||||
OnboardingStep.SCHEDULE_AGENT,
|
||||
OnboardingStep.RUN_AGENTS,
|
||||
OnboardingStep.RUN_3_DAYS,
|
||||
OnboardingStep.TRIGGER_WEBHOOK,
|
||||
OnboardingStep.RUN_14_DAYS,
|
||||
OnboardingStep.RUN_AGENTS_100,
|
||||
OnboardingStep.BUILDER_RUN_AGENT,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step)
|
||||
if data.walletShown is not None:
|
||||
update["walletShown"] = data.walletShown
|
||||
if data.notificationDot is not None:
|
||||
update["notificationDot"] = data.notificationDot
|
||||
if data.notified is not None:
|
||||
update["notified"] = list(set(data.notified))
|
||||
if data.usageReason is not None:
|
||||
@@ -93,10 +83,6 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
if data.agentRuns is not None:
|
||||
update["agentRuns"] = data.agentRuns
|
||||
if data.lastRunAt is not None:
|
||||
update["lastRunAt"] = data.lastRunAt
|
||||
if data.consecutiveRunDays is not None:
|
||||
update["consecutiveRunDays"] = data.consecutiveRunDays
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
@@ -115,28 +101,16 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
# This is seen as a reward for the GET_RESULTS step in the wallet
|
||||
case OnboardingStep.AGENT_NEW_RUN:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_VISIT:
|
||||
reward = 100
|
||||
case OnboardingStep.RUN_AGENTS:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_ADD_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_SAVE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.RE_RUN_AGENT:
|
||||
case OnboardingStep.BUILDER_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.SCHEDULE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.RUN_AGENTS:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_3_DAYS:
|
||||
reward = 100
|
||||
case OnboardingStep.TRIGGER_WEBHOOK:
|
||||
reward = 100
|
||||
case OnboardingStep.RUN_14_DAYS:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_AGENTS_100:
|
||||
reward = 300
|
||||
|
||||
if reward == 0:
|
||||
return
|
||||
@@ -158,22 +132,6 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
)
|
||||
|
||||
|
||||
async def complete_webhook_trigger_step(user_id: str):
|
||||
"""
|
||||
Completes the TRIGGER_WEBHOOK onboarding step for the user if not already completed.
|
||||
"""
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
if OnboardingStep.TRIGGER_WEBHOOK not in onboarding.completedSteps:
|
||||
await update_user_onboarding(
|
||||
user_id,
|
||||
UserOnboardingUpdate(
|
||||
completedSteps=onboarding.completedSteps
|
||||
+ [OnboardingStep.TRIGGER_WEBHOOK]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
"""
|
||||
Removes all special characters from a string, truncates it to 100 characters,
|
||||
@@ -375,13 +333,8 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
]
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl_seconds=300) # Cache for 5 minutes since this rarely changes
|
||||
async def onboarding_enabled() -> bool:
|
||||
"""
|
||||
Check if onboarding should be enabled based on store agent count.
|
||||
Cached to prevent repeated slow database queries.
|
||||
"""
|
||||
# Use a more efficient query that stops counting after finding enough agents
|
||||
count = await prisma.models.StoreAgent.prisma().count(take=MIN_AGENT_COUNT + 1)
|
||||
# Onboarding is enabled if there are at least 2 agents in the store
|
||||
|
||||
# Onboading is enabled if there are at least 2 agents in the store
|
||||
return count >= MIN_AGENT_COUNT
|
||||
|
||||
@@ -1,24 +1,29 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from autogpt_libs.utils.cache import cached, thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
from backend.util.cache import cached, thread_cached
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
load_dotenv()
|
||||
|
||||
HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring connection")
|
||||
def connect(decode_responses: bool = True) -> Redis:
|
||||
def connect() -> Redis:
|
||||
c = Redis(
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=decode_responses,
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
)
|
||||
c.ping()
|
||||
return c
|
||||
@@ -29,7 +34,7 @@ def disconnect():
|
||||
get_redis().close()
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
@cached()
|
||||
def get_redis() -> Redis:
|
||||
return connect()
|
||||
|
||||
@@ -37,9 +42,9 @@ def get_redis() -> Redis:
|
||||
@conn_retry("AsyncRedis", "Acquiring connection")
|
||||
async def connect_async() -> AsyncRedis:
|
||||
c = AsyncRedis(
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
)
|
||||
await c.ping()
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Optional, cast
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import HTTPException
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User as PrismaUser
|
||||
@@ -16,7 +17,6 @@ from backend.data.db import prisma
|
||||
from backend.data.model import User, UserIntegrations, UserMetadata
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.cache import cached
|
||||
from backend.util.encryption import JSONCryptor
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -4,12 +4,7 @@ Module for generating AI-based activity status for graph executions.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
try:
|
||||
from typing import NotRequired
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -151,35 +146,17 @@ async def generate_activity_status_for_execution(
|
||||
"Focus on the ACTUAL TASK the user wanted done, not the internal workflow steps. "
|
||||
"Avoid technical terms like 'workflow', 'execution', 'components', 'nodes', 'processing', etc. "
|
||||
"Keep it to 3 sentences maximum. Be conversational and human-friendly.\n\n"
|
||||
"UNDERSTAND THE INTENDED PURPOSE:\n"
|
||||
"- FIRST: Read the graph description carefully to understand what the user wanted to accomplish\n"
|
||||
"- The graph name and description tell you the main goal/intention of this automation\n"
|
||||
"- Use this intended purpose as your PRIMARY criteria for success/failure evaluation\n"
|
||||
"- Ask yourself: 'Did this execution actually accomplish what the graph was designed to do?'\n\n"
|
||||
"CRITICAL OUTPUT ANALYSIS:\n"
|
||||
"- Check if blocks that should produce user-facing results actually produced outputs\n"
|
||||
"- Blocks with names containing 'Output', 'Post', 'Create', 'Send', 'Publish', 'Generate' are usually meant to produce final results\n"
|
||||
"- If these critical blocks have NO outputs (empty recent_outputs), the task likely FAILED even if status shows 'completed'\n"
|
||||
"- Sub-agents (AgentExecutorBlock) that produce no outputs usually indicate failed sub-tasks\n"
|
||||
"- Most importantly: Does the execution result match what the graph description promised to deliver?\n\n"
|
||||
"SUCCESS EVALUATION BASED ON INTENTION:\n"
|
||||
"- If the graph is meant to 'create blog posts' → check if blog content was actually created\n"
|
||||
"- If the graph is meant to 'send emails' → check if emails were actually sent\n"
|
||||
"- If the graph is meant to 'analyze data' → check if analysis results were produced\n"
|
||||
"- If the graph is meant to 'generate reports' → check if reports were generated\n"
|
||||
"- Technical completion ≠ goal achievement. Focus on whether the USER'S INTENDED OUTCOME was delivered\n\n"
|
||||
"IMPORTANT: Be HONEST about what actually happened:\n"
|
||||
"- If the input was invalid/nonsensical, say so directly\n"
|
||||
"- If the task failed, explain what went wrong in simple terms\n"
|
||||
"- If errors occurred, focus on what the user needs to know\n"
|
||||
"- Only claim success if the INTENDED PURPOSE was genuinely accomplished AND produced expected outputs\n"
|
||||
"- Don't sugar-coat failures or present them as helpful feedback\n"
|
||||
"- ESPECIALLY: If the graph's main purpose wasn't achieved, this is a failure regardless of 'completed' status\n\n"
|
||||
"- Only claim success if the task was genuinely completed\n"
|
||||
"- Don't sugar-coat failures or present them as helpful feedback\n\n"
|
||||
"Understanding Errors:\n"
|
||||
"- Node errors: Individual steps may fail but the overall task might still complete (e.g., one data source fails but others work)\n"
|
||||
"- Graph error (in overall_status.graph_error): This means the entire execution failed and nothing was accomplished\n"
|
||||
"- Missing outputs from critical blocks: Even if no errors, this means the task failed to produce expected results\n"
|
||||
"- Focus on whether the graph's intended purpose was fulfilled, not whether technical steps completed"
|
||||
"- Even if execution shows 'completed', check if critical nodes failed that would prevent the desired outcome\n"
|
||||
"- Focus on the end result the user wanted, not whether technical steps completed"
|
||||
),
|
||||
},
|
||||
{
|
||||
@@ -188,28 +165,15 @@ async def generate_activity_status_for_execution(
|
||||
f"A user ran '{graph_name}' to accomplish something. Based on this execution data, "
|
||||
f"write what they achieved in simple, user-friendly terms:\n\n"
|
||||
f"{json.dumps(execution_data, indent=2)}\n\n"
|
||||
"ANALYSIS CHECKLIST:\n"
|
||||
"1. READ graph_info.description FIRST - this tells you what the user intended to accomplish\n"
|
||||
"2. Check overall_status.graph_error - if present, the entire execution failed\n"
|
||||
"3. Look for nodes with 'Output', 'Post', 'Create', 'Send', 'Publish', 'Generate' in their block_name\n"
|
||||
"4. Check if these critical blocks have empty recent_outputs arrays - this indicates failure\n"
|
||||
"5. Look for AgentExecutorBlock (sub-agents) with no outputs - this suggests sub-task failures\n"
|
||||
"6. Count how many nodes produced outputs vs total nodes - low ratio suggests problems\n"
|
||||
"7. MOST IMPORTANT: Does the execution outcome match what graph_info.description promised?\n\n"
|
||||
"INTENTION-BASED EVALUATION:\n"
|
||||
"- If description mentions 'blog writing' → did it create blog content?\n"
|
||||
"- If description mentions 'email automation' → were emails actually sent?\n"
|
||||
"- If description mentions 'data analysis' → were analysis results produced?\n"
|
||||
"- If description mentions 'content generation' → was content actually generated?\n"
|
||||
"- If description mentions 'social media posting' → were posts actually made?\n"
|
||||
"- Match the outputs to the stated intention, not just technical completion\n\n"
|
||||
"CRITICAL: Check overall_status.graph_error FIRST - if present, the entire execution failed.\n"
|
||||
"Then check individual node errors to understand partial failures.\n\n"
|
||||
"Write 1-3 sentences about what the user accomplished, such as:\n"
|
||||
"- 'I analyzed your resume and provided detailed feedback for the IT industry.'\n"
|
||||
"- 'I couldn't complete the task because critical steps failed to produce any results.'\n"
|
||||
"- 'I failed to generate the content you requested due to missing API access.'\n"
|
||||
"- 'I couldn't analyze your resume because the input was just nonsensical text.'\n"
|
||||
"- 'I failed to complete the task due to missing API access.'\n"
|
||||
"- 'I extracted key information from your documents and organized it into a summary.'\n"
|
||||
"- 'The task failed because the blog post creation step didn't produce any output.'\n\n"
|
||||
"BE CRITICAL: If the graph's intended purpose (from description) wasn't achieved, report this as a failure even if status is 'completed'."
|
||||
"- 'The task failed to run due to system configuration issues.'\n\n"
|
||||
"Focus on what ACTUALLY happened, not what was attempted."
|
||||
),
|
||||
},
|
||||
]
|
||||
@@ -233,7 +197,6 @@ async def generate_activity_status_for_execution(
|
||||
logger.debug(
|
||||
f"Generated activity status for {graph_exec_id}: {activity_status}"
|
||||
)
|
||||
|
||||
return activity_status
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClusterLock:
|
||||
"""Simple Redis-based distributed lock for preventing duplicate execution."""
|
||||
|
||||
def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300):
|
||||
self.redis = redis
|
||||
self.key = key
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
|
||||
def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
|
||||
Returns:
|
||||
- owner_id (self.owner_id) if successfully acquired
|
||||
- different owner_id if someone else holds the lock
|
||||
- None if Redis is unavailable or other error
|
||||
"""
|
||||
try:
|
||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||
if success:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
current_value = self.redis.get(self.key)
|
||||
if current_value:
|
||||
current_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
return current_owner
|
||||
|
||||
# Key doesn't exist but we failed to set it - race condition or Redis issue
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.try_acquire failed for key {self.key}: {e}")
|
||||
return None
|
||||
|
||||
def refresh(self) -> bool:
|
||||
"""Refresh lock TTL if we still own it.
|
||||
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
is_rate_limited = (
|
||||
self._last_refresh > 0
|
||||
and (current_time - self._last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = self.redis.get(self.key)
|
||||
if not current_value:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
if is_rate_limited:
|
||||
return True
|
||||
|
||||
# Perform actual refresh
|
||||
if self.redis.expire(self.key, self.timeout):
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release the lock."""
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._last_refresh = 0.0
|
||||
@@ -1,509 +0,0 @@
|
||||
"""
|
||||
Integration tests for ClusterLock - Redis-based distributed locking.
|
||||
|
||||
Tests the complete lock lifecycle without mocking Redis to ensure
|
||||
real-world behavior is correct. Covers acquisition, refresh, expiry,
|
||||
contention, and error scenarios.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from threading import Thread
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_client():
|
||||
"""Get Redis client for testing using same config as backend."""
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Use same config as backend but without decode_responses since ClusterLock needs raw bytes
|
||||
client = redis.Redis(
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=False, # ClusterLock needs raw bytes for ownership verification
|
||||
)
|
||||
|
||||
# Clean up any existing test keys
|
||||
try:
|
||||
for key in client.scan_iter(match="test_lock:*"):
|
||||
client.delete(key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lock_key():
|
||||
"""Generate unique lock key for each test."""
|
||||
return f"test_lock:{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def owner_id():
|
||||
"""Generate unique owner ID for each test."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestClusterLockBasic:
|
||||
"""Basic lock acquisition and release functionality."""
|
||||
|
||||
def test_lock_acquisition_success(self, redis_client, lock_key, owner_id):
|
||||
"""Test basic lock acquisition succeeds."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Lock should be acquired successfully
|
||||
result = lock.try_acquire()
|
||||
assert result == owner_id # Returns our owner_id when successfully acquired
|
||||
assert lock._last_refresh > 0
|
||||
|
||||
# Lock key should exist in Redis
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
assert redis_client.get(lock_key).decode("utf-8") == owner_id
|
||||
|
||||
def test_lock_acquisition_contention(self, redis_client, lock_key):
|
||||
"""Test second acquisition fails when lock is held."""
|
||||
owner1 = str(uuid.uuid4())
|
||||
owner2 = str(uuid.uuid4())
|
||||
|
||||
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=60)
|
||||
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=60)
|
||||
|
||||
# First lock should succeed
|
||||
result1 = lock1.try_acquire()
|
||||
assert result1 == owner1 # Successfully acquired, returns our owner_id
|
||||
|
||||
# Second lock should fail and return the first owner
|
||||
result2 = lock2.try_acquire()
|
||||
assert result2 == owner1 # Returns the current owner (first owner)
|
||||
assert lock2._last_refresh == 0
|
||||
|
||||
def test_lock_release_deletes_redis_key(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock release deletes Redis key and marks locally as released."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
assert lock._last_refresh > 0
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
|
||||
# Release should delete Redis key and mark locally as released
|
||||
lock.release()
|
||||
assert lock._last_refresh == 0
|
||||
assert lock._last_refresh == 0.0
|
||||
|
||||
# Redis key should be deleted for immediate release
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
# Another lock should be able to acquire immediately
|
||||
new_owner_id = str(uuid.uuid4())
|
||||
new_lock = ClusterLock(redis_client, lock_key, new_owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == new_owner_id
|
||||
|
||||
|
||||
class TestClusterLockRefresh:
|
||||
"""Lock refresh and TTL management."""
|
||||
|
||||
def test_lock_refresh_success(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock refresh extends TTL."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
original_ttl = redis_client.ttl(lock_key)
|
||||
|
||||
# Wait a bit then refresh
|
||||
time.sleep(1)
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# TTL should be reset to full timeout (allow for small timing differences)
|
||||
new_ttl = redis_client.ttl(lock_key)
|
||||
assert new_ttl >= original_ttl or new_ttl >= 58 # Allow for timing variance
|
||||
|
||||
def test_lock_refresh_rate_limiting(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh is rate-limited to timeout/10."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=100
|
||||
) # 100s timeout
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# First refresh should work
|
||||
assert lock.refresh() is True
|
||||
first_refresh_time = lock._last_refresh
|
||||
|
||||
# Immediate second refresh should be skipped (rate limited) but verify key exists
|
||||
assert lock.refresh() is True # Returns True but skips actual refresh
|
||||
assert lock._last_refresh == first_refresh_time # Time unchanged
|
||||
|
||||
def test_lock_refresh_verifies_existence_during_rate_limit(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test refresh verifies lock existence even during rate limiting."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=100)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Manually delete the key (simulates expiry or external deletion)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
# Refresh should detect missing key even during rate limit period
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_lock_refresh_ownership_lost(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh fails when ownership is lost."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Simulate another process taking the lock
|
||||
different_owner = str(uuid.uuid4())
|
||||
redis_client.set(lock_key, different_owner, ex=60)
|
||||
|
||||
# Force refresh past rate limit and verify it fails
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_lock_refresh_when_not_acquired(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh fails when lock was never acquired."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Refresh without acquiring should fail
|
||||
assert lock.refresh() is False
|
||||
|
||||
|
||||
class TestClusterLockExpiry:
|
||||
"""Lock expiry and timeout behavior."""
|
||||
|
||||
def test_lock_natural_expiry(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock expires naturally via Redis TTL."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=2
|
||||
) # 2 second timeout
|
||||
|
||||
lock.try_acquire()
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
|
||||
# Wait for expiry
|
||||
time.sleep(3)
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
# New lock with same key should succeed
|
||||
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == owner_id
|
||||
|
||||
def test_lock_refresh_prevents_expiry(self, redis_client, lock_key, owner_id):
|
||||
"""Test refreshing prevents lock from expiring."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=3
|
||||
) # 3 second timeout
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Wait and refresh before expiry
|
||||
time.sleep(1)
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# Wait beyond original timeout
|
||||
time.sleep(2.5)
|
||||
assert redis_client.exists(lock_key) == 1 # Should still exist
|
||||
|
||||
|
||||
class TestClusterLockConcurrency:
|
||||
"""Concurrent access patterns."""
|
||||
|
||||
def test_multiple_threads_contention(self, redis_client, lock_key):
|
||||
"""Test multiple threads competing for same lock."""
|
||||
num_threads = 5
|
||||
successful_acquisitions = []
|
||||
|
||||
def try_acquire_lock(thread_id):
|
||||
owner_id = f"thread_{thread_id}"
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
if lock.try_acquire() == owner_id:
|
||||
successful_acquisitions.append(thread_id)
|
||||
time.sleep(0.1) # Hold lock briefly
|
||||
lock.release()
|
||||
|
||||
threads = []
|
||||
for i in range(num_threads):
|
||||
thread = Thread(target=try_acquire_lock, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one thread should have acquired the lock
|
||||
assert len(successful_acquisitions) == 1
|
||||
|
||||
def test_sequential_lock_reuse(self, redis_client, lock_key):
|
||||
"""Test lock can be reused after natural expiry."""
|
||||
owners = [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
for i, owner_id in enumerate(owners):
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=1) # 1 second
|
||||
|
||||
assert lock.try_acquire() == owner_id
|
||||
time.sleep(1.5) # Wait for expiry
|
||||
|
||||
# Verify lock expired
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
def test_refresh_during_concurrent_access(self, redis_client, lock_key):
|
||||
"""Test lock refresh works correctly during concurrent access attempts."""
|
||||
owner1 = str(uuid.uuid4())
|
||||
owner2 = str(uuid.uuid4())
|
||||
|
||||
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=5)
|
||||
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=5)
|
||||
|
||||
# Thread 1 holds lock and refreshes
|
||||
assert lock1.try_acquire() == owner1
|
||||
|
||||
def refresh_continuously():
|
||||
for _ in range(10):
|
||||
lock1._last_refresh = 0 # Force refresh
|
||||
lock1.refresh()
|
||||
time.sleep(0.1)
|
||||
|
||||
def try_acquire_continuously():
|
||||
attempts = 0
|
||||
while attempts < 20:
|
||||
if lock2.try_acquire() == owner2:
|
||||
return True
|
||||
time.sleep(0.1)
|
||||
attempts += 1
|
||||
return False
|
||||
|
||||
refresh_thread = Thread(target=refresh_continuously)
|
||||
acquire_thread = Thread(target=try_acquire_continuously)
|
||||
|
||||
refresh_thread.start()
|
||||
acquire_thread.start()
|
||||
|
||||
refresh_thread.join()
|
||||
acquire_thread.join()
|
||||
|
||||
# Lock1 should still own the lock due to refreshes
|
||||
assert lock1._last_refresh > 0
|
||||
assert lock2._last_refresh == 0
|
||||
|
||||
|
||||
class TestClusterLockErrorHandling:
|
||||
"""Error handling and edge cases."""
|
||||
|
||||
def test_redis_connection_failure_on_acquire(self, lock_key, owner_id):
|
||||
"""Test graceful handling when Redis is unavailable during acquisition."""
|
||||
# Use invalid Redis connection
|
||||
bad_redis = redis.Redis(
|
||||
host="invalid_host", port=1234, socket_connect_timeout=1
|
||||
)
|
||||
lock = ClusterLock(bad_redis, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Should return None for Redis connection failures
|
||||
result = lock.try_acquire()
|
||||
assert result is None # Returns None when Redis fails
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_redis_connection_failure_on_refresh(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test graceful handling when Redis fails during refresh."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Acquire normally
|
||||
assert lock.try_acquire() == owner_id
|
||||
|
||||
# Replace Redis client with failing one
|
||||
lock.redis = redis.Redis(
|
||||
host="invalid_host", port=1234, socket_connect_timeout=1
|
||||
)
|
||||
|
||||
# Refresh should fail gracefully
|
||||
lock._last_refresh = 0 # Force refresh
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_invalid_lock_parameters(self, redis_client):
|
||||
"""Test validation of lock parameters."""
|
||||
owner_id = str(uuid.uuid4())
|
||||
|
||||
# All parameters are now simple - no validation needed
|
||||
# Just test basic construction works
|
||||
lock = ClusterLock(redis_client, "test_key", owner_id, timeout=60)
|
||||
assert lock.key == "test_key"
|
||||
assert lock.owner_id == owner_id
|
||||
assert lock.timeout == 60
|
||||
|
||||
def test_refresh_after_redis_key_deleted(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh behavior when Redis key is manually deleted."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Manually delete the key (simulates external deletion)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
# Refresh should fail and mark as not acquired
|
||||
lock._last_refresh = 0 # Force refresh
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
|
||||
class TestClusterLockDynamicRefreshInterval:
|
||||
"""Dynamic refresh interval based on timeout."""
|
||||
|
||||
def test_refresh_interval_calculation(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh interval is calculated as max(timeout/10, 1)."""
|
||||
test_cases = [
|
||||
(5, 1), # 5/10 = 0, but minimum is 1
|
||||
(10, 1), # 10/10 = 1
|
||||
(30, 3), # 30/10 = 3
|
||||
(100, 10), # 100/10 = 10
|
||||
(200, 20), # 200/10 = 20
|
||||
(1000, 100), # 1000/10 = 100
|
||||
]
|
||||
|
||||
for timeout, expected_interval in test_cases:
|
||||
lock = ClusterLock(
|
||||
redis_client, f"{lock_key}_{timeout}", owner_id, timeout=timeout
|
||||
)
|
||||
lock.try_acquire()
|
||||
|
||||
# Calculate expected interval using same logic as implementation
|
||||
refresh_interval = max(timeout // 10, 1)
|
||||
assert refresh_interval == expected_interval
|
||||
|
||||
# Test rate limiting works with calculated interval
|
||||
assert lock.refresh() is True
|
||||
first_refresh_time = lock._last_refresh
|
||||
|
||||
# Sleep less than interval - should be rate limited
|
||||
time.sleep(0.1)
|
||||
assert lock.refresh() is True
|
||||
assert lock._last_refresh == first_refresh_time # No actual refresh
|
||||
|
||||
|
||||
class TestClusterLockRealWorldScenarios:
|
||||
"""Real-world usage patterns."""
|
||||
|
||||
def test_execution_coordination_simulation(self, redis_client):
|
||||
"""Simulate graph execution coordination across multiple pods."""
|
||||
graph_exec_id = str(uuid.uuid4())
|
||||
lock_key = f"execution:{graph_exec_id}"
|
||||
|
||||
# Simulate 3 pods trying to execute same graph
|
||||
pods = [f"pod_{i}" for i in range(3)]
|
||||
execution_results = {}
|
||||
|
||||
def execute_graph(pod_id):
|
||||
"""Simulate graph execution with cluster lock."""
|
||||
lock = ClusterLock(redis_client, lock_key, pod_id, timeout=300)
|
||||
|
||||
if lock.try_acquire() == pod_id:
|
||||
# Simulate execution work
|
||||
execution_results[pod_id] = "executed"
|
||||
time.sleep(0.1)
|
||||
lock.release()
|
||||
else:
|
||||
execution_results[pod_id] = "rejected"
|
||||
|
||||
threads = []
|
||||
for pod_id in pods:
|
||||
thread = Thread(target=execute_graph, args=(pod_id,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one pod should have executed
|
||||
executed_count = sum(
|
||||
1 for result in execution_results.values() if result == "executed"
|
||||
)
|
||||
rejected_count = sum(
|
||||
1 for result in execution_results.values() if result == "rejected"
|
||||
)
|
||||
|
||||
assert executed_count == 1
|
||||
assert rejected_count == 2
|
||||
|
||||
def test_long_running_execution_with_refresh(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test lock maintains ownership during long execution with periodic refresh."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=30
|
||||
) # 30 second timeout, refresh interval = max(30//10, 1) = 3 seconds
|
||||
|
||||
def long_execution_with_refresh():
|
||||
"""Simulate long-running execution with periodic refresh."""
|
||||
assert lock.try_acquire() == owner_id
|
||||
|
||||
# Simulate 10 seconds of work with refreshes every 2 seconds
|
||||
# This respects rate limiting - actual refreshes will happen at 0s, 3s, 6s, 9s
|
||||
try:
|
||||
for i in range(5): # 5 iterations * 2 seconds = 10 seconds total
|
||||
time.sleep(2)
|
||||
refresh_success = lock.refresh()
|
||||
assert refresh_success is True, f"Refresh failed at iteration {i}"
|
||||
return "completed"
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
# Should complete successfully without losing lock
|
||||
result = long_execution_with_refresh()
|
||||
assert result == "completed"
|
||||
|
||||
def test_graceful_degradation_pattern(self, redis_client, lock_key):
|
||||
"""Test graceful degradation when Redis becomes unavailable."""
|
||||
owner_id = str(uuid.uuid4())
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=3
|
||||
) # Use shorter timeout
|
||||
|
||||
# Normal operation
|
||||
assert lock.try_acquire() == owner_id
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# Simulate Redis becoming unavailable
|
||||
original_redis = lock.redis
|
||||
lock.redis = redis.Redis(
|
||||
host="invalid_host",
|
||||
port=1234,
|
||||
socket_connect_timeout=1,
|
||||
decode_responses=False,
|
||||
)
|
||||
|
||||
# Should degrade gracefully
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
# Restore Redis and verify can acquire again
|
||||
lock.redis = original_redis
|
||||
# Wait for original lock to expire (use longer wait for 3s timeout)
|
||||
time.sleep(4)
|
||||
|
||||
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == owner_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run specific test for quick validation
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -11,32 +10,9 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockOutputEntry,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
@@ -49,21 +25,50 @@ from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockOutputEntry,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
parse_execution_output,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
@@ -79,7 +84,6 @@ from backend.util.decorator import (
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
@@ -87,12 +91,6 @@ from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = TruncatedLogger(_logger, prefix="[GraphExecutor]")
|
||||
settings = Settings()
|
||||
@@ -108,7 +106,6 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
|
||||
@@ -120,14 +117,10 @@ def init_worker():
|
||||
|
||||
|
||||
def execute_graph(
|
||||
graph_exec_entry: "GraphExecutionEntry",
|
||||
cancel_event: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event
|
||||
):
|
||||
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||||
return _tls.processor.on_graph_execution(
|
||||
graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
return _tls.processor.on_graph_execution(graph_exec_entry, cancel_event)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -436,7 +429,7 @@ class ExecutionProcessor:
|
||||
graph_id=node_exec.graph_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_name=b.name if (b := get_block(node_exec.block_id)) else "-",
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
@@ -590,7 +583,6 @@ class ExecutionProcessor:
|
||||
self,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -649,7 +641,6 @@ class ExecutionProcessor:
|
||||
cancel=cancel,
|
||||
log_metadata=log_metadata,
|
||||
execution_stats=exec_stats,
|
||||
cluster_lock=cluster_lock,
|
||||
)
|
||||
exec_stats.walltime += timing_info.wall_time
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
@@ -751,7 +742,6 @@ class ExecutionProcessor:
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
execution_stats: GraphExecutionStats,
|
||||
cluster_lock: ClusterLock,
|
||||
) -> ExecutionStatus:
|
||||
"""
|
||||
Returns:
|
||||
@@ -937,7 +927,7 @@ class ExecutionProcessor:
|
||||
and execution_queue.empty()
|
||||
and (running_node_execution or running_node_evaluation)
|
||||
):
|
||||
cluster_lock.refresh()
|
||||
# There is nothing to execute, and no output to process, let's relax for a while.
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
@@ -1229,7 +1219,6 @@ class ExecutionManager(AppProcess):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
@@ -1239,8 +1228,6 @@ class ExecutionManager(AppProcess):
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._execution_locks = {}
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
@@ -1448,46 +1435,17 @@ class ExecutionManager(AppProcess):
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
)
|
||||
|
||||
# Check for local duplicate execution first
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running locally; rejecting duplicate."
|
||||
# TODO: Make this check cluster-wide, prevent duplicate runs across executor pods.
|
||||
logger.error(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide execution lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"exec_lock:{graph_exec_id}",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
# Either someone else has it or Redis is unavailable
|
||||
if current_owner is not None:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
self._execution_locks[graph_exec_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"[{self.service_name}] Acquired cluster lock for {graph_exec_id} with executor {self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
|
||||
future = self.executor.submit(
|
||||
execute_graph, graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
future = self.executor.submit(execute_graph, graph_exec_entry, cancel_event)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
self._update_prompt_metrics()
|
||||
|
||||
@@ -1506,10 +1464,6 @@ class ExecutionManager(AppProcess):
|
||||
f"[{self.service_name}] Error in run completion callback: {e}"
|
||||
)
|
||||
finally:
|
||||
# Release the cluster-wide execution lock
|
||||
if graph_exec_id in self._execution_locks:
|
||||
self._execution_locks[graph_exec_id].release()
|
||||
del self._execution_locks[graph_exec_id]
|
||||
self._cleanup_completed_runs()
|
||||
|
||||
future.add_done_callback(_on_run_done)
|
||||
@@ -1592,10 +1546,6 @@ class ExecutionManager(AppProcess):
|
||||
f"{prefix} ⏳ Still waiting for {len(self.active_graph_runs)} executions: {ids}"
|
||||
)
|
||||
|
||||
for graph_exec_id in self.active_graph_runs:
|
||||
if lock := self._execution_locks.get(graph_exec_id):
|
||||
lock.refresh()
|
||||
|
||||
time.sleep(wait_interval)
|
||||
waited += wait_interval
|
||||
|
||||
@@ -1613,15 +1563,6 @@ class ExecutionManager(AppProcess):
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
|
||||
|
||||
# Release remaining execution locks
|
||||
try:
|
||||
for lock in self._execution_locks.values():
|
||||
lock.release()
|
||||
self._execution_locks.clear()
|
||||
logger.info(f"{prefix} ✅ Released execution locks")
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} ⚠️ Failed to release all locks: {e}")
|
||||
|
||||
# Disconnect the run execution consumer
|
||||
self._stop_message_consumers(
|
||||
self.run_thread,
|
||||
@@ -1727,18 +1668,15 @@ def update_graph_execution_state(
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def synchronized(key: str, timeout: int = settings.config.cluster_lock_timeout):
|
||||
async def synchronized(key: str, timeout: int = 60):
|
||||
r = await redis.get_redis_async()
|
||||
lock: AsyncRedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
await lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if await lock.locked() and await lock.owned():
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release lock for key {key}: {e}")
|
||||
await lock.release()
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
|
||||
@@ -4,7 +4,7 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Mapping, Optional, cast
|
||||
from typing import Any, Mapping, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
@@ -20,9 +20,6 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import prisma
|
||||
|
||||
# Import dynamic field utilities from centralized location
|
||||
from backend.data.dynamic_fields import merge_execution_input
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
@@ -42,6 +39,7 @@ from backend.util.clients import (
|
||||
)
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
@@ -188,7 +186,195 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
|
||||
|
||||
# ============ Execution Input Helpers ============ #
|
||||
|
||||
# Dynamic field utilities are now imported from backend.data.dynamic_fields
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Delimiters
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
_DELIMS = (LIST_SPLIT, DICT_SPLIT, OBJC_SPLIT)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Tokenisation utilities
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _next_delim(s: str) -> tuple[str | None, int]:
|
||||
"""
|
||||
Return the *earliest* delimiter appearing in `s` and its index.
|
||||
|
||||
If none present → (None, -1).
|
||||
"""
|
||||
first: str | None = None
|
||||
pos = len(s) # sentinel: larger than any real index
|
||||
for d in _DELIMS:
|
||||
i = s.find(d)
|
||||
if 0 <= i < pos:
|
||||
first, pos = d, i
|
||||
return first, (pos if first else -1)
|
||||
|
||||
|
||||
def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Convert the raw path string (starting with a delimiter) into
|
||||
[ (delimiter, identifier), … ] or None if the syntax is malformed.
|
||||
"""
|
||||
tokens: list[tuple[str, str]] = []
|
||||
while path:
|
||||
# 1. Which delimiter starts this chunk?
|
||||
delim = next((d for d in _DELIMS if path.startswith(d)), None)
|
||||
if delim is None:
|
||||
return None # invalid syntax
|
||||
|
||||
# 2. Slice off the delimiter, then up to the next delimiter (or EOS)
|
||||
path = path[len(delim) :]
|
||||
nxt_delim, pos = _next_delim(path)
|
||||
token, path = (
|
||||
path[: pos if pos != -1 else len(path)],
|
||||
path[pos if pos != -1 else len(path) :],
|
||||
)
|
||||
if token == "":
|
||||
return None # empty identifier is invalid
|
||||
tokens.append((delim, token))
|
||||
return tokens
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Public API – parsing (flattened ➜ concrete)
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | None:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path)
|
||||
returns **None**.
|
||||
"""
|
||||
base_name, data = output
|
||||
|
||||
# Exact match → whole object
|
||||
if name == base_name:
|
||||
return data
|
||||
|
||||
# Must start with the expected name
|
||||
if not name.startswith(base_name):
|
||||
return None
|
||||
path = name[len(base_name) :]
|
||||
if not path:
|
||||
return None # nothing left to parse
|
||||
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: JsonValue = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
return None
|
||||
if not isinstance(cur, list) or idx >= len(cur):
|
||||
return None
|
||||
cur = cur[idx]
|
||||
|
||||
elif delim == DICT_SPLIT:
|
||||
if not isinstance(cur, dict) or ident not in cur:
|
||||
return None
|
||||
cur = cur[ident]
|
||||
|
||||
elif delim == OBJC_SPLIT:
|
||||
if not hasattr(cur, ident):
|
||||
return None
|
||||
cur = getattr(cur, ident)
|
||||
|
||||
else:
|
||||
return None # unreachable
|
||||
|
||||
return cur
|
||||
|
||||
|
||||
def _assign(container: Any, tokens: list[tuple[str, str]], value: Any) -> Any:
|
||||
"""
|
||||
Recursive helper that *returns* the (possibly new) container with
|
||||
`value` assigned along the remaining `tokens` path.
|
||||
"""
|
||||
if not tokens:
|
||||
return value # leaf reached
|
||||
|
||||
delim, ident = tokens[0]
|
||||
rest = tokens[1:]
|
||||
|
||||
# ---------- list ----------
|
||||
if delim == LIST_SPLIT:
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
raise ValueError("index must be an integer")
|
||||
|
||||
if container is None:
|
||||
container = []
|
||||
elif not isinstance(container, list):
|
||||
container = list(container) if hasattr(container, "__iter__") else []
|
||||
|
||||
while len(container) <= idx:
|
||||
container.append(None)
|
||||
container[idx] = _assign(container[idx], rest, value)
|
||||
return container
|
||||
|
||||
# ---------- dict ----------
|
||||
if delim == DICT_SPLIT:
|
||||
if container is None:
|
||||
container = {}
|
||||
elif not isinstance(container, dict):
|
||||
container = dict(container) if hasattr(container, "items") else {}
|
||||
container[ident] = _assign(container.get(ident), rest, value)
|
||||
return container
|
||||
|
||||
# ---------- object ----------
|
||||
if delim == OBJC_SPLIT:
|
||||
if container is None or not isinstance(container, MockObject):
|
||||
container = MockObject()
|
||||
setattr(
|
||||
container,
|
||||
ident,
|
||||
_assign(getattr(container, ident, None), rest, value),
|
||||
)
|
||||
return container
|
||||
|
||||
return value # unreachable
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Reconstruct nested objects from a *flattened* dict of key → value.
|
||||
|
||||
Raises ValueError on syntactically invalid list indices.
|
||||
"""
|
||||
merged: BlockInput = {}
|
||||
|
||||
for key, value in data.items():
|
||||
# Split off the base name (before the first delimiter, if any)
|
||||
delim, pos = _next_delim(key)
|
||||
if delim is None:
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
base, path = key[:pos], key[pos:]
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
# Invalid key; treat as scalar under the raw name
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
merged[base] = _assign(merged.get(base), tokens, value)
|
||||
|
||||
data.update(merged)
|
||||
return data
|
||||
|
||||
|
||||
def validate_exec(
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import cast
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
||||
from backend.executor.utils import merge_execution_input, parse_execution_output
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
|
||||
|
||||
@@ -151,10 +151,7 @@ class IntegrationCredentialsManager:
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release OAuth refresh lock: {e}")
|
||||
await _lock.release()
|
||||
|
||||
credentials = fresh_credentials
|
||||
return credentials
|
||||
@@ -187,10 +184,7 @@ class IntegrationCredentialsManager:
|
||||
yield
|
||||
finally:
|
||||
if (await lock.locked()) and (await lock.owned()):
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release credentials lock: {e}")
|
||||
await lock.release()
|
||||
|
||||
async def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.util.cache import cached
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
@@ -8,7 +8,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# --8<-- [start:load_webhook_managers]
|
||||
@cached(ttl_seconds=3600) # Cache webhook managers for 1 hour
|
||||
@cached()
|
||||
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
|
||||
webhook_managers = {}
|
||||
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
"""
|
||||
Shared cache configuration constants.
|
||||
|
||||
This module defines all page_size defaults used across the application.
|
||||
By centralizing these values, we ensure that cache invalidation always
|
||||
uses the same page_size as the routes that populate the cache.
|
||||
|
||||
CRITICAL: If you change any of these values, the tests in
|
||||
test_cache_invalidation_consistency.py will fail to remind you to
|
||||
update all dependent code.
|
||||
"""
|
||||
|
||||
# V1 API (legacy) page sizes
|
||||
V1_GRAPHS_PAGE_SIZE = 250
|
||||
"""Default page size for listing user graphs in v1 API."""
|
||||
|
||||
V1_LIBRARY_AGENTS_PAGE_SIZE = 10
|
||||
"""Default page size for library agents in v1 API."""
|
||||
|
||||
V1_GRAPH_EXECUTIONS_PAGE_SIZE = 25
|
||||
"""Default page size for graph executions in v1 API."""
|
||||
|
||||
# V2 Store API page sizes
|
||||
V2_STORE_AGENTS_PAGE_SIZE = 20
|
||||
"""Default page size for store agents listing."""
|
||||
|
||||
V2_STORE_CREATORS_PAGE_SIZE = 20
|
||||
"""Default page size for store creators listing."""
|
||||
|
||||
V2_STORE_SUBMISSIONS_PAGE_SIZE = 20
|
||||
"""Default page size for user submissions listing."""
|
||||
|
||||
V2_MY_AGENTS_PAGE_SIZE = 20
|
||||
"""Default page size for user's own agents listing."""
|
||||
|
||||
# V2 Library API page sizes
|
||||
V2_LIBRARY_AGENTS_PAGE_SIZE = 10
|
||||
"""Default page size for library agents listing in v2 API."""
|
||||
|
||||
V2_LIBRARY_PRESETS_PAGE_SIZE = 20
|
||||
"""Default page size for library presets listing."""
|
||||
|
||||
# Alternative page sizes (for backward compatibility or special cases)
|
||||
V2_LIBRARY_PRESETS_ALT_PAGE_SIZE = 10
|
||||
"""
|
||||
Alternative page size for library presets.
|
||||
Some clients may use this smaller page size, so cache clearing must handle both.
|
||||
"""
|
||||
|
||||
V2_GRAPH_EXECUTIONS_ALT_PAGE_SIZE = 10
|
||||
"""
|
||||
Alternative page size for graph executions.
|
||||
Some clients may use this smaller page size, so cache clearing must handle both.
|
||||
"""
|
||||
|
||||
# Cache clearing configuration
|
||||
MAX_PAGES_TO_CLEAR = 20
|
||||
"""
|
||||
Maximum number of pages to clear when invalidating paginated caches.
|
||||
This prevents infinite loops while ensuring we clear most cached pages.
|
||||
For users with more than 20 pages, those pages will expire naturally via TTL.
|
||||
"""
|
||||
|
||||
|
||||
def get_page_sizes_for_clearing(
|
||||
primary_page_size: int, alt_page_size: int | None = None
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get all page_size values that should be cleared for a given cache.
|
||||
|
||||
Args:
|
||||
primary_page_size: The main page_size used by the route
|
||||
alt_page_size: Optional alternative page_size if multiple clients use different sizes
|
||||
|
||||
Returns:
|
||||
List of page_size values to clear
|
||||
|
||||
Example:
|
||||
>>> get_page_sizes_for_clearing(20)
|
||||
[20]
|
||||
>>> get_page_sizes_for_clearing(20, 10)
|
||||
[20, 10]
|
||||
"""
|
||||
if alt_page_size is None:
|
||||
return [primary_page_size]
|
||||
return [primary_page_size, alt_page_size]
|
||||
@@ -32,7 +32,6 @@ from backend.data.model import (
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
)
|
||||
from backend.data.onboarding import complete_webhook_trigger_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
@@ -180,7 +179,7 @@ async def callback(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
@router.get("/credentials")
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
@@ -221,9 +220,7 @@ async def list_credentials_by_provider(
|
||||
]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider}/credentials/{cred_id}", summary="Get Specific Credential By ID"
|
||||
)
|
||||
@router.get("/{provider}/credentials/{cred_id}")
|
||||
async def get_credential(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to retrieve credentials for")
|
||||
@@ -244,7 +241,7 @@ async def get_credential(
|
||||
return credential
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
async def create_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
provider: Annotated[
|
||||
@@ -370,8 +367,6 @@ async def webhook_ingress_generic(
|
||||
return
|
||||
|
||||
executions: list[Awaitable] = []
|
||||
await complete_webhook_trigger_step(user_id)
|
||||
|
||||
for node in webhook.triggered_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware:
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to add security headers to responses, with cache control
|
||||
disabled by default for all endpoints except those explicitly allowed.
|
||||
@@ -23,8 +25,6 @@ class SecurityHeadersMiddleware:
|
||||
"/api/health",
|
||||
"/api/v1/health",
|
||||
"/api/status",
|
||||
"/api/blocks",
|
||||
"/api/v1/blocks",
|
||||
# Public store/marketplace pages (read-only)
|
||||
"/api/store/agents",
|
||||
"/api/v1/store/agents",
|
||||
@@ -49,7 +49,7 @@ class SecurityHeadersMiddleware:
|
||||
}
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
super().__init__(app)
|
||||
# Compile regex patterns for wildcard matching
|
||||
self.cacheable_patterns = [
|
||||
re.compile(pattern.replace("*", "[^/]+"))
|
||||
@@ -72,42 +72,26 @@ class SecurityHeadersMiddleware:
|
||||
|
||||
return False
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Pure ASGI middleware implementation for better performance than BaseHTTPMiddleware."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response: Response = await call_next(request)
|
||||
|
||||
# Extract path from scope
|
||||
path = scope["path"]
|
||||
# Add general security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
# Add security headers to the response
|
||||
headers = dict(message.get("headers", []))
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in request.url.path:
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow"
|
||||
|
||||
# Add general security headers (HTTP spec requires proper capitalization)
|
||||
headers[b"X-Content-Type-Options"] = b"nosniff"
|
||||
headers[b"X-Frame-Options"] = b"DENY"
|
||||
headers[b"X-XSS-Protection"] = b"1; mode=block"
|
||||
headers[b"Referrer-Policy"] = b"strict-origin-when-cross-origin"
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(request.url.path):
|
||||
response.headers["Cache-Control"] = (
|
||||
"no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in path:
|
||||
headers[b"X-Robots-Tag"] = b"noindex, nofollow"
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(path):
|
||||
headers[b"Cache-Control"] = (
|
||||
b"no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
headers[b"Pragma"] = b"no-cache"
|
||||
headers[b"Expires"] = b"0"
|
||||
|
||||
# Convert headers back to list format
|
||||
message["headers"] = list(headers.items())
|
||||
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
return response
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import platform
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -12,7 +11,6 @@ import uvicorn
|
||||
from autogpt_libs.auth import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth import verify_settings as verify_auth_settings
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
@@ -72,26 +70,6 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
await backend.data.db.connect()
|
||||
|
||||
# Configure thread pool for FastAPI sync operation performance
|
||||
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
|
||||
# - Any endpoint defined with 'def' (not async def)
|
||||
# - Any dependency function defined with 'def' (not async def)
|
||||
# - Manual run_in_threadpool() calls (like JWT decoding)
|
||||
# Default pool size is only 40 threads, causing bottlenecks under high concurrency
|
||||
config = backend.util.settings.Config()
|
||||
try:
|
||||
import anyio.to_thread
|
||||
|
||||
anyio.to_thread.current_default_thread_limiter().total_tokens = (
|
||||
config.fastapi_thread_pool_size
|
||||
)
|
||||
logger.info(
|
||||
f"Thread pool size set to {config.fastapi_thread_pool_size} for sync endpoint/dependency performance"
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(f"Could not configure thread pool size: {e}")
|
||||
# Continue without thread pool configuration
|
||||
|
||||
# Ensure SDK auto-registration is patched before initializing blocks
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
@@ -162,9 +140,6 @@ app = fastapi.FastAPI(
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
# Add GZip compression middleware for large responses (like /api/blocks)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=50_000) # 50KB threshold
|
||||
|
||||
# Add 401 responses to authenticated endpoints in OpenAPI spec
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
@@ -298,28 +273,12 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
config = backend.util.settings.Config()
|
||||
|
||||
# Configure uvicorn with performance optimizations from Kludex FastAPI tips
|
||||
uvicorn_config = {
|
||||
"app": server_app,
|
||||
"host": config.agent_api_host,
|
||||
"port": config.agent_api_port,
|
||||
"log_config": None,
|
||||
# Use httptools for HTTP parsing (if available)
|
||||
"http": "httptools",
|
||||
# Only use uvloop on Unix-like systems (not supported on Windows)
|
||||
"loop": "uvloop" if platform.system() != "Windows" else "auto",
|
||||
}
|
||||
|
||||
# Only add debug in local environment (not supported in all uvicorn versions)
|
||||
if config.app_env == backend.util.settings.AppEnvironment.LOCAL:
|
||||
import os
|
||||
|
||||
# Enable asyncio debug mode via environment variable
|
||||
os.environ["PYTHONASYNCIODEBUG"] = "1"
|
||||
|
||||
uvicorn.run(**uvicorn_config)
|
||||
uvicorn.run(
|
||||
server_app,
|
||||
host=backend.util.settings.Config().agent_api_host,
|
||||
port=backend.util.settings.Config().agent_api_port,
|
||||
log_config=None,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
"""
|
||||
Cache functions for main V1 API endpoints.
|
||||
|
||||
This module contains all caching decorators and helpers for the V1 API,
|
||||
separated from the main routes for better organization and maintainability.
|
||||
"""
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import get_blocks
|
||||
from backend.util.cache import cached
|
||||
|
||||
# ===== Block Caches =====
|
||||
|
||||
|
||||
# Cache block definitions with costs - they rarely change
|
||||
@cached(maxsize=1, ttl_seconds=3600, shared_cache=True)
|
||||
def get_cached_blocks() -> Sequence[dict]:
|
||||
"""
|
||||
Get cached blocks with thundering herd protection.
|
||||
|
||||
Uses cached decorator to prevent multiple concurrent requests
|
||||
from all executing the expensive block loading operation.
|
||||
"""
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
block_classes = get_blocks()
|
||||
result = []
|
||||
|
||||
for block_class in block_classes.values():
|
||||
block_instance = block_class()
|
||||
if not block_instance.disabled:
|
||||
# Get costs for this specific block class without creating another instance
|
||||
costs = get_block_cost(block_instance)
|
||||
result.append({**block_instance.to_dict(), "costs": costs})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ===== Graph Caches =====
|
||||
|
||||
|
||||
# Cache user's graphs list for 15 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=900, shared_cache=True)
|
||||
async def get_cached_graphs(
|
||||
user_id: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get user's graphs."""
|
||||
return await graph_db.list_graphs_paginated(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual graph details for 30 minutes
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_graph(
|
||||
graph_id: str,
|
||||
version: int | None,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get graph details."""
|
||||
return await graph_db.get_graph(
|
||||
graph_id=graph_id,
|
||||
version=version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
)
|
||||
|
||||
|
||||
# Cache graph versions for 30 minutes
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_graph_all_versions(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
"""Cached helper to get all versions of a graph."""
|
||||
return await graph_db.get_graph_all_versions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# ===== Execution Caches =====
|
||||
|
||||
|
||||
# Cache graph executions for 10 seconds.
|
||||
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
|
||||
async def get_cached_graph_executions(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get graph executions."""
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache all user executions for 10 seconds.
|
||||
@cached(maxsize=500, ttl_seconds=10, shared_cache=True)
|
||||
async def get_cached_graphs_executions(
|
||||
user_id: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get all user's graph executions."""
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual execution details for 10 seconds.
|
||||
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
|
||||
async def get_cached_graph_execution(
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get graph execution details."""
|
||||
return await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
|
||||
|
||||
# ===== User Preference Caches =====
|
||||
|
||||
|
||||
# Cache user timezone for 1 hour
|
||||
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
|
||||
async def get_cached_user_timezone(user_id: str):
|
||||
"""Cached helper to get user timezone."""
|
||||
user = await user_db.get_user_by_id(user_id)
|
||||
return {"timezone": user.timezone if user else "UTC"}
|
||||
|
||||
|
||||
# Cache user preferences for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_user_preferences(user_id: str):
|
||||
"""Cached helper to get user notification preferences."""
|
||||
return await user_db.get_user_notification_preference(user_id)
|
||||
@@ -1,376 +0,0 @@
|
||||
"""
|
||||
Tests for cache invalidation in V1 API routes.
|
||||
|
||||
This module tests that caches are properly invalidated when data is modified
|
||||
through POST, PUT, PATCH, and DELETE operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.server.routers.cache as cache
|
||||
from backend.data import graph as graph_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id():
|
||||
"""Generate a mock user ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_id():
|
||||
"""Generate a mock graph ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestGraphCacheInvalidation:
|
||||
"""Test cache invalidation for graph operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_graph_clears_list_cache(self, mock_user_id):
|
||||
"""Test that creating a graph clears the graphs list cache."""
|
||||
# Setup
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
|
||||
# Pre-populate cache
|
||||
with patch.object(
|
||||
graph_db, "list_graphs_paginated", new_callable=AsyncMock
|
||||
) as mock_list:
|
||||
# Use a simple dict instead of MagicMock to make it pickleable
|
||||
mock_list.return_value = {
|
||||
"graphs": [],
|
||||
"total_count": 0,
|
||||
"page": 1,
|
||||
"page_size": 250,
|
||||
}
|
||||
|
||||
# First call should hit the database
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
assert mock_list.call_count == 1
|
||||
|
||||
# Second call should use cache
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
assert mock_list.call_count == 1 # Still 1, used cache
|
||||
|
||||
# Simulate cache invalidation (what happens in create_new_graph)
|
||||
cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
|
||||
|
||||
# Next call should hit database again
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
assert mock_list.call_count == 2 # Incremented, cache was cleared
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_graph_clears_multiple_caches(
|
||||
self, mock_user_id, mock_graph_id
|
||||
):
|
||||
"""Test that deleting a graph clears all related caches."""
|
||||
# Clear all caches first
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
cache.get_cached_graph.cache_clear()
|
||||
cache.get_cached_graph_all_versions.cache_clear()
|
||||
cache.get_cached_graph_executions.cache_clear()
|
||||
|
||||
# Setup mocks
|
||||
with (
|
||||
patch.object(
|
||||
graph_db, "list_graphs_paginated", new_callable=AsyncMock
|
||||
) as mock_list,
|
||||
patch.object(graph_db, "get_graph", new_callable=AsyncMock) as mock_get,
|
||||
patch.object(
|
||||
graph_db, "get_graph_all_versions", new_callable=AsyncMock
|
||||
) as mock_versions,
|
||||
):
|
||||
mock_list.return_value = {
|
||||
"graphs": [],
|
||||
"total_count": 0,
|
||||
"page": 1,
|
||||
"page_size": 250,
|
||||
}
|
||||
mock_get.return_value = {"id": mock_graph_id}
|
||||
mock_versions.return_value = []
|
||||
|
||||
# Pre-populate all caches (use consistent argument style)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
|
||||
initial_calls = {
|
||||
"list": mock_list.call_count,
|
||||
"get": mock_get.call_count,
|
||||
"versions": mock_versions.call_count,
|
||||
}
|
||||
|
||||
# Use cached values (no additional DB calls)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
|
||||
# Verify cache was used
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
assert mock_versions.call_count == initial_calls["versions"]
|
||||
|
||||
# Simulate delete_graph cache invalidation
|
||||
# Use positional arguments for cache_delete to match how we called the functions
|
||||
result1 = cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
|
||||
result2 = cache.get_cached_graph.cache_delete(
|
||||
mock_graph_id, None, mock_user_id
|
||||
)
|
||||
result3 = cache.get_cached_graph_all_versions.cache_delete(
|
||||
mock_graph_id, mock_user_id
|
||||
)
|
||||
|
||||
# Verify that the cache entries were actually deleted
|
||||
assert result1, "Failed to delete graphs cache entry"
|
||||
assert result2, "Failed to delete graph cache entry"
|
||||
assert result3, "Failed to delete graph versions cache entry"
|
||||
|
||||
# Next calls should hit database
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
|
||||
# Verify database was called again
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
assert mock_versions.call_count == initial_calls["versions"] + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_graph_clears_caches(self, mock_user_id, mock_graph_id):
|
||||
"""Test that updating a graph clears the appropriate caches."""
|
||||
# Clear caches
|
||||
cache.get_cached_graph.cache_clear()
|
||||
cache.get_cached_graph_all_versions.cache_clear()
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
|
||||
with (
|
||||
patch.object(graph_db, "get_graph", new_callable=AsyncMock) as mock_get,
|
||||
patch.object(
|
||||
graph_db, "get_graph_all_versions", new_callable=AsyncMock
|
||||
) as mock_versions,
|
||||
patch.object(
|
||||
graph_db, "list_graphs_paginated", new_callable=AsyncMock
|
||||
) as mock_list,
|
||||
):
|
||||
mock_get.return_value = {"id": mock_graph_id, "version": 1}
|
||||
mock_versions.return_value = [{"version": 1}]
|
||||
mock_list.return_value = {
|
||||
"graphs": [],
|
||||
"total_count": 0,
|
||||
"page": 1,
|
||||
"page_size": 250,
|
||||
}
|
||||
|
||||
# Populate caches
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
|
||||
initial_calls = {
|
||||
"get": mock_get.call_count,
|
||||
"versions": mock_versions.call_count,
|
||||
"list": mock_list.call_count,
|
||||
}
|
||||
|
||||
# Verify cache is being used
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
assert mock_versions.call_count == initial_calls["versions"]
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
|
||||
# Simulate update_graph cache invalidation
|
||||
cache.get_cached_graph.cache_delete(mock_graph_id, None, mock_user_id)
|
||||
cache.get_cached_graph_all_versions.cache_delete(
|
||||
mock_graph_id, mock_user_id
|
||||
)
|
||||
cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
|
||||
|
||||
# Next calls should hit database
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
assert mock_versions.call_count == initial_calls["versions"] + 1
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
|
||||
|
||||
class TestUserPreferencesCacheInvalidation:
|
||||
"""Test cache invalidation for user preferences operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_preferences_clears_cache(self, mock_user_id):
|
||||
"""Test that updating preferences clears the preferences cache."""
|
||||
# Clear cache
|
||||
cache.get_cached_user_preferences.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
cache.user_db, "get_user_notification_preference", new_callable=AsyncMock
|
||||
) as mock_get_prefs:
|
||||
mock_prefs = {"email_notifications": True, "push_notifications": False}
|
||||
mock_get_prefs.return_value = mock_prefs
|
||||
|
||||
# First call hits database
|
||||
result1 = await cache.get_cached_user_preferences(mock_user_id)
|
||||
assert mock_get_prefs.call_count == 1
|
||||
assert result1 == mock_prefs
|
||||
|
||||
# Second call uses cache
|
||||
result2 = await cache.get_cached_user_preferences(mock_user_id)
|
||||
assert mock_get_prefs.call_count == 1 # Still 1
|
||||
assert result2 == mock_prefs
|
||||
|
||||
# Simulate update_preferences cache invalidation
|
||||
cache.get_cached_user_preferences.cache_delete(mock_user_id)
|
||||
|
||||
# Change the mock return value to simulate updated preferences
|
||||
mock_prefs_updated = {
|
||||
"email_notifications": False,
|
||||
"push_notifications": True,
|
||||
}
|
||||
mock_get_prefs.return_value = mock_prefs_updated
|
||||
|
||||
# Next call should hit database and get new value
|
||||
result3 = await cache.get_cached_user_preferences(mock_user_id)
|
||||
assert mock_get_prefs.call_count == 2
|
||||
assert result3 == mock_prefs_updated
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timezone_cache_operations(self, mock_user_id):
|
||||
"""Test timezone cache and its operations."""
|
||||
# Clear cache
|
||||
cache.get_cached_user_timezone.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
cache.user_db, "get_user_by_id", new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
# Use a simple object that supports attribute access
|
||||
class MockUser:
|
||||
def __init__(self, timezone):
|
||||
self.timezone = timezone
|
||||
|
||||
mock_user = MockUser("America/New_York")
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# First call hits database
|
||||
result1 = await cache.get_cached_user_timezone(mock_user_id)
|
||||
assert mock_get_user.call_count == 1
|
||||
assert result1["timezone"] == "America/New_York"
|
||||
|
||||
# Second call uses cache
|
||||
result2 = await cache.get_cached_user_timezone(mock_user_id)
|
||||
assert mock_get_user.call_count == 1 # Still 1
|
||||
assert result2["timezone"] == "America/New_York"
|
||||
|
||||
# Clear cache manually (simulating what would happen after update)
|
||||
cache.get_cached_user_timezone.cache_delete(mock_user_id)
|
||||
|
||||
# Change timezone
|
||||
mock_user_updated = MockUser("Europe/London")
|
||||
mock_get_user.return_value = mock_user_updated
|
||||
|
||||
# Next call should hit database
|
||||
result3 = await cache.get_cached_user_timezone(mock_user_id)
|
||||
assert mock_get_user.call_count == 2
|
||||
assert result3["timezone"] == "Europe/London"
|
||||
|
||||
|
||||
class TestExecutionCacheInvalidation:
|
||||
"""Test cache invalidation for execution operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execution_cache_cleared_on_graph_delete(
|
||||
self, mock_user_id, mock_graph_id
|
||||
):
|
||||
"""Test that execution caches are cleared when a graph is deleted."""
|
||||
# Clear cache
|
||||
cache.get_cached_graph_executions.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
cache.execution_db, "get_graph_executions_paginated", new_callable=AsyncMock
|
||||
) as mock_exec:
|
||||
mock_exec.return_value = {
|
||||
"executions": [],
|
||||
"total_count": 0,
|
||||
"page": 1,
|
||||
"page_size": 25,
|
||||
}
|
||||
|
||||
# Populate cache for multiple pages
|
||||
for page in range(1, 4):
|
||||
await cache.get_cached_graph_executions(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
initial_calls = mock_exec.call_count
|
||||
|
||||
# Verify cache is used
|
||||
for page in range(1, 4):
|
||||
await cache.get_cached_graph_executions(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
assert mock_exec.call_count == initial_calls # No new calls
|
||||
|
||||
# Simulate graph deletion clearing execution caches
|
||||
for page in range(1, 10): # Clear more pages as done in delete_graph
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
# Next calls should hit database
|
||||
for page in range(1, 4):
|
||||
await cache.get_cached_graph_executions(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
assert mock_exec.call_count == initial_calls + 3 # 3 new calls
|
||||
|
||||
|
||||
class TestCacheInfo:
|
||||
"""Test cache information and metrics."""
|
||||
|
||||
def test_cache_info_returns_correct_metrics(self):
|
||||
"""Test that cache_info returns correct metrics."""
|
||||
# Clear all caches
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
cache.get_cached_graph.cache_clear()
|
||||
|
||||
# Get initial info
|
||||
info_graphs = cache.get_cached_graphs.cache_info()
|
||||
info_graph = cache.get_cached_graph.cache_info()
|
||||
|
||||
assert info_graphs["size"] == 0
|
||||
assert info_graph["size"] == 0
|
||||
|
||||
# Note: We can't directly test cache population without real async context,
|
||||
# but we can verify the cache_info structure
|
||||
assert "size" in info_graphs
|
||||
assert "maxsize" in info_graphs
|
||||
assert "ttl_seconds" in info_graphs
|
||||
|
||||
def test_cache_clear_removes_all_entries(self):
|
||||
"""Test that cache_clear removes all entries."""
|
||||
# This test verifies the cache_clear method exists and can be called
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
cache.get_cached_graph.cache_clear()
|
||||
cache.get_cached_graph_all_versions.cache_clear()
|
||||
cache.get_cached_graph_executions.cache_clear()
|
||||
cache.get_cached_graphs_executions.cache_clear()
|
||||
cache.get_cached_user_preferences.cache_clear()
|
||||
cache.get_cached_user_timezone.cache_clear()
|
||||
|
||||
# After clear, all caches should be empty
|
||||
assert cache.get_cached_graphs.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graph.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graph_all_versions.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graph_executions.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graphs_executions.cache_info()["size"] == 0
|
||||
assert cache.get_cached_user_preferences.cache_info()["size"] == 0
|
||||
assert cache.get_cached_user_timezone.cache_info()["size"] == 0
|
||||
@@ -11,6 +11,7 @@ import pydantic
|
||||
import stripe
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
@@ -23,16 +24,11 @@ from fastapi import (
|
||||
Security,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
import backend.server.cache_config as cache_config
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.routers.cache as cache
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as library_db
|
||||
from backend.data import api_key as api_key_db
|
||||
from backend.data import execution as execution_db
|
||||
@@ -59,6 +55,7 @@ from backend.data.onboarding import (
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
update_user_timezone,
|
||||
@@ -88,7 +85,6 @@ from backend.server.model import (
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.json import dumps
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
@@ -169,9 +165,7 @@ async def get_user_timezone_route(
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
user = await get_or_create_user(user_data)
|
||||
# Use cached timezone for subsequent calls
|
||||
result = await cache.get_cached_user_timezone(user.id)
|
||||
return TimezoneResponse(timezone=result["timezone"])
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -185,7 +179,6 @@ async def update_user_timezone_route(
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
cache.get_cached_user_timezone.cache_delete(user_id)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -198,7 +191,7 @@ async def update_user_timezone_route(
|
||||
async def get_preferences(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> NotificationPreference:
|
||||
preferences = await cache.get_cached_user_preferences(user_id)
|
||||
preferences = await get_user_notification_preference(user_id)
|
||||
return preferences
|
||||
|
||||
|
||||
@@ -213,10 +206,6 @@ async def update_preferences(
|
||||
preferences: NotificationPreferenceDTO = Body(...),
|
||||
) -> NotificationPreference:
|
||||
output = await update_user_notification_preference(user_id, preferences)
|
||||
|
||||
# Clear preferences cache after update
|
||||
cache.get_cached_user_preferences.cache_delete(user_id)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -274,10 +263,13 @@ async def is_onboarding_enabled():
|
||||
########################################################
|
||||
|
||||
|
||||
def _compute_blocks_sync() -> str:
|
||||
@cached()
|
||||
def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
|
||||
"""
|
||||
Synchronous function to compute blocks data.
|
||||
This does the heavy lifting: instantiate 226+ blocks, compute costs, serialize.
|
||||
Get cached blocks with thundering herd protection.
|
||||
|
||||
Uses sync_cache decorator to prevent multiple concurrent requests
|
||||
from all executing the expensive block loading operation.
|
||||
"""
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
@@ -287,27 +279,11 @@ def _compute_blocks_sync() -> str:
|
||||
for block_class in block_classes.values():
|
||||
block_instance = block_class()
|
||||
if not block_instance.disabled:
|
||||
# Get costs for this specific block class without creating another instance
|
||||
costs = get_block_cost(block_instance)
|
||||
# Convert BlockCost BaseModel objects to dictionaries for JSON serialization
|
||||
costs_dict = [
|
||||
cost.model_dump() if isinstance(cost, BaseModel) else cost
|
||||
for cost in costs
|
||||
]
|
||||
result.append({**block_instance.to_dict(), "costs": costs_dict})
|
||||
result.append({**block_instance.to_dict(), "costs": costs})
|
||||
|
||||
# Use our JSON utility which properly handles complex types through to_dict conversion
|
||||
return dumps(result)
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
async def _get_cached_blocks() -> str:
|
||||
"""
|
||||
Async cached function with thundering herd protection.
|
||||
On cache miss: runs heavy work in thread pool
|
||||
On cache hit: returns cached string immediately (no thread pool needed)
|
||||
"""
|
||||
# Only run in thread pool on cache miss - cache hits return immediately
|
||||
return await run_in_threadpool(_compute_blocks_sync)
|
||||
return result
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -315,28 +291,9 @@ async def _get_cached_blocks() -> str:
|
||||
summary="List available blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(requires_user)],
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"items": {"additionalProperties": True, "type": "object"},
|
||||
"type": "array",
|
||||
"title": "Response Getv1List Available Blocks",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def get_graph_blocks() -> Response:
|
||||
# Cache hit: returns immediately, Cache miss: runs in thread pool
|
||||
content = await _get_cached_blocks()
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/json",
|
||||
)
|
||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
return _get_cached_blocks()
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -676,10 +633,11 @@ class DeleteGraphResponse(TypedDict):
|
||||
async def list_graphs(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphMeta]:
|
||||
paginated_result = await cache.get_cached_graphs(
|
||||
paginated_result = await graph_db.list_graphs_paginated(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
filter_by="active",
|
||||
)
|
||||
return paginated_result.graphs
|
||||
|
||||
@@ -702,26 +660,13 @@ async def get_graph(
|
||||
version: int | None = None,
|
||||
for_export: bool = False,
|
||||
) -> graph_db.GraphModel:
|
||||
# Use cache for non-export requests
|
||||
if not for_export:
|
||||
graph = await cache.get_cached_graph(
|
||||
graph_id=graph_id,
|
||||
version=version,
|
||||
user_id=user_id,
|
||||
)
|
||||
# If graph not found, clear cache entry as permissions may have changed
|
||||
if not graph:
|
||||
cache.get_cached_graph.cache_delete(
|
||||
graph_id=graph_id, version=version, user_id=user_id
|
||||
)
|
||||
else:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id,
|
||||
version,
|
||||
user_id=user_id,
|
||||
for_export=for_export,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
)
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id,
|
||||
version,
|
||||
user_id=user_id,
|
||||
for_export=for_export,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graph
|
||||
@@ -736,7 +681,7 @@ async def get_graph(
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
graphs = await cache.get_cached_graph_all_versions(graph_id, user_id=user_id)
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not graphs:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graphs
|
||||
@@ -760,26 +705,6 @@ async def create_new_graph(
|
||||
# as the graph already valid and no sub-graphs are returned back.
|
||||
await graph_db.create_graph(graph, user_id=user_id)
|
||||
await library_db.create_library_agent(graph, user_id=user_id)
|
||||
|
||||
# Clear graphs list cache after creating new graph
|
||||
cache.get_cached_graphs.cache_delete(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
|
||||
)
|
||||
for page in range(1, cache_config.MAX_PAGES_TO_CLEAR):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# Clear my agents cache so user sees new agent immediately
|
||||
import backend.server.v2.store.cache
|
||||
|
||||
backend.server.v2.store.cache._clear_my_agents_cache(user_id)
|
||||
|
||||
return await on_graph_activate(graph, user_id=user_id)
|
||||
|
||||
|
||||
@@ -795,32 +720,7 @@ async def delete_graph(
|
||||
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
|
||||
await on_graph_deactivate(active_version, user_id=user_id)
|
||||
|
||||
result = DeleteGraphResponse(
|
||||
version_counts=await graph_db.delete_graph(graph_id, user_id=user_id)
|
||||
)
|
||||
|
||||
# Clear caches after deleting graph
|
||||
cache.get_cached_graphs.cache_delete(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
|
||||
)
|
||||
cache.get_cached_graph.cache_delete(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
)
|
||||
cache.get_cached_graph_all_versions.cache_delete(graph_id, user_id=user_id)
|
||||
|
||||
# Clear my agents cache so user sees agent removed immediately
|
||||
import backend.server.v2.store.cache
|
||||
|
||||
backend.server.v2.store.cache._clear_my_agents_cache(user_id)
|
||||
|
||||
# Clear library agent by graph_id cache
|
||||
library_cache.get_cached_library_agent_by_graph_id.cache_delete(
|
||||
graph_id=graph_id, user_id=user_id
|
||||
)
|
||||
|
||||
return result
|
||||
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
@@ -876,18 +776,6 @@ async def update_graph(
|
||||
include_subgraphs=True,
|
||||
)
|
||||
assert new_graph_version_with_subgraphs # make type checker happy
|
||||
|
||||
# Clear caches after updating graph
|
||||
cache.get_cached_graph.cache_delete(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
)
|
||||
cache.get_cached_graph_all_versions.cache_delete(graph_id, user_id=user_id)
|
||||
cache.get_cached_graphs.cache_delete(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
return new_graph_version_with_subgraphs
|
||||
|
||||
|
||||
@@ -953,29 +841,6 @@ async def execute_graph(
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
|
||||
# Invalidate caches before execution starts so frontend sees fresh data
|
||||
cache.get_cached_graphs_executions.cache_delete(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
|
||||
)
|
||||
for page in range(1, cache_config.MAX_PAGES_TO_CLEAR):
|
||||
cache.get_cached_graph_execution.cache_delete(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE,
|
||||
)
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
@@ -988,7 +853,6 @@ async def execute_graph(
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="success")
|
||||
|
||||
return result
|
||||
except GraphValidationError as e:
|
||||
# Record failed graph execution
|
||||
@@ -1064,7 +928,7 @@ async def _stop_graph_run(
|
||||
async def list_graphs_executions(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
paginated_result = await cache.get_cached_graphs_executions(
|
||||
paginated_result = await execution_db.get_graph_executions_paginated(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
@@ -1086,7 +950,7 @@ async def list_graph_executions(
|
||||
25, ge=1, le=100, description="Number of executions per page"
|
||||
),
|
||||
) -> execution_db.GraphExecutionsPaginated:
|
||||
return await cache.get_cached_graph_executions(
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
|
||||
@@ -102,13 +102,13 @@ def test_get_graph_blocks(
|
||||
mock_block.id = "test-block"
|
||||
mock_block.disabled = False
|
||||
|
||||
# Mock get_blocks where it's imported at the top of v1.py
|
||||
# Mock get_blocks
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_blocks",
|
||||
return_value={"test-block": lambda: mock_block},
|
||||
)
|
||||
|
||||
# Mock block costs where it's imported inside the function
|
||||
# Mock block costs
|
||||
mocker.patch(
|
||||
"backend.data.credit.get_block_cost",
|
||||
return_value=[{"cost": 10, "type": "credit"}],
|
||||
|
||||
@@ -1,299 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Complete audit of all @cached functions to verify proper cache invalidation.
|
||||
|
||||
This test systematically checks every @cached function in the codebase
|
||||
to ensure it has appropriate cache invalidation logic when data changes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestCacheInvalidationAudit:
|
||||
"""Audit all @cached functions for proper invalidation."""
|
||||
|
||||
def test_v1_router_caches(self):
|
||||
"""
|
||||
V1 Router cached functions:
|
||||
- _get_cached_blocks(): ✓ NEVER CHANGES (blocks are static in code)
|
||||
"""
|
||||
# No invalidation needed for static data
|
||||
pass
|
||||
|
||||
def test_v1_cache_module_graph_caches(self):
|
||||
"""
|
||||
V1 Cache module graph-related caches:
|
||||
- get_cached_graphs(user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py create_graph(), delete_graph(), update_graph_metadata(), stop_graph_execution()
|
||||
|
||||
- get_cached_graph(graph_id, version, user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py delete_graph(), update_graph(), delete_graph_execution()
|
||||
|
||||
- get_cached_graph_all_versions(graph_id, user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py delete_graph(), update_graph(), delete_graph_execution()
|
||||
|
||||
- get_cached_graph_executions(graph_id, user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py stop_graph_execution()
|
||||
Also cleared in: v2/library/routes/presets.py
|
||||
|
||||
- get_cached_graphs_executions(user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py stop_graph_execution()
|
||||
|
||||
- get_cached_graph_execution(graph_exec_id, user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py stop_graph_execution()
|
||||
|
||||
ISSUE: All use hardcoded page_size values instead of cache_config constants!
|
||||
"""
|
||||
# Document that v1 routes should migrate to use cache_config
|
||||
pass
|
||||
|
||||
def test_v1_cache_module_user_caches(self):
|
||||
"""
|
||||
V1 Cache module user-related caches:
|
||||
- get_cached_user_timezone(user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py update_user_profile()
|
||||
|
||||
- get_cached_user_preferences(user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py update_user_notification_preferences()
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_v2_store_cache_functions(self):
|
||||
"""
|
||||
V2 Store cached functions:
|
||||
- _get_cached_user_profile(user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v2/store/routes.py update_or_create_profile()
|
||||
|
||||
- _get_cached_store_agents(...): ⚠️ PARTIAL INVALIDATION
|
||||
Cleared in: v2/admin/store_admin_routes.py review_submission() - uses cache_clear()
|
||||
NOT cleared when agents are created/updated!
|
||||
|
||||
- _get_cached_agent_details(username, agent_name): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (15 min)
|
||||
|
||||
- _get_cached_agent_graph(store_listing_version_id): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (1 hour)
|
||||
|
||||
- _get_cached_store_agent_by_version(store_listing_version_id): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (1 hour)
|
||||
|
||||
- _get_cached_store_creators(...): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (1 hour)
|
||||
|
||||
- _get_cached_creator_details(username): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (1 hour)
|
||||
|
||||
- _get_cached_my_agents(user_id, page, page_size): ❌ NO INVALIDATION
|
||||
NEVER cleared! Users won't see new agents for 5 minutes!
|
||||
CRITICAL BUG: Should be cleared when user creates/deletes agents
|
||||
|
||||
- _get_cached_submissions(user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared via: _clear_submissions_cache() helper
|
||||
Called in: create_submission(), edit_submission(), delete_submission()
|
||||
Called in: v2/admin/store_admin_routes.py review_submission()
|
||||
"""
|
||||
# Document critical issues
|
||||
CRITICAL_MISSING_INVALIDATION = [
|
||||
"_get_cached_my_agents - users won't see new agents immediately",
|
||||
]
|
||||
|
||||
# Acceptable TTL-only caches (documented, not asserted):
|
||||
# - _get_cached_agent_details (public data, 15min TTL acceptable)
|
||||
# - _get_cached_agent_graph (immutable data, 1hr TTL acceptable)
|
||||
# - _get_cached_store_agent_by_version (immutable version, 1hr TTL acceptable)
|
||||
# - _get_cached_store_creators (public data, 1hr TTL acceptable)
|
||||
# - _get_cached_creator_details (public data, 1hr TTL acceptable)
|
||||
|
||||
assert (
|
||||
len(CRITICAL_MISSING_INVALIDATION) == 1
|
||||
), "These caches need invalidation logic:\n" + "\n".join(
|
||||
CRITICAL_MISSING_INVALIDATION
|
||||
)
|
||||
|
||||
def test_v2_library_cache_functions(self):
|
||||
"""
|
||||
V2 Library cached functions:
|
||||
- get_cached_library_agents(user_id, page, page_size, ...): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py create_graph(), stop_graph_execution()
|
||||
Cleared in: v2/library/routes/agents.py add_library_agent(), remove_library_agent()
|
||||
|
||||
- get_cached_library_agent_favorites(user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared in: v2/library/routes/agents.py favorite/unfavorite endpoints
|
||||
|
||||
- get_cached_library_agent(library_agent_id, user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v2/library/routes/agents.py remove_library_agent()
|
||||
|
||||
- get_cached_library_agent_by_graph_id(graph_id, user_id): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (30 min)
|
||||
Should be cleared when graph is deleted
|
||||
|
||||
- get_cached_library_agent_by_store_version(store_listing_version_id, user_id): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (1 hour)
|
||||
Probably acceptable as store versions are immutable
|
||||
|
||||
- get_cached_library_presets(user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared via: _clear_presets_list_cache() helper
|
||||
Called in: v2/library/routes/presets.py preset mutations
|
||||
|
||||
- get_cached_library_preset(preset_id, user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v2/library/routes/presets.py preset mutations
|
||||
|
||||
ISSUE: Clearing uses hardcoded page_size values (10 and 20) instead of cache_config!
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_immutable_singleton_caches(self):
|
||||
"""
|
||||
Caches that never need invalidation (singleton or immutable):
|
||||
- get_webhook_block_ids(): ✓ STATIC (blocks in code)
|
||||
- get_io_block_ids(): ✓ STATIC (blocks in code)
|
||||
- get_supabase(): ✓ CLIENT INSTANCE (no invalidation needed)
|
||||
- get_async_supabase(): ✓ CLIENT INSTANCE (no invalidation needed)
|
||||
- _get_all_providers(): ✓ STATIC CONFIG (providers in code)
|
||||
- get_redis(): ✓ CLIENT INSTANCE (no invalidation needed)
|
||||
- load_webhook_managers(): ✓ STATIC (managers in code)
|
||||
- load_all_blocks(): ✓ STATIC (blocks in code)
|
||||
- get_cached_blocks(): ✓ STATIC (blocks in code)
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_feature_flag_cache(self):
|
||||
"""
|
||||
Feature flag cache:
|
||||
- _fetch_user_context_data(user_id): ⚠️ LONG TTL
|
||||
TTL: 24 hours
|
||||
NO INVALIDATION
|
||||
|
||||
This is probably acceptable as user context changes infrequently.
|
||||
However, if user metadata changes, they won't see updated flags for 24 hours.
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_onboarding_cache(self):
|
||||
"""
|
||||
Onboarding cache:
|
||||
- onboarding_enabled(): ⚠️ NO INVALIDATION
|
||||
TTL: 5 minutes
|
||||
NO INVALIDATION
|
||||
|
||||
Should probably be cleared when store agents are added/removed.
|
||||
But 5min TTL is acceptable for this use case.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TestCacheInvalidationPageSizeConsistency:
|
||||
"""Test that all cache_delete calls use consistent page_size values."""
|
||||
|
||||
def test_v1_routes_hardcoded_page_sizes(self):
|
||||
"""
|
||||
V1 routes use hardcoded page_size values that should migrate to cache_config:
|
||||
|
||||
❌ page_size=250 for graphs:
|
||||
- v1.py line 765: cache.get_cached_graphs.cache_delete(user_id, page=1, page_size=250)
|
||||
- v1.py line 791: cache.get_cached_graphs.cache_delete(user_id, page=1, page_size=250)
|
||||
- v1.py line 859: cache.get_cached_graphs.cache_delete(user_id, page=1, page_size=250)
|
||||
- v1.py line 929: cache.get_cached_graphs_executions.cache_delete(user_id, page=1, page_size=250)
|
||||
|
||||
❌ page_size=10 for library agents:
|
||||
- v1.py line 768: library_cache.get_cached_library_agents.cache_delete(..., page_size=10)
|
||||
- v1.py line 940: library_cache.get_cached_library_agents.cache_delete(..., page_size=10)
|
||||
|
||||
❌ page_size=25 for graph executions:
|
||||
- v1.py line 937: cache.get_cached_graph_executions.cache_delete(..., page_size=25)
|
||||
|
||||
RECOMMENDATION: Create constants in cache_config and migrate v1 routes to use them.
|
||||
"""
|
||||
from backend.server import cache_config
|
||||
|
||||
# These constants exist but aren't used in v1 routes yet
|
||||
assert cache_config.V1_GRAPHS_PAGE_SIZE == 250
|
||||
assert cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE == 10
|
||||
assert cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE == 25
|
||||
|
||||
def test_v2_library_routes_hardcoded_page_sizes(self):
|
||||
"""
|
||||
V2 library routes use hardcoded page_size values:
|
||||
|
||||
❌ v2/library/routes/agents.py:
|
||||
- line 233: cache_delete(..., page_size=10)
|
||||
|
||||
❌ v2/library/routes/presets.py _clear_presets_list_cache():
|
||||
- Clears BOTH page_size=10 AND page_size=20
|
||||
- This suggests different consumers use different page sizes
|
||||
|
||||
❌ v2/library/routes/presets.py:
|
||||
- line 449: cache_delete(..., page_size=10)
|
||||
- line 452: cache_delete(..., page_size=25)
|
||||
|
||||
RECOMMENDATION: Migrate to use cache_config constants.
|
||||
"""
|
||||
from backend.server import cache_config
|
||||
|
||||
# Constants exist for library
|
||||
assert cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE == 10
|
||||
assert cache_config.V2_LIBRARY_PRESETS_PAGE_SIZE == 20
|
||||
assert cache_config.V2_LIBRARY_PRESETS_ALT_PAGE_SIZE == 10
|
||||
|
||||
def test_only_page_1_cleared_risk(self):
|
||||
"""
|
||||
Document cache_delete calls that only clear page=1.
|
||||
|
||||
RISKY PATTERN: Many cache_delete calls only clear page=1:
|
||||
- v1.py create_graph(): Only clears page=1 of graphs
|
||||
- v1.py delete_graph(): Only clears page=1 of graphs
|
||||
- v1.py update_graph_metadata(): Only clears page=1 of graphs
|
||||
- v1.py stop_graph_execution(): Only clears page=1 of executions
|
||||
|
||||
PROBLEM: If user has > 1 page, subsequent pages show stale data until TTL expires.
|
||||
|
||||
SOLUTIONS:
|
||||
1. Use cache_clear() to clear all pages (nuclear option)
|
||||
2. Loop through multiple pages like _clear_submissions_cache does
|
||||
3. Accept TTL-based expiry for pages 2+ (current approach)
|
||||
|
||||
Current approach is probably acceptable given TTL values are reasonable.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TestCriticalCacheBugs:
|
||||
"""Document critical cache bugs that need fixing."""
|
||||
|
||||
def test_my_agents_cache_never_cleared(self):
|
||||
"""
|
||||
CRITICAL BUG: _get_cached_my_agents is NEVER cleared!
|
||||
|
||||
Impact:
|
||||
- User creates a new agent → Won't see it in "My Agents" for 5 minutes
|
||||
- User deletes an agent → Still see it in "My Agents" for 5 minutes
|
||||
|
||||
Fix needed:
|
||||
1. Create _clear_my_agents_cache() helper (like _clear_submissions_cache)
|
||||
2. Call it from v1.py create_graph() and delete_graph()
|
||||
3. Use cache_config.V2_MY_AGENTS_PAGE_SIZE constant
|
||||
|
||||
Location: v2/store/cache.py line 120
|
||||
"""
|
||||
# This documents the bug
|
||||
NEEDS_CACHE_CLEARING = "_get_cached_my_agents"
|
||||
assert NEEDS_CACHE_CLEARING == "_get_cached_my_agents"
|
||||
|
||||
def test_library_agent_by_graph_id_never_cleared(self):
|
||||
"""
|
||||
BUG: get_cached_library_agent_by_graph_id is NEVER cleared!
|
||||
|
||||
Impact:
|
||||
- User deletes a graph → Library still shows it's available for 30 minutes
|
||||
|
||||
Fix needed:
|
||||
- Clear in v1.py delete_graph()
|
||||
- Clear in v2/library/routes/agents.py remove_library_agent()
|
||||
|
||||
Location: v2/library/cache.py line 59
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,95 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite to verify cache_config constants are being used correctly.
|
||||
|
||||
This ensures that the centralized cache_config.py constants are actually
|
||||
used throughout the codebase, not just defined.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.server import cache_config
|
||||
|
||||
|
||||
class TestCacheConfigConstants:
|
||||
"""Verify cache_config constants have expected values."""
|
||||
|
||||
def test_v2_store_page_sizes(self):
|
||||
"""Test V2 Store API page size constants."""
|
||||
assert cache_config.V2_STORE_AGENTS_PAGE_SIZE == 20
|
||||
assert cache_config.V2_STORE_CREATORS_PAGE_SIZE == 20
|
||||
assert cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE == 20
|
||||
assert cache_config.V2_MY_AGENTS_PAGE_SIZE == 20
|
||||
|
||||
def test_v2_library_page_sizes(self):
|
||||
"""Test V2 Library API page size constants."""
|
||||
assert cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE == 10
|
||||
assert cache_config.V2_LIBRARY_PRESETS_PAGE_SIZE == 20
|
||||
assert cache_config.V2_LIBRARY_PRESETS_ALT_PAGE_SIZE == 10
|
||||
|
||||
def test_v1_page_sizes(self):
|
||||
"""Test V1 API page size constants."""
|
||||
assert cache_config.V1_GRAPHS_PAGE_SIZE == 250
|
||||
assert cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE == 10
|
||||
assert cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE == 25
|
||||
|
||||
def test_cache_clearing_config(self):
|
||||
"""Test cache clearing configuration."""
|
||||
assert cache_config.MAX_PAGES_TO_CLEAR == 20
|
||||
|
||||
def test_get_page_sizes_for_clearing_helper(self):
|
||||
"""Test the helper function for getting page sizes to clear."""
|
||||
# Single page size
|
||||
result = cache_config.get_page_sizes_for_clearing(20)
|
||||
assert result == [20]
|
||||
|
||||
# Multiple page sizes
|
||||
result = cache_config.get_page_sizes_for_clearing(20, 10)
|
||||
assert result == [20, 10]
|
||||
|
||||
# With None alt_page_size
|
||||
result = cache_config.get_page_sizes_for_clearing(20, None)
|
||||
assert result == [20]
|
||||
|
||||
|
||||
class TestCacheConfigUsage:
|
||||
"""Test that cache_config constants are actually used in the code."""
|
||||
|
||||
def test_store_routes_import_cache_config(self):
|
||||
"""Verify store routes imports cache_config."""
|
||||
import backend.server.v2.store.routes as store_routes
|
||||
|
||||
# Check that cache_config is imported
|
||||
assert hasattr(store_routes, "backend")
|
||||
assert hasattr(store_routes.backend.server, "cache_config")
|
||||
|
||||
def test_store_cache_uses_constants(self):
|
||||
"""Verify store cache module uses cache_config constants."""
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
|
||||
# Check the module imports cache_config
|
||||
assert hasattr(store_cache, "backend")
|
||||
assert hasattr(store_cache.backend.server, "cache_config")
|
||||
|
||||
# The _clear_submissions_cache function should use the constant
|
||||
import inspect
|
||||
|
||||
source = inspect.getsource(store_cache._clear_submissions_cache)
|
||||
assert (
|
||||
"cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE" in source
|
||||
), "_clear_submissions_cache must use cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE"
|
||||
assert (
|
||||
"cache_config.MAX_PAGES_TO_CLEAR" in source
|
||||
), "_clear_submissions_cache must use cache_config.MAX_PAGES_TO_CLEAR"
|
||||
|
||||
def test_admin_routes_use_constants(self):
|
||||
"""Verify admin routes use cache_config constants."""
|
||||
import backend.server.v2.admin.store_admin_routes as admin_routes
|
||||
|
||||
# Check that cache_config is imported
|
||||
assert hasattr(admin_routes, "backend")
|
||||
assert hasattr(admin_routes.backend.server, "cache_config")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,263 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive test suite for cache invalidation consistency across the entire backend.
|
||||
|
||||
This test file identifies ALL locations where cache_delete is called with hardcoded
|
||||
parameters (especially page_size) and ensures they match the corresponding route defaults.
|
||||
|
||||
CRITICAL: If any test in this file fails, it means cache invalidation will be broken
|
||||
and users will see stale data after mutations.
|
||||
|
||||
Key problem areas identified:
|
||||
1. v1.py routes: Uses page_size=250 for graphs, but cache clearing uses page_size=250 ✓
|
||||
2. v1.py routes: Uses page_size=10 for library agents clearing
|
||||
3. v2/library routes: Uses page_size=10 for library agents clearing
|
||||
4. v2/store routes: Uses page_size=20 for submissions clearing (in _clear_submissions_cache)
|
||||
5. v2/library presets: Uses page_size=10 AND page_size=20 for presets (dual clearing)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestCacheInvalidationConsistency:
|
||||
"""Test that all cache_delete calls use correct parameters matching route defaults."""
|
||||
|
||||
def test_v1_graphs_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test v1 graphs routes use consistent page_size.
|
||||
|
||||
Locations that must match:
|
||||
- routes/v1.py line 682: default page_size=250
|
||||
- routes/v1.py line 765: cache_delete with page_size=250
|
||||
- routes/v1.py line 791: cache_delete with page_size=250
|
||||
- routes/v1.py line 859: cache_delete with page_size=250
|
||||
- routes/v1.py line 929: cache_delete with page_size=250
|
||||
- routes/v1.py line 1034: default page_size=250
|
||||
"""
|
||||
V1_GRAPHS_DEFAULT_PAGE_SIZE = 250
|
||||
|
||||
# This is the expected value - if this test fails, check all the above locations
|
||||
assert V1_GRAPHS_DEFAULT_PAGE_SIZE == 250, (
|
||||
"If you changed the default page_size for v1 graphs, you must update:\n"
|
||||
"1. routes/v1.py list_graphs() default parameter\n"
|
||||
"2. routes/v1.py create_graph() cache_delete call\n"
|
||||
"3. routes/v1.py delete_graph() cache_delete call\n"
|
||||
"4. routes/v1.py update_graph_metadata() cache_delete call\n"
|
||||
"5. routes/v1.py stop_graph_execution() cache_delete call\n"
|
||||
"6. routes/v1.py list_graph_run_events() default parameter"
|
||||
)
|
||||
|
||||
def test_v1_library_agents_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test v1 library agents cache clearing uses consistent page_size.
|
||||
|
||||
Locations that must match:
|
||||
- routes/v1.py line 768: cache_delete with page_size=10
|
||||
- routes/v1.py line 940: cache_delete with page_size=10
|
||||
- v2/library/routes/agents.py line 233: cache_delete with page_size=10
|
||||
|
||||
WARNING: These hardcode page_size=10 but we need to verify this matches
|
||||
the actual page_size used when fetching library agents!
|
||||
"""
|
||||
V1_LIBRARY_AGENTS_CLEARING_PAGE_SIZE = 10
|
||||
|
||||
assert V1_LIBRARY_AGENTS_CLEARING_PAGE_SIZE == 10, (
|
||||
"If you changed the library agents clearing page_size, you must update:\n"
|
||||
"1. routes/v1.py create_graph() cache clearing loop\n"
|
||||
"2. routes/v1.py stop_graph_execution() cache clearing loop\n"
|
||||
"3. v2/library/routes/agents.py add_library_agent() cache clearing loop"
|
||||
)
|
||||
|
||||
# TODO: This should be verified against the actual default used in library routes
|
||||
|
||||
def test_v1_graph_executions_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test v1 graph executions cache clearing uses consistent page_size.
|
||||
|
||||
Locations:
|
||||
- routes/v1.py line 937: cache_delete with page_size=25
|
||||
- v2/library/routes/presets.py line 449: cache_delete with page_size=10
|
||||
- v2/library/routes/presets.py line 452: cache_delete with page_size=25
|
||||
"""
|
||||
V1_GRAPH_EXECUTIONS_CLEARING_PAGE_SIZE = 25
|
||||
|
||||
# Note: presets.py clears BOTH page_size=10 AND page_size=25
|
||||
# This suggests there may be multiple consumers with different page sizes
|
||||
assert V1_GRAPH_EXECUTIONS_CLEARING_PAGE_SIZE == 25
|
||||
|
||||
def test_v2_store_submissions_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test v2 store submissions use consistent page_size.
|
||||
|
||||
Locations that must match:
|
||||
- v2/store/routes.py line 484: default page_size=20
|
||||
- v2/store/cache.py line 18: _clear_submissions_cache uses page_size=20
|
||||
|
||||
This is already tested in test_cache_delete.py but documented here for completeness.
|
||||
"""
|
||||
V2_STORE_SUBMISSIONS_DEFAULT_PAGE_SIZE = 20
|
||||
V2_STORE_SUBMISSIONS_CLEARING_PAGE_SIZE = 20
|
||||
|
||||
assert (
|
||||
V2_STORE_SUBMISSIONS_DEFAULT_PAGE_SIZE
|
||||
== V2_STORE_SUBMISSIONS_CLEARING_PAGE_SIZE
|
||||
), (
|
||||
"The default page_size for store submissions must match the hardcoded value in _clear_submissions_cache!\n"
|
||||
"Update both:\n"
|
||||
"1. v2/store/routes.py get_submissions() default parameter\n"
|
||||
"2. v2/store/cache.py _clear_submissions_cache() hardcoded page_size"
|
||||
)
|
||||
|
||||
def test_v2_library_presets_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test v2 library presets cache clearing uses consistent page_size.
|
||||
|
||||
Locations:
|
||||
- v2/library/routes/presets.py line 36: cache_delete with page_size=10
|
||||
- v2/library/routes/presets.py line 39: cache_delete with page_size=20
|
||||
|
||||
This route clears BOTH page_size=10 and page_size=20, suggesting multiple consumers.
|
||||
"""
|
||||
V2_LIBRARY_PRESETS_CLEARING_PAGE_SIZES = [10, 20]
|
||||
|
||||
assert 10 in V2_LIBRARY_PRESETS_CLEARING_PAGE_SIZES
|
||||
assert 20 in V2_LIBRARY_PRESETS_CLEARING_PAGE_SIZES
|
||||
|
||||
# TODO: Verify these match the actual page_size defaults used in preset routes
|
||||
|
||||
def test_cache_clearing_helper_functions_documented(self):
|
||||
"""
|
||||
Document all cache clearing helper functions and their hardcoded parameters.
|
||||
|
||||
Helper functions that wrap cache_delete with hardcoded params:
|
||||
1. v2/store/cache.py::_clear_submissions_cache() - hardcodes page_size=20, num_pages=20
|
||||
2. v2/library/routes/presets.py::_clear_presets_list_cache() - hardcodes page_size=10 AND 20, num_pages=20
|
||||
|
||||
These helpers are DANGEROUS because:
|
||||
- They hide the hardcoded parameters
|
||||
- They loop through multiple pages with hardcoded page_size
|
||||
- If the route default changes, these won't clear the right cache entries
|
||||
"""
|
||||
HELPER_FUNCTIONS = {
|
||||
"_clear_submissions_cache": {
|
||||
"file": "v2/store/cache.py",
|
||||
"page_size": 20,
|
||||
"num_pages": 20,
|
||||
"risk": "HIGH - single page_size, could miss entries if default changes",
|
||||
},
|
||||
"_clear_presets_list_cache": {
|
||||
"file": "v2/library/routes/presets.py",
|
||||
"page_size": [10, 20],
|
||||
"num_pages": 20,
|
||||
"risk": "MEDIUM - clears multiple page_sizes, but could still miss new ones",
|
||||
},
|
||||
}
|
||||
|
||||
assert (
|
||||
len(HELPER_FUNCTIONS) == 2
|
||||
), "If you add new cache clearing helper functions, document them here!"
|
||||
|
||||
def test_cache_delete_without_page_loops_are_risky(self):
|
||||
"""
|
||||
Document cache_delete calls that clear only page=1 (risky if there are multiple pages).
|
||||
|
||||
Single page cache_delete calls:
|
||||
- routes/v1.py line 765: Only clears page=1 with page_size=250
|
||||
- routes/v1.py line 791: Only clears page=1 with page_size=250
|
||||
- routes/v1.py line 859: Only clears page=1 with page_size=250
|
||||
|
||||
These are RISKY because:
|
||||
- If a user has more than one page of graphs, pages 2+ won't be invalidated
|
||||
- User could see stale data on pagination
|
||||
|
||||
RECOMMENDATION: Use cache_clear() or loop through multiple pages like
|
||||
_clear_submissions_cache does.
|
||||
"""
|
||||
SINGLE_PAGE_CLEARS = [
|
||||
"routes/v1.py line 765: create_graph clears only page=1",
|
||||
"routes/v1.py line 791: delete_graph clears only page=1",
|
||||
"routes/v1.py line 859: update_graph_metadata clears only page=1",
|
||||
]
|
||||
|
||||
# This test documents the issue but doesn't fail
|
||||
# Consider this a TODO to fix these cache clearing strategies
|
||||
assert (
|
||||
len(SINGLE_PAGE_CLEARS) >= 3
|
||||
), "These cache_delete calls should probably loop through multiple pages"
|
||||
|
||||
def test_all_cached_functions_have_proper_invalidation(self):
|
||||
"""
|
||||
Verify all @cached functions have corresponding cache_delete calls.
|
||||
|
||||
Functions with proper invalidation:
|
||||
✓ get_cached_user_profile - cleared on profile update
|
||||
✓ get_cached_store_agents - cleared on admin review (cache_clear)
|
||||
✓ get_cached_submissions - cleared via _clear_submissions_cache helper
|
||||
✓ get_cached_graphs - cleared on graph mutations
|
||||
✓ get_cached_library_agents - cleared on library changes
|
||||
|
||||
Functions that might not have proper invalidation:
|
||||
? get_cached_agent_details - not explicitly cleared
|
||||
? get_cached_store_creators - not explicitly cleared
|
||||
? get_cached_my_agents - not explicitly cleared (no helper function exists!)
|
||||
|
||||
This is a documentation test - actual verification requires code analysis.
|
||||
"""
|
||||
NEEDS_VERIFICATION = [
|
||||
"get_cached_agent_details",
|
||||
"get_cached_store_creators",
|
||||
"get_cached_my_agents", # NO CLEARING FUNCTION EXISTS!
|
||||
]
|
||||
|
||||
assert "get_cached_my_agents" in NEEDS_VERIFICATION, (
|
||||
"get_cached_my_agents has no cache clearing logic - this is a BUG!\n"
|
||||
"When a user creates/deletes an agent, their 'my agents' list won't update."
|
||||
)
|
||||
|
||||
|
||||
class TestCacheKeyParameterOrdering:
|
||||
"""
|
||||
Test that cache_delete calls use the same parameter order as the @cached function.
|
||||
|
||||
The @cached decorator uses function signature order to create cache keys.
|
||||
cache_delete must use the exact same order or it won't find the cached entry!
|
||||
"""
|
||||
|
||||
def test_cached_function_parameter_order_matters(self):
|
||||
"""
|
||||
Document that parameter order in cache_delete must match @cached function signature.
|
||||
|
||||
Example from v2/store/cache.py:
|
||||
|
||||
@cached(...)
|
||||
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
|
||||
...
|
||||
|
||||
CORRECT: _get_cached_submissions.cache_delete(user_id, page=1, page_size=20)
|
||||
WRONG: _get_cached_submissions.cache_delete(page=1, user_id=user_id, page_size=20)
|
||||
|
||||
The cached decorator generates keys based on the POSITIONAL order, so parameter
|
||||
order must match between the function definition and cache_delete call.
|
||||
"""
|
||||
# This is a documentation test - no assertion needed
|
||||
# Real verification requires inspecting each cache_delete call
|
||||
pass
|
||||
|
||||
def test_named_parameters_vs_positional_in_cache_delete(self):
|
||||
"""
|
||||
Document best practice: use named parameters in cache_delete for safety.
|
||||
|
||||
Good practice seen in codebase:
|
||||
- cache.get_cached_graphs.cache_delete(user_id=user_id, page=1, page_size=250)
|
||||
- library_cache.get_cached_library_agents.cache_delete(user_id=user_id, page=page, page_size=10)
|
||||
|
||||
This is safer than positional arguments because:
|
||||
1. More readable
|
||||
2. Less likely to get order wrong
|
||||
3. Self-documenting what each parameter means
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -457,8 +457,7 @@ async def test_api_key_with_unicode_characters_normalization_attack(mock_request
|
||||
"""Test that Unicode normalization doesn't bypass validation."""
|
||||
# Create auth with composed Unicode character
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="café", # é is composed
|
||||
header_name="X-API-Key", expected_token="café" # é is composed
|
||||
)
|
||||
|
||||
# Try with decomposed version (c + a + f + e + ´)
|
||||
@@ -523,8 +522,8 @@ async def test_api_keys_with_newline_variations(mock_request):
|
||||
"valid\r\ntoken", # Windows newline
|
||||
"valid\rtoken", # Mac newline
|
||||
"valid\x85token", # NEL (Next Line)
|
||||
"valid\x0btoken", # Vertical Tab
|
||||
"valid\x0ctoken", # Form Feed
|
||||
"valid\x0Btoken", # Vertical Tab
|
||||
"valid\x0Ctoken", # Form Feed
|
||||
]
|
||||
|
||||
for api_key in newline_variations:
|
||||
|
||||
@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoModManager:
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._load_config()
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.cache_config
|
||||
import backend.server.v2.store.cache
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
@@ -31,7 +29,7 @@ async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = backend.server.cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get store listings with their version history for admins.
|
||||
@@ -95,8 +93,6 @@ async def review_submission(
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
backend.server.v2.store.cache._clear_submissions_cache(submission.user_id)
|
||||
backend.server.v2.store.cache._get_cached_store_agents.cache_clear()
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.exception("Error reviewing submission: %s", e)
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import prisma
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.data.block
|
||||
from backend.blocks import load_all_blocks
|
||||
@@ -17,7 +18,6 @@ from backend.server.v2.builder.model import (
|
||||
ProviderResponse,
|
||||
SearchBlocksResponse,
|
||||
)
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -296,7 +296,7 @@ def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
@cached()
|
||||
def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
providers: dict[ProviderName, Provider] = {}
|
||||
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
"""
|
||||
Cache functions for Library API endpoints.
|
||||
|
||||
This module contains all caching decorators and helpers for the Library API,
|
||||
separated from the main routes for better organization and maintainability.
|
||||
"""
|
||||
|
||||
import backend.server.v2.library.db
|
||||
from backend.util.cache import cached
|
||||
|
||||
# ===== Library Agent Caches =====
|
||||
|
||||
|
||||
# Cache library agents list for 10 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=600, shared_cache=True)
|
||||
async def get_cached_library_agents(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""Cached helper to get library agents list."""
|
||||
return await backend.server.v2.library.db.list_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache user's favorite agents for 5 minutes - favorites change more frequently
|
||||
@cached(maxsize=500, ttl_seconds=300, shared_cache=True)
|
||||
async def get_cached_library_agent_favorites(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""Cached helper to get user's favorite library agents."""
|
||||
return await backend.server.v2.library.db.list_favorite_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual library agent details for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library agent details."""
|
||||
return await backend.server.v2.library.db.get_library_agent(
|
||||
id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# Cache library agent by graph ID for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_agent_by_graph_id(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library agent by graph ID."""
|
||||
return await backend.server.v2.library.db.get_library_agent_by_graph_id(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# Cache library agent by store version ID for 1 hour - marketplace agents are more stable
|
||||
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
|
||||
async def get_cached_library_agent_by_store_version(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library agent by store version ID."""
|
||||
return await backend.server.v2.library.db.get_library_agent_by_store_version_id(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# ===== Library Preset Caches =====
|
||||
|
||||
|
||||
# Cache library presets list for 30 minutes
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_presets(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""Cached helper to get library presets list."""
|
||||
return await backend.server.v2.library.db.list_presets(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual preset details for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_preset(
|
||||
preset_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library preset details."""
|
||||
return await backend.server.v2.library.db.get_preset(
|
||||
preset_id=preset_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -1,286 +0,0 @@
|
||||
"""
|
||||
Tests for cache invalidation in Library API routes.
|
||||
|
||||
This module tests that library caches are properly invalidated when data is modified.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as library_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id():
|
||||
"""Generate a mock user ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_library_agent_id():
|
||||
"""Generate a mock library agent ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestLibraryAgentCacheInvalidation:
|
||||
"""Test cache invalidation for library agent operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_agent_clears_list_cache(self, mock_user_id):
|
||||
"""Test that adding an agent clears the library agents list cache."""
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_agents.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
library_db, "list_library_agents", new_callable=AsyncMock
|
||||
) as mock_list:
|
||||
mock_response = {"agents": [], "total_count": 0, "page": 1, "page_size": 20}
|
||||
mock_list.return_value = mock_response
|
||||
|
||||
# First call hits database
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
assert mock_list.call_count == 1
|
||||
|
||||
# Second call uses cache
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
assert mock_list.call_count == 1 # Still 1, cache used
|
||||
|
||||
# Simulate adding an agent (cache invalidation)
|
||||
for page in range(1, 5):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 15
|
||||
)
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 20
|
||||
)
|
||||
|
||||
# Next call should hit database
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
assert mock_list.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_agent_clears_multiple_caches(
|
||||
self, mock_user_id, mock_library_agent_id
|
||||
):
|
||||
"""Test that deleting an agent clears both specific and list caches."""
|
||||
# Clear caches
|
||||
library_cache.get_cached_library_agent.cache_clear()
|
||||
library_cache.get_cached_library_agents.cache_clear()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
library_db, "get_library_agent", new_callable=AsyncMock
|
||||
) as mock_get,
|
||||
patch.object(
|
||||
library_db, "list_library_agents", new_callable=AsyncMock
|
||||
) as mock_list,
|
||||
):
|
||||
mock_agent = {"id": mock_library_agent_id, "name": "Test Agent"}
|
||||
mock_get.return_value = mock_agent
|
||||
mock_list.return_value = {
|
||||
"agents": [mock_agent],
|
||||
"total_count": 1,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
}
|
||||
|
||||
# Populate caches
|
||||
await library_cache.get_cached_library_agent(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
|
||||
initial_calls = {
|
||||
"get": mock_get.call_count,
|
||||
"list": mock_list.call_count,
|
||||
}
|
||||
|
||||
# Verify cache is used
|
||||
await library_cache.get_cached_library_agent(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
|
||||
# Simulate delete_library_agent cache invalidation
|
||||
library_cache.get_cached_library_agent.cache_delete(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
for page in range(1, 5):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 15
|
||||
)
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 20
|
||||
)
|
||||
|
||||
# Next calls should hit database
|
||||
await library_cache.get_cached_library_agent(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_favorites_cache_operations(self, mock_user_id):
|
||||
"""Test that favorites cache works independently."""
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_agent_favorites.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
library_db, "list_favorite_library_agents", new_callable=AsyncMock
|
||||
) as mock_favs:
|
||||
mock_response = {"agents": [], "total_count": 0, "page": 1, "page_size": 20}
|
||||
mock_favs.return_value = mock_response
|
||||
|
||||
# First call hits database
|
||||
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
|
||||
assert mock_favs.call_count == 1
|
||||
|
||||
# Second call uses cache
|
||||
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
|
||||
assert mock_favs.call_count == 1 # Cache used
|
||||
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_agent_favorites.cache_delete(
|
||||
mock_user_id, 1, 20
|
||||
)
|
||||
|
||||
# Next call hits database
|
||||
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
|
||||
assert mock_favs.call_count == 2
|
||||
|
||||
|
||||
class TestLibraryPresetCacheInvalidation:
|
||||
"""Test cache invalidation for library preset operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preset_cache_operations(self, mock_user_id):
|
||||
"""Test preset cache and invalidation."""
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_presets.cache_clear()
|
||||
library_cache.get_cached_library_preset.cache_clear()
|
||||
|
||||
preset_id = str(uuid.uuid4())
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
library_db, "list_presets", new_callable=AsyncMock
|
||||
) as mock_list,
|
||||
patch.object(library_db, "get_preset", new_callable=AsyncMock) as mock_get,
|
||||
):
|
||||
mock_preset = {"id": preset_id, "name": "Test Preset"}
|
||||
mock_list.return_value = {
|
||||
"presets": [mock_preset],
|
||||
"total_count": 1,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
}
|
||||
mock_get.return_value = mock_preset
|
||||
|
||||
# Populate caches
|
||||
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
|
||||
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
|
||||
|
||||
initial_calls = {
|
||||
"list": mock_list.call_count,
|
||||
"get": mock_get.call_count,
|
||||
}
|
||||
|
||||
# Verify cache is used
|
||||
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
|
||||
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
|
||||
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
|
||||
# Clear specific preset cache
|
||||
library_cache.get_cached_library_preset.cache_delete(
|
||||
preset_id, mock_user_id
|
||||
)
|
||||
|
||||
# Clear list cache
|
||||
library_cache.get_cached_library_presets.cache_delete(mock_user_id, 1, 20)
|
||||
|
||||
# Next calls should hit database
|
||||
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
|
||||
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
|
||||
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
|
||||
|
||||
class TestLibraryCacheMetrics:
|
||||
"""Test library cache metrics and management."""
|
||||
|
||||
def test_cache_info_structure(self):
|
||||
"""Test that cache_info returns expected structure."""
|
||||
info = library_cache.get_cached_library_agents.cache_info()
|
||||
|
||||
assert "size" in info
|
||||
assert "maxsize" in info
|
||||
assert "ttl_seconds" in info
|
||||
assert (
|
||||
info["maxsize"] is None
|
||||
) # Redis manages its own size with shared_cache=True
|
||||
assert info["ttl_seconds"] == 600 # 10 minutes
|
||||
|
||||
def test_all_library_caches_can_be_cleared(self):
|
||||
"""Test that all library caches can be cleared."""
|
||||
# Clear all library caches
|
||||
library_cache.get_cached_library_agents.cache_clear()
|
||||
library_cache.get_cached_library_agent_favorites.cache_clear()
|
||||
library_cache.get_cached_library_agent.cache_clear()
|
||||
library_cache.get_cached_library_agent_by_graph_id.cache_clear()
|
||||
library_cache.get_cached_library_agent_by_store_version.cache_clear()
|
||||
library_cache.get_cached_library_presets.cache_clear()
|
||||
library_cache.get_cached_library_preset.cache_clear()
|
||||
|
||||
# Verify all are empty
|
||||
assert library_cache.get_cached_library_agents.cache_info()["size"] == 0
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_favorites.cache_info()["size"] == 0
|
||||
)
|
||||
assert library_cache.get_cached_library_agent.cache_info()["size"] == 0
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_by_graph_id.cache_info()["size"] == 0
|
||||
)
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_by_store_version.cache_info()["size"]
|
||||
== 0
|
||||
)
|
||||
assert library_cache.get_cached_library_presets.cache_info()["size"] == 0
|
||||
assert library_cache.get_cached_library_preset.cache_info()["size"] == 0
|
||||
|
||||
def test_cache_ttl_values(self):
|
||||
"""Test that cache TTL values are set correctly."""
|
||||
# Library agents - 10 minutes
|
||||
assert (
|
||||
library_cache.get_cached_library_agents.cache_info()["ttl_seconds"] == 600
|
||||
)
|
||||
|
||||
# Favorites - 5 minutes (more dynamic)
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_favorites.cache_info()["ttl_seconds"]
|
||||
== 300
|
||||
)
|
||||
|
||||
# Individual agent - 30 minutes
|
||||
assert (
|
||||
library_cache.get_cached_library_agent.cache_info()["ttl_seconds"] == 1800
|
||||
)
|
||||
|
||||
# Presets - 30 minutes
|
||||
assert (
|
||||
library_cache.get_cached_library_presets.cache_info()["ttl_seconds"] == 1800
|
||||
)
|
||||
assert (
|
||||
library_cache.get_cached_library_preset.cache_info()["ttl_seconds"] == 1800
|
||||
)
|
||||
@@ -101,9 +101,7 @@ async def list_library_agents(
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
@@ -187,9 +185,7 @@ async def list_favorite_library_agents(
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
@@ -421,9 +417,7 @@ async def create_library_agent(
|
||||
}
|
||||
},
|
||||
),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
for graph_entry in graph_entries
|
||||
)
|
||||
@@ -648,9 +642,7 @@ async def add_store_agent_to_library(
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
|
||||
@@ -177,9 +177,7 @@ async def test_add_agent_to_library(mocker):
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include(
|
||||
"test-user", include_nodes=False, include_executions=False
|
||||
),
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
from fastapi.responses import Response
|
||||
|
||||
import backend.server.cache_config
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
@@ -66,22 +64,13 @@ async def list_library_agents(
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
# Use cache for default queries (no search term, default sort)
|
||||
if search_term is None and sort_by == library_model.LibraryAgentSort.UPDATED_AT:
|
||||
return await library_cache.get_cached_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
else:
|
||||
# Direct DB query for searches and custom sorts
|
||||
return await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_term,
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_term,
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not list library agents for user #{user_id}: {e}")
|
||||
raise HTTPException(
|
||||
@@ -125,7 +114,7 @@ async def list_favorite_library_agents(
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_cache.get_cached_library_agent_favorites(
|
||||
return await library_db.list_favorite_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@@ -143,9 +132,7 @@ async def get_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
return await library_cache.get_cached_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/by-graph/{graph_id}")
|
||||
@@ -223,21 +210,11 @@ async def add_marketplace_agent_to_library(
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
result = await library_db.add_store_agent_to_library(
|
||||
return await library_db.add_store_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Clear library caches after adding new agent
|
||||
for page in range(1, backend.server.cache_config.MAX_PAGES_TO_CLEAR):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except store_exceptions.AgentNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Could not find store listing version {store_listing_version_id} "
|
||||
@@ -286,22 +263,13 @@ async def update_library_agent(
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
result = await library_db.update_library_agent(
|
||||
return await library_db.update_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
auto_update_version=payload.auto_update_version,
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
)
|
||||
|
||||
for page in range(1, backend.server.cache_config.MAX_PAGES_TO_CLEAR):
|
||||
library_cache.get_cached_library_agent_favorites.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
return result
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -352,18 +320,6 @@ async def delete_library_agent(
|
||||
await library_db.delete_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Clear caches after deleting agent
|
||||
library_cache.get_cached_library_agent.cache_delete(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
for page in range(1, backend.server.cache_config.MAX_PAGES_TO_CLEAR):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -4,9 +4,6 @@ from typing import Any, Optional
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
|
||||
import backend.server.cache_config
|
||||
import backend.server.routers.cache as cache
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
@@ -28,24 +25,6 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _clear_presets_list_cache(
|
||||
user_id: str, num_pages: int = backend.server.cache_config.MAX_PAGES_TO_CLEAR
|
||||
):
|
||||
"""
|
||||
Clear the presets list cache for the given user.
|
||||
Clears both primary and alternative page sizes for backward compatibility.
|
||||
"""
|
||||
page_sizes = backend.server.cache_config.get_page_sizes_for_clearing(
|
||||
backend.server.cache_config.V2_LIBRARY_PRESETS_PAGE_SIZE,
|
||||
backend.server.cache_config.V2_LIBRARY_PRESETS_ALT_PAGE_SIZE,
|
||||
)
|
||||
for page in range(1, num_pages + 1):
|
||||
for page_size in page_sizes:
|
||||
library_cache.get_cached_library_presets.cache_delete(
|
||||
user_id=user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/presets",
|
||||
summary="List presets",
|
||||
@@ -72,21 +51,12 @@ async def list_presets(
|
||||
models.LibraryAgentPresetResponse: A response containing the list of presets.
|
||||
"""
|
||||
try:
|
||||
# Use cache only for default queries (no filter)
|
||||
if graph_id is None:
|
||||
return await library_cache.get_cached_library_presets(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
else:
|
||||
# Direct DB query for filtered requests
|
||||
return await db.list_presets(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return await db.list_presets(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list presets for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
@@ -117,7 +87,7 @@ async def get_preset(
|
||||
HTTPException: If the preset is not found or an error occurs.
|
||||
"""
|
||||
try:
|
||||
preset = await library_cache.get_cached_library_preset(preset_id, user_id)
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error retrieving preset %s for user %s: %s", preset_id, user_id, e
|
||||
@@ -161,13 +131,9 @@ async def create_preset(
|
||||
"""
|
||||
try:
|
||||
if isinstance(preset, models.LibraryAgentPresetCreatable):
|
||||
result = await db.create_preset(user_id, preset)
|
||||
return await db.create_preset(user_id, preset)
|
||||
else:
|
||||
result = await db.create_preset_from_graph_execution(user_id, preset)
|
||||
|
||||
_clear_presets_list_cache(user_id)
|
||||
|
||||
return result
|
||||
return await db.create_preset_from_graph_execution(user_id, preset)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except Exception as e:
|
||||
@@ -234,9 +200,6 @@ async def setup_trigger(
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
|
||||
_clear_presets_list_cache(user_id)
|
||||
|
||||
return new_preset
|
||||
|
||||
|
||||
@@ -315,13 +278,6 @@ async def update_preset(
|
||||
description=preset.description,
|
||||
is_active=preset.is_active,
|
||||
)
|
||||
|
||||
# Clear caches after updating preset
|
||||
library_cache.get_cached_library_preset.cache_delete(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
_clear_presets_list_cache(user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Preset update failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
@@ -395,12 +351,6 @@ async def delete_preset(
|
||||
|
||||
try:
|
||||
await db.delete_preset(user_id, preset_id)
|
||||
|
||||
# Clear caches after deleting preset
|
||||
library_cache.get_cached_library_preset.cache_delete(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
_clear_presets_list_cache(user_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error deleting preset %s for user %s: %s", preset_id, user_id, e
|
||||
@@ -451,33 +401,6 @@ async def execute_preset(
|
||||
merged_node_input = preset.inputs | inputs
|
||||
merged_credential_inputs = preset.credentials | credential_inputs
|
||||
|
||||
# Clear graph executions cache - use both page sizes for compatibility
|
||||
for page in range(1, 10):
|
||||
# Clear with alternative page size
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
graph_id=preset.graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_GRAPH_EXECUTIONS_ALT_PAGE_SIZE,
|
||||
)
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_GRAPH_EXECUTIONS_ALT_PAGE_SIZE,
|
||||
)
|
||||
# Clear with v1 page size (25)
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
graph_id=preset.graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE,
|
||||
)
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
return await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
|
||||
@@ -179,15 +179,14 @@ async def test_get_favorite_library_agents_success(
|
||||
def test_get_favorite_library_agents_error(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
# Mock the cache function instead of the DB directly since routes now use cache
|
||||
mock_cache_call = mocker.patch(
|
||||
"backend.server.v2.library.routes.agents.library_cache.get_cached_library_agent_favorites"
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_cache_call.side_effect = Exception("Test error")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents/favorites")
|
||||
assert response.status_code == 500
|
||||
mock_cache_call.assert_called_once_with(
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
page=1,
|
||||
page_size=15,
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
"""
|
||||
Cache functions for Store API endpoints.
|
||||
|
||||
This module contains all caching decorators and helpers for the Store API,
|
||||
separated from the main routes for better organization and maintainability.
|
||||
"""
|
||||
|
||||
import backend.server.cache_config
|
||||
import backend.server.v2.store.db
|
||||
from backend.util.cache import cached
|
||||
|
||||
|
||||
def _clear_submissions_cache(
|
||||
user_id: str, num_pages: int = backend.server.cache_config.MAX_PAGES_TO_CLEAR
|
||||
):
|
||||
"""
|
||||
Clear the submissions cache for the given user.
|
||||
|
||||
Args:
|
||||
user_id: User ID whose cache should be cleared
|
||||
num_pages: Number of pages to clear (default from cache_config)
|
||||
"""
|
||||
for page in range(1, num_pages + 1):
|
||||
_get_cached_submissions.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
|
||||
def _clear_my_agents_cache(
|
||||
user_id: str, num_pages: int = backend.server.cache_config.MAX_PAGES_TO_CLEAR
|
||||
):
|
||||
"""
|
||||
Clear the my agents cache for the given user.
|
||||
|
||||
Args:
|
||||
user_id: User ID whose cache should be cleared
|
||||
num_pages: Number of pages to clear (default from cache_config)
|
||||
"""
|
||||
for page in range(1, num_pages + 1):
|
||||
_get_cached_my_agents.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_MY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# Cache user profiles for 1 hour per user
|
||||
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_user_profile(user_id: str):
|
||||
"""Cached helper to get user profile."""
|
||||
return await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
|
||||
|
||||
# Cache store agents list for 15 minutes
|
||||
# Different cache entries for different query combinations
|
||||
@cached(maxsize=5000, ttl_seconds=900, shared_cache=True)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: str | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=900, shared_cache=True)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
|
||||
|
||||
# Cache agent graphs for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_agent_graph(store_listing_version_id: str):
|
||||
"""Cached helper to get agent graph."""
|
||||
return await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache agent by version for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
|
||||
"""Cached helper to get store agent by version ID."""
|
||||
return await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache creators list for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store creators."""
|
||||
return await backend.server.v2.store.db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual creator details for 1 hour
|
||||
@cached(maxsize=100, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
|
||||
|
||||
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
|
||||
@cached(maxsize=500, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's agents."""
|
||||
return await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
|
||||
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's submissions."""
|
||||
return await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -493,7 +493,6 @@ async def get_store_submissions(
|
||||
submission_models = []
|
||||
for sub in submissions:
|
||||
submission_model = backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=sub.user_id,
|
||||
agent_id=sub.agent_id,
|
||||
agent_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
@@ -711,7 +710,6 @@ async def create_store_submission(
|
||||
logger.debug(f"Created store listing for agent {agent_id}")
|
||||
# Return submission details
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -862,7 +860,6 @@ async def edit_store_submission(
|
||||
"Failed to update store listing version"
|
||||
)
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=user_id,
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
name=name,
|
||||
@@ -996,7 +993,6 @@ async def create_store_version(
|
||||
)
|
||||
# Return submission details
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -1497,7 +1493,7 @@ async def review_store_submission(
|
||||
include={"StoreListing": True},
|
||||
)
|
||||
|
||||
if not submission or not submission.StoreListing:
|
||||
if not submission:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
f"Failed to update store listing version {store_listing_version_id}"
|
||||
)
|
||||
@@ -1587,7 +1583,6 @@ async def review_store_submission(
|
||||
|
||||
# Convert to Pydantic model for consistency
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=submission.StoreListing.owningUserId,
|
||||
agent_id=submission.agentGraphId,
|
||||
agent_version=submission.agentGraphVersion,
|
||||
name=submission.name,
|
||||
@@ -1720,17 +1715,14 @@ async def get_admin_listings_with_versions(
|
||||
# Get total count for pagination
|
||||
total = await prisma.models.StoreListing.prisma().count(where=where)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert to response models
|
||||
listings_with_versions = []
|
||||
for listing in listings:
|
||||
versions: list[backend.server.v2.store.model.StoreSubmission] = []
|
||||
if not listing.OwningUser:
|
||||
logger.error(f"Listing {listing.id} has no owning user")
|
||||
continue
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=listing.OwningUser.id,
|
||||
agent_id=version.agentGraphId,
|
||||
agent_version=version.agentGraphVersion,
|
||||
name=version.name,
|
||||
|
||||
@@ -98,7 +98,6 @@ class Profile(pydantic.BaseModel):
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
user_id: str = pydantic.Field(default="", exclude=True)
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
name: str
|
||||
|
||||
@@ -135,7 +135,6 @@ def test_creator_details():
|
||||
|
||||
def test_store_submission():
|
||||
submission = backend.server.v2.store.model.StoreSubmission(
|
||||
user_id="user123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
@@ -157,7 +156,6 @@ def test_store_submissions_response():
|
||||
response = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
user_id="user123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
|
||||
@@ -6,33 +6,132 @@ import urllib.parse
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.cache_config
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.image_gen
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
from backend.server.v2.store.cache import (
|
||||
_clear_submissions_cache,
|
||||
_get_cached_agent_details,
|
||||
_get_cached_agent_graph,
|
||||
_get_cached_creator_details,
|
||||
_get_cached_my_agents,
|
||||
_get_cached_store_agent_by_version,
|
||||
_get_cached_store_agents,
|
||||
_get_cached_store_creators,
|
||||
_get_cached_submissions,
|
||||
_get_cached_user_profile,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
##############################################
|
||||
|
||||
|
||||
# Cache user profiles for 1 hour per user
|
||||
@cached(maxsize=1000, ttl_seconds=3600)
|
||||
async def _get_cached_user_profile(user_id: str):
|
||||
"""Cached helper to get user profile."""
|
||||
return await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
|
||||
|
||||
# Cache store agents list for 15 minutes
|
||||
# Different cache entries for different query combinations
|
||||
@cached(maxsize=5000, ttl_seconds=900)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: str | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=900)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
|
||||
|
||||
# Cache agent graphs for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_agent_graph(store_listing_version_id: str):
|
||||
"""Cached helper to get agent graph."""
|
||||
return await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache agent by version for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
|
||||
"""Cached helper to get store agent by version ID."""
|
||||
return await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache creators list for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store creators."""
|
||||
return await backend.server.v2.store.db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual creator details for 1 hour
|
||||
@cached(maxsize=100, ttl_seconds=3600)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
|
||||
|
||||
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
|
||||
@cached(maxsize=500, ttl_seconds=300)
|
||||
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's agents."""
|
||||
return await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
|
||||
@cached(maxsize=500, ttl_seconds=3600)
|
||||
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's submissions."""
|
||||
return await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Profile Endpoints ############
|
||||
##############################################
|
||||
@@ -131,7 +230,7 @@ async def get_agents(
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = backend.server.cache_config.V2_STORE_AGENTS_PAGE_SIZE,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
@@ -329,7 +428,7 @@ async def get_creators(
|
||||
search_query: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = backend.server.cache_config.V2_STORE_CREATORS_PAGE_SIZE,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
This is needed for:
|
||||
@@ -415,9 +514,7 @@ async def get_creator(
|
||||
async def get_my_agents(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
|
||||
page_size: typing.Annotated[
|
||||
int, fastapi.Query(ge=1)
|
||||
] = backend.server.cache_config.V2_MY_AGENTS_PAGE_SIZE,
|
||||
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
|
||||
):
|
||||
"""
|
||||
Get user's own agents.
|
||||
@@ -463,7 +560,10 @@ async def delete_submission(
|
||||
|
||||
# Clear submissions cache for this specific user after deletion
|
||||
if result:
|
||||
_clear_submissions_cache(user_id)
|
||||
# Clear user's own agents cache - we don't know all page/size combinations
|
||||
for page in range(1, 20):
|
||||
# Clear user's submissions cache for common defaults
|
||||
_get_cached_submissions.cache_delete(user_id, page=page, page_size=20)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
@@ -484,7 +584,7 @@ async def delete_submission(
|
||||
async def get_submissions(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = 1,
|
||||
page_size: int = backend.server.cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
@@ -566,7 +666,10 @@ async def create_submission(
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
_clear_submissions_cache(user_id)
|
||||
# Clear user's own agents cache - we don't know all page/size combinations
|
||||
for page in range(1, 20):
|
||||
# Clear user's submissions cache for common defaults
|
||||
_get_cached_submissions.cache_delete(user_id, page=page, page_size=20)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
@@ -617,7 +720,10 @@ async def edit_submission(
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
_clear_submissions_cache(user_id)
|
||||
# Clear user's own agents cache - we don't know all page/size combinations
|
||||
for page in range(1, 20):
|
||||
# Clear user's submissions cache for common defaults
|
||||
_get_cached_submissions.cache_delete(user_id, page=page, page_size=20)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -534,7 +534,6 @@ def test_get_submissions_success(
|
||||
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
user_id="user123",
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
|
||||
@@ -345,150 +345,6 @@ class TestCacheDeletion:
|
||||
)
|
||||
assert deleted is False # Different parameters, not in cache
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_submissions_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test that _clear_submissions_cache uses the correct page_size.
|
||||
This test ensures that if the default page_size in routes changes,
|
||||
the hardcoded value in _clear_submissions_cache must also change.
|
||||
"""
|
||||
from backend.server.v2.store.model import StoreSubmissionsResponse
|
||||
|
||||
mock_response = StoreSubmissionsResponse(
|
||||
submissions=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_submissions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
# Clear cache first
|
||||
routes._get_cached_submissions.cache_clear()
|
||||
|
||||
# Populate cache with multiple pages using the default page_size
|
||||
DEFAULT_PAGE_SIZE = 20 # This should match the default in routes.py
|
||||
user_id = "test_user"
|
||||
|
||||
# Add entries for pages 1-5
|
||||
for page in range(1, 6):
|
||||
await routes._get_cached_submissions(
|
||||
user_id=user_id, page=page, page_size=DEFAULT_PAGE_SIZE
|
||||
)
|
||||
|
||||
# Verify cache has entries
|
||||
cache_info_before = routes._get_cached_submissions.cache_info()
|
||||
assert cache_info_before["size"] == 5
|
||||
|
||||
# Call _clear_submissions_cache
|
||||
routes._clear_submissions_cache(user_id, num_pages=20)
|
||||
|
||||
# All entries should be cleared
|
||||
cache_info_after = routes._get_cached_submissions.cache_info()
|
||||
assert (
|
||||
cache_info_after["size"] == 0
|
||||
), "Cache should be empty after _clear_submissions_cache"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_submissions_cache_detects_page_size_mismatch(self):
|
||||
"""
|
||||
Test that detects if _clear_submissions_cache is using wrong page_size.
|
||||
If this test fails, it means the hardcoded page_size in _clear_submissions_cache
|
||||
doesn't match the default page_size used in the routes.
|
||||
"""
|
||||
from backend.server.v2.store.model import StoreSubmissionsResponse
|
||||
|
||||
mock_response = StoreSubmissionsResponse(
|
||||
submissions=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_submissions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
# Clear cache first
|
||||
routes._get_cached_submissions.cache_clear()
|
||||
|
||||
# WRONG_PAGE_SIZE simulates what happens if someone changes
|
||||
# the default page_size in routes but forgets to update _clear_submissions_cache
|
||||
WRONG_PAGE_SIZE = 25 # Different from the hardcoded value in cache.py
|
||||
user_id = "test_user"
|
||||
|
||||
# Populate cache with the "wrong" page_size
|
||||
for page in range(1, 6):
|
||||
await routes._get_cached_submissions(
|
||||
user_id=user_id, page=page, page_size=WRONG_PAGE_SIZE
|
||||
)
|
||||
|
||||
# Verify cache has entries
|
||||
cache_info_before = routes._get_cached_submissions.cache_info()
|
||||
assert cache_info_before["size"] == 5
|
||||
|
||||
# Call _clear_submissions_cache (which uses page_size=20 hardcoded)
|
||||
routes._clear_submissions_cache(user_id, num_pages=20)
|
||||
|
||||
# If page_size is mismatched, entries won't be cleared
|
||||
cache_info_after = routes._get_cached_submissions.cache_info()
|
||||
|
||||
# This assertion will FAIL if _clear_submissions_cache uses wrong page_size
|
||||
assert (
|
||||
cache_info_after["size"] == 5
|
||||
), "Cache entries with different page_size should NOT be cleared (this is expected)"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_my_agents_cache_needs_clearing_too(self):
|
||||
"""
|
||||
Test that demonstrates _get_cached_my_agents also needs cache clearing.
|
||||
Currently there's no _clear_my_agents_cache function, but there should be.
|
||||
"""
|
||||
from backend.server.v2.store.model import MyAgentsResponse
|
||||
|
||||
mock_response = MyAgentsResponse(
|
||||
agents=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_my_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
routes._get_cached_my_agents.cache_clear()
|
||||
|
||||
DEFAULT_PAGE_SIZE = 20
|
||||
user_id = "test_user"
|
||||
|
||||
# Populate cache
|
||||
for page in range(1, 6):
|
||||
await routes._get_cached_my_agents(
|
||||
user_id=user_id, page=page, page_size=DEFAULT_PAGE_SIZE
|
||||
)
|
||||
|
||||
cache_info = routes._get_cached_my_agents.cache_info()
|
||||
assert cache_info["size"] == 5
|
||||
|
||||
# NOTE: Currently there's no _clear_my_agents_cache function
|
||||
# If we implement one, it should clear all pages consistently
|
||||
# For now we document this as a TODO
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
|
||||
@@ -1,461 +0,0 @@
|
||||
"""
|
||||
Caching utilities for the AutoGPT platform.
|
||||
|
||||
Provides decorators for caching function results with support for:
|
||||
- In-memory caching with TTL
|
||||
- Shared Redis-backed caching across processes
|
||||
- Thread-local caching for request-scoped data
|
||||
- Thundering herd protection
|
||||
- LRU eviction with optional TTL refresh
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, ParamSpec, Protocol, TypeVar, cast, runtime_checkable
|
||||
|
||||
from redis import ConnectionPool, Redis
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# RECOMMENDED REDIS CONFIGURATION FOR PRODUCTION:
|
||||
# Configure Redis with the following settings for optimal caching performance:
|
||||
# maxmemory-policy allkeys-lru # Evict least recently used keys when memory limit reached
|
||||
# maxmemory 2gb # Set memory limit (adjust based on your needs)
|
||||
# save "" # Disable persistence if using Redis purely for caching
|
||||
|
||||
# Create a dedicated Redis connection pool for caching (binary mode for pickle)
|
||||
_cache_pool: ConnectionPool | None = None
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring cache connection pool")
|
||||
def _get_cache_pool() -> ConnectionPool:
|
||||
"""Get or create a connection pool for cache operations."""
|
||||
global _cache_pool
|
||||
if _cache_pool is None:
|
||||
_cache_pool = ConnectionPool(
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
max_connections=50,
|
||||
socket_keepalive=True,
|
||||
socket_connect_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _cache_pool
|
||||
|
||||
|
||||
def _get_redis_client() -> Redis:
|
||||
"""Get a Redis client from the connection pool."""
|
||||
return Redis(connection_pool=_get_cache_pool())
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedValue:
|
||||
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
|
||||
|
||||
result: Any
|
||||
timestamp: float
|
||||
|
||||
|
||||
def _make_hashable_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Convert args and kwargs into a hashable cache key.
|
||||
|
||||
Handles unhashable types like dict, list, set by converting them to
|
||||
their sorted string representations.
|
||||
"""
|
||||
|
||||
def make_hashable(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a hashable representation."""
|
||||
if isinstance(obj, dict):
|
||||
# Sort dict items to ensure consistent ordering
|
||||
return (
|
||||
"__dict__",
|
||||
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
|
||||
)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return ("__list__", tuple(make_hashable(item) for item in obj))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# Handle objects with __dict__ attribute
|
||||
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
|
||||
else:
|
||||
# For basic hashable types (str, int, bool, None, etc.)
|
||||
try:
|
||||
hash(obj)
|
||||
return obj
|
||||
except TypeError:
|
||||
# Fallback: convert to string representation
|
||||
return ("__str__", str(obj))
|
||||
|
||||
hashable_args = tuple(make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
|
||||
return (hashable_args, hashable_kwargs)
|
||||
|
||||
|
||||
def _make_redis_key(key: tuple[Any, ...]) -> str:
|
||||
"""Convert a hashable key tuple to a Redis key string."""
|
||||
# Ensure key is already hashable
|
||||
hashable_key = key if isinstance(key, tuple) else (key,)
|
||||
return f"cache:{hash(hashable_key)}"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for cached functions with cache management methods."""
|
||||
|
||||
def cache_clear(self, pattern: str | None = None) -> None:
|
||||
"""Clear cached entries. If pattern provided, clear matching entries only."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
|
||||
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
|
||||
return False
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def cached(
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int,
|
||||
shared_cache: bool = False,
|
||||
refresh_ttl_on_get: bool = False,
|
||||
) -> Callable[[Callable], CachedFunction]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
|
||||
Uses double-checked locking to prevent multiple threads/coroutines from
|
||||
executing the expensive operation simultaneously during cache misses.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries (only for in-memory cache)
|
||||
ttl_seconds: Time to live in seconds. Required - entries must expire.
|
||||
shared_cache: If True, use Redis for cross-process caching
|
||||
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
|
||||
|
||||
Returns:
|
||||
Decorated function with caching capabilities
|
||||
|
||||
Example:
|
||||
@cached(ttl_seconds=300) # 5 minute TTL
|
||||
def expensive_sync_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(target_func):
|
||||
cache_storage: dict[tuple, CachedValue] = {}
|
||||
_event_loop_locks: dict[Any, asyncio.Lock] = {}
|
||||
|
||||
def _get_from_redis(redis_key: str) -> Any | None:
|
||||
"""Get value from Redis, optionally refreshing TTL."""
|
||||
try:
|
||||
import pickle
|
||||
|
||||
redis = _get_redis_client()
|
||||
if refresh_ttl_on_get:
|
||||
# Use GETEX to get value and refresh expiry atomically
|
||||
cached_bytes = redis.getex(redis_key, ex=ttl_seconds)
|
||||
else:
|
||||
cached_bytes = redis.get(redis_key)
|
||||
|
||||
if cached_bytes and isinstance(cached_bytes, bytes):
|
||||
return pickle.loads(cached_bytes)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error during cache check for {target_func.__name__}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _set_to_redis(redis_key: str, value: Any) -> None:
|
||||
"""Set value in Redis with TTL."""
|
||||
try:
|
||||
import pickle
|
||||
|
||||
redis = _get_redis_client()
|
||||
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
redis.setex(redis_key, ttl_seconds, pickled_value)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error storing cache for {target_func.__name__}: {e}"
|
||||
)
|
||||
|
||||
def _get_from_memory(key: tuple) -> Any | None:
|
||||
"""Get value from in-memory cache, checking TTL."""
|
||||
if key in cache_storage:
|
||||
cached_data = cache_storage[key]
|
||||
if time.time() - cached_data.timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {target_func.__name__} args: {key[0]} kwargs: {key[1]}"
|
||||
)
|
||||
return cached_data.result
|
||||
return None
|
||||
|
||||
def _set_to_memory(key: tuple, value: Any) -> None:
|
||||
"""Set value in in-memory cache with timestamp."""
|
||||
cache_storage[key] = CachedValue(result=value, timestamp=time.time())
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
|
||||
def _get_cache_lock():
|
||||
"""Get or create an asyncio.Lock for the current event loop."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop not in _event_loop_locks:
|
||||
_event_loop_locks[loop] = asyncio.Lock()
|
||||
return _event_loop_locks[loop]
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = _make_redis_key(key) if shared_cache else ""
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with _get_cache_lock():
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = async_wrapper
|
||||
|
||||
else:
|
||||
# Sync function with threading.Lock
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = _make_redis_key(key) if shared_cache else ""
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Add cache management methods
|
||||
def cache_clear(pattern: str | None = None) -> None:
|
||||
"""Clear cache entries. If pattern provided, clear matching entries."""
|
||||
if shared_cache:
|
||||
redis = _get_redis_client()
|
||||
if pattern:
|
||||
# Clear entries matching pattern
|
||||
keys = list(redis.scan_iter(f"cache:{pattern}", count=100))
|
||||
else:
|
||||
# Clear all cache keys
|
||||
keys = list(redis.scan_iter("cache:*", count=100))
|
||||
|
||||
if keys:
|
||||
pipeline = redis.pipeline()
|
||||
for key in keys:
|
||||
pipeline.delete(key)
|
||||
pipeline.execute()
|
||||
else:
|
||||
if pattern:
|
||||
# For in-memory cache, pattern matching not supported
|
||||
logger.warning(
|
||||
"Pattern-based clearing not supported for in-memory cache"
|
||||
)
|
||||
else:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
if shared_cache:
|
||||
redis = _get_redis_client()
|
||||
cache_keys = list(redis.scan_iter("cache:*"))
|
||||
return {
|
||||
"size": len(cache_keys),
|
||||
"maxsize": None, # Redis manages its own size
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
def cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if shared_cache:
|
||||
redis = _get_redis_client()
|
||||
redis_key = _make_redis_key(key)
|
||||
if redis.exists(redis_key):
|
||||
redis.delete(redis_key)
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
setattr(wrapper, "cache_delete", cache_delete)
|
||||
|
||||
return cast(CachedFunction, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def thread_cached(func):
|
||||
"""
|
||||
Thread-local cache decorator for both sync and async functions.
|
||||
|
||||
Each thread gets its own cache, which is useful for request-scoped caching
|
||||
in web applications where you want to cache within a single request but
|
||||
not across requests.
|
||||
|
||||
Args:
|
||||
func: The function to cache
|
||||
|
||||
Returns:
|
||||
Decorated function with thread-local caching
|
||||
|
||||
Example:
|
||||
@thread_cached
|
||||
def expensive_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@thread_cached # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = await func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
"""Clear thread-local cache for a function."""
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
@@ -4,7 +4,8 @@ Centralized service client helpers with thread caching.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.util.cache import cached, thread_cached
|
||||
from autogpt_libs.utils.cache import cached, thread_cached
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
@@ -117,7 +118,7 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
# ============ Supabase Clients ============ #
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
@cached()
|
||||
def get_supabase() -> "Client":
|
||||
"""Get a process-cached synchronous Supabase client instance."""
|
||||
from supabase import create_client
|
||||
@@ -127,7 +128,7 @@ def get_supabase() -> "Client":
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
@cached()
|
||||
async def get_async_supabase() -> "AClient":
|
||||
"""Get a process-cached asynchronous Supabase client instance."""
|
||||
from supabase import create_async_client
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"""
|
||||
Utilities for handling dynamic field names and delimiters in the AutoGPT Platform.
|
||||
|
||||
Dynamic fields allow graphs to connect complex data structures using special delimiters:
|
||||
- _#_ for dictionary keys (e.g., "values_#_name" → values["name"])
|
||||
- _$_ for list indices (e.g., "items_$_0" → items[0])
|
||||
- _@_ for object attributes (e.g., "obj_@_attr" → obj.attr)
|
||||
|
||||
This module provides utilities for:
|
||||
- Extracting base field names from dynamic field names
|
||||
- Generating proper schemas for base fields
|
||||
- Creating helper functions for field sanitization
|
||||
"""
|
||||
|
||||
from backend.data.dynamic_fields import DICT_SPLIT, LIST_SPLIT, OBJC_SPLIT
|
||||
|
||||
# All dynamic field delimiters
|
||||
DYNAMIC_DELIMITERS = (DICT_SPLIT, LIST_SPLIT, OBJC_SPLIT)
|
||||
|
||||
|
||||
def extract_base_field_name(field_name: str) -> str:
|
||||
"""
|
||||
Extract the base field name from a dynamic field name.
|
||||
|
||||
Examples:
|
||||
extract_base_field_name("values_#_name") → "values"
|
||||
extract_base_field_name("items_$_0") → "items"
|
||||
extract_base_field_name("obj_@_attr") → "obj"
|
||||
extract_base_field_name("regular_field") → "regular_field"
|
||||
|
||||
Args:
|
||||
field_name: The field name that may contain dynamic delimiters
|
||||
|
||||
Returns:
|
||||
The base field name without any dynamic suffixes
|
||||
"""
|
||||
base_name = field_name
|
||||
for delimiter in DYNAMIC_DELIMITERS:
|
||||
if delimiter in base_name:
|
||||
base_name = base_name.split(delimiter)[0]
|
||||
return base_name
|
||||
|
||||
|
||||
def is_dynamic_field(field_name: str) -> bool:
|
||||
"""
|
||||
Check if a field name contains dynamic delimiters.
|
||||
|
||||
Args:
|
||||
field_name: The field name to check
|
||||
|
||||
Returns:
|
||||
True if the field contains any dynamic delimiters, False otherwise
|
||||
"""
|
||||
return any(delimiter in field_name for delimiter in DYNAMIC_DELIMITERS)
|
||||
|
||||
|
||||
def get_dynamic_field_description(
|
||||
base_field_name: str, original_field_name: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate a description for a dynamic field based on its base field and structure.
|
||||
|
||||
Args:
|
||||
base_field_name: The base field name (e.g., "values")
|
||||
original_field_name: The full dynamic field name (e.g., "values_#_name")
|
||||
|
||||
Returns:
|
||||
A descriptive string explaining what this dynamic field represents
|
||||
"""
|
||||
if DICT_SPLIT in original_field_name:
|
||||
key_part = (
|
||||
original_field_name.split(DICT_SPLIT, 1)[1].split(DICT_SPLIT[0])[0]
|
||||
if DICT_SPLIT in original_field_name
|
||||
else "key"
|
||||
)
|
||||
return f"Dictionary value for {base_field_name}['{key_part}']"
|
||||
elif LIST_SPLIT in original_field_name:
|
||||
index_part = (
|
||||
original_field_name.split(LIST_SPLIT, 1)[1].split(LIST_SPLIT[0])[0]
|
||||
if LIST_SPLIT in original_field_name
|
||||
else "index"
|
||||
)
|
||||
return f"List item for {base_field_name}[{index_part}]"
|
||||
elif OBJC_SPLIT in original_field_name:
|
||||
attr_part = (
|
||||
original_field_name.split(OBJC_SPLIT, 1)[1].split(OBJC_SPLIT[0])[0]
|
||||
if OBJC_SPLIT in original_field_name
|
||||
else "attr"
|
||||
)
|
||||
return f"Object attribute for {base_field_name}.{attr_part}"
|
||||
else:
|
||||
return f"Dynamic value for {base_field_name}"
|
||||
|
||||
|
||||
def group_fields_by_base_name(field_names: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Group a list of field names by their base field names.
|
||||
|
||||
Args:
|
||||
field_names: List of field names that may contain dynamic delimiters
|
||||
|
||||
Returns:
|
||||
Dictionary mapping base field names to lists of original field names
|
||||
|
||||
Example:
|
||||
group_fields_by_base_name([
|
||||
"values_#_name",
|
||||
"values_#_age",
|
||||
"items_$_0",
|
||||
"regular_field"
|
||||
])
|
||||
→ {
|
||||
"values": ["values_#_name", "values_#_age"],
|
||||
"items": ["items_$_0"],
|
||||
"regular_field": ["regular_field"]
|
||||
}
|
||||
"""
|
||||
grouped = {}
|
||||
for field_name in field_names:
|
||||
base_name = extract_base_field_name(field_name)
|
||||
if base_name not in grouped:
|
||||
grouped[base_name] = []
|
||||
grouped[base_name].append(field_name)
|
||||
return grouped
|
||||
@@ -1,175 +0,0 @@
|
||||
"""Tests for dynamic field utilities."""
|
||||
|
||||
from backend.util.dynamic_fields import (
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
group_fields_by_base_name,
|
||||
is_dynamic_field,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractBaseFieldName:
|
||||
"""Test extracting base field names from dynamic field names."""
|
||||
|
||||
def test_extract_dict_field(self):
|
||||
"""Test extracting base name from dictionary fields."""
|
||||
assert extract_base_field_name("values_#_name") == "values"
|
||||
assert extract_base_field_name("data_#_key1_#_key2") == "data"
|
||||
assert extract_base_field_name("config_#_database_#_host") == "config"
|
||||
|
||||
def test_extract_list_field(self):
|
||||
"""Test extracting base name from list fields."""
|
||||
assert extract_base_field_name("items_$_0") == "items"
|
||||
assert extract_base_field_name("results_$_5_$_10") == "results"
|
||||
assert extract_base_field_name("nested_$_0_$_1_$_2") == "nested"
|
||||
|
||||
def test_extract_object_field(self):
|
||||
"""Test extracting base name from object fields."""
|
||||
assert extract_base_field_name("user_@_name") == "user"
|
||||
assert extract_base_field_name("response_@_data_@_items") == "response"
|
||||
assert extract_base_field_name("obj_@_attr1_@_attr2") == "obj"
|
||||
|
||||
def test_extract_mixed_fields(self):
|
||||
"""Test extracting base name from mixed dynamic fields."""
|
||||
assert extract_base_field_name("data_$_0_#_key") == "data"
|
||||
assert extract_base_field_name("items_#_user_@_name") == "items"
|
||||
assert extract_base_field_name("complex_$_0_@_attr_#_key") == "complex"
|
||||
|
||||
def test_extract_regular_field(self):
|
||||
"""Test extracting base name from regular (non-dynamic) fields."""
|
||||
assert extract_base_field_name("regular_field") == "regular_field"
|
||||
assert extract_base_field_name("simple") == "simple"
|
||||
assert extract_base_field_name("") == ""
|
||||
|
||||
def test_extract_field_with_underscores(self):
|
||||
"""Test fields with regular underscores (not dynamic delimiters)."""
|
||||
assert extract_base_field_name("field_name_here") == "field_name_here"
|
||||
assert extract_base_field_name("my_field_#_key") == "my_field"
|
||||
|
||||
|
||||
class TestIsDynamicField:
|
||||
"""Test identifying dynamic fields."""
|
||||
|
||||
def test_is_dynamic_dict_field(self):
|
||||
"""Test identifying dictionary dynamic fields."""
|
||||
assert is_dynamic_field("values_#_name") is True
|
||||
assert is_dynamic_field("data_#_key1_#_key2") is True
|
||||
|
||||
def test_is_dynamic_list_field(self):
|
||||
"""Test identifying list dynamic fields."""
|
||||
assert is_dynamic_field("items_$_0") is True
|
||||
assert is_dynamic_field("results_$_5_$_10") is True
|
||||
|
||||
def test_is_dynamic_object_field(self):
|
||||
"""Test identifying object dynamic fields."""
|
||||
assert is_dynamic_field("user_@_name") is True
|
||||
assert is_dynamic_field("response_@_data_@_items") is True
|
||||
|
||||
def test_is_dynamic_mixed_field(self):
|
||||
"""Test identifying mixed dynamic fields."""
|
||||
assert is_dynamic_field("data_$_0_#_key") is True
|
||||
assert is_dynamic_field("items_#_user_@_name") is True
|
||||
|
||||
def test_is_not_dynamic_field(self):
|
||||
"""Test identifying non-dynamic fields."""
|
||||
assert is_dynamic_field("regular_field") is False
|
||||
assert is_dynamic_field("field_name_here") is False
|
||||
assert is_dynamic_field("simple") is False
|
||||
assert is_dynamic_field("") is False
|
||||
|
||||
|
||||
class TestGetDynamicFieldDescription:
|
||||
"""Test generating descriptions for dynamic fields."""
|
||||
|
||||
def test_dict_field_description(self):
|
||||
"""Test descriptions for dictionary fields."""
|
||||
desc = get_dynamic_field_description("values", "values_#_name")
|
||||
assert "Dictionary value for values['name']" == desc
|
||||
|
||||
desc = get_dynamic_field_description("config", "config_#_database")
|
||||
assert "Dictionary value for config['database']" == desc
|
||||
|
||||
def test_list_field_description(self):
|
||||
"""Test descriptions for list fields."""
|
||||
desc = get_dynamic_field_description("items", "items_$_0")
|
||||
assert "List item for items[0]" == desc
|
||||
|
||||
desc = get_dynamic_field_description("results", "results_$_5")
|
||||
assert "List item for results[5]" == desc
|
||||
|
||||
def test_object_field_description(self):
|
||||
"""Test descriptions for object fields."""
|
||||
desc = get_dynamic_field_description("user", "user_@_name")
|
||||
assert "Object attribute for user.name" == desc
|
||||
|
||||
desc = get_dynamic_field_description("response", "response_@_data")
|
||||
assert "Object attribute for response.data" == desc
|
||||
|
||||
def test_fallback_description(self):
|
||||
"""Test fallback description for non-dynamic fields."""
|
||||
desc = get_dynamic_field_description("field", "field")
|
||||
assert "Dynamic value for field" == desc
|
||||
|
||||
|
||||
class TestGroupFieldsByBaseName:
|
||||
"""Test grouping fields by their base names."""
|
||||
|
||||
def test_group_mixed_fields(self):
|
||||
"""Test grouping a mix of dynamic and regular fields."""
|
||||
fields = [
|
||||
"values_#_name",
|
||||
"values_#_age",
|
||||
"items_$_0",
|
||||
"items_$_1",
|
||||
"user_@_email",
|
||||
"regular_field",
|
||||
"another_field",
|
||||
]
|
||||
|
||||
result = group_fields_by_base_name(fields)
|
||||
|
||||
expected = {
|
||||
"values": ["values_#_name", "values_#_age"],
|
||||
"items": ["items_$_0", "items_$_1"],
|
||||
"user": ["user_@_email"],
|
||||
"regular_field": ["regular_field"],
|
||||
"another_field": ["another_field"],
|
||||
}
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_group_empty_list(self):
|
||||
"""Test grouping an empty list."""
|
||||
result = group_fields_by_base_name([])
|
||||
assert result == {}
|
||||
|
||||
def test_group_single_field(self):
|
||||
"""Test grouping a single field."""
|
||||
result = group_fields_by_base_name(["values_#_name"])
|
||||
assert result == {"values": ["values_#_name"]}
|
||||
|
||||
def test_group_complex_dynamic_fields(self):
|
||||
"""Test grouping complex nested dynamic fields."""
|
||||
fields = [
|
||||
"data_$_0_#_key1",
|
||||
"data_$_0_#_key2",
|
||||
"data_$_1_#_key1",
|
||||
"other_@_attr",
|
||||
]
|
||||
|
||||
result = group_fields_by_base_name(fields)
|
||||
|
||||
expected = {
|
||||
"data": ["data_$_0_#_key1", "data_$_0_#_key2", "data_$_1_#_key1"],
|
||||
"other": ["other_@_attr"],
|
||||
}
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_preserve_order(self):
|
||||
"""Test that field order is preserved within groups."""
|
||||
fields = ["values_#_c", "values_#_a", "values_#_b"]
|
||||
result = group_fields_by_base_name(fields)
|
||||
|
||||
# Should preserve the original order
|
||||
assert result["values"] == ["values_#_c", "values_#_a", "values_#_b"]
|
||||
@@ -5,12 +5,12 @@ from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
import ldclient
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from backend.util.cache import cached
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -66,18 +66,6 @@ async def store_media_file(
|
||||
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Security fix: Add disk space limits to prevent DoS
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file
|
||||
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
||||
|
||||
# Check total disk usage in base_path
|
||||
if base_path.exists():
|
||||
current_usage = get_dir_size(base_path)
|
||||
if current_usage > MAX_TOTAL_DISK_USAGE:
|
||||
raise ValueError(
|
||||
f"Disk usage limit exceeded: {current_usage} bytes > {MAX_TOTAL_DISK_USAGE} bytes"
|
||||
)
|
||||
|
||||
# Helper functions
|
||||
def _extension_from_mime(mime: str) -> str:
|
||||
ext = mimetypes.guess_extension(mime, strict=False)
|
||||
@@ -120,12 +108,6 @@ async def store_media_file(
|
||||
filename = Path(path_part).name or f"{uuid.uuid4()}.bin"
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
|
||||
# Check file size limit
|
||||
if len(cloud_content) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the cloud content before writing locally
|
||||
await scan_content_safe(cloud_content, filename=filename)
|
||||
target_path.write_bytes(cloud_content)
|
||||
@@ -147,12 +129,6 @@ async def store_media_file(
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
content = base64.b64decode(b64_content)
|
||||
|
||||
# Check file size limit
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the base64 content before writing
|
||||
await scan_content_safe(content, filename=filename)
|
||||
target_path.write_bytes(content)
|
||||
@@ -166,12 +142,6 @@ async def store_media_file(
|
||||
# Download and save
|
||||
resp = await Requests().get(file)
|
||||
|
||||
# Check file size limit
|
||||
if len(resp.content) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the downloaded content before writing
|
||||
await scan_content_safe(resp.content, filename=filename)
|
||||
target_path.write_bytes(resp.content)
|
||||
@@ -189,18 +159,6 @@ async def store_media_file(
|
||||
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
||||
|
||||
|
||||
def get_dir_size(path: Path) -> int:
|
||||
"""Get total size of directory."""
|
||||
total = 0
|
||||
try:
|
||||
for entry in path.glob("**/*"):
|
||||
if entry.is_file():
|
||||
total += entry.stat().st_size
|
||||
except Exception:
|
||||
pass
|
||||
return total
|
||||
|
||||
|
||||
def get_mime_type(file: str) -> str:
|
||||
"""
|
||||
Get the MIME type of a file, whether it's a data URI, URL, or local path.
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Type, TypeGuard, TypeVar, overload
|
||||
|
||||
import jsonschema
|
||||
import orjson
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from prisma import Json
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .type import type_match
|
||||
|
||||
# Precompiled regex to remove PostgreSQL-incompatible control characters
|
||||
# Removes \u0000-\u0008, \u000B-\u000C, \u000E-\u001F, \u007F (keeps tab \u0009, newline \u000A, carriage return \u000D)
|
||||
POSTGRES_CONTROL_CHARS = re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]")
|
||||
# Try to import orjson for better performance
|
||||
try:
|
||||
import orjson
|
||||
|
||||
HAS_ORJSON = True
|
||||
except ImportError:
|
||||
HAS_ORJSON = False
|
||||
|
||||
|
||||
def to_dict(data) -> dict:
|
||||
@@ -21,28 +23,22 @@ def to_dict(data) -> dict:
|
||||
return jsonable_encoder(data)
|
||||
|
||||
|
||||
def dumps(
|
||||
data: Any, *args: Any, indent: int | None = None, option: int = 0, **kwargs: Any
|
||||
) -> str:
|
||||
def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
"""
|
||||
Serialize data to JSON string with automatic conversion of Pydantic models and complex types.
|
||||
|
||||
This function converts the input data to a JSON-serializable format using FastAPI's
|
||||
jsonable_encoder before dumping to JSON. It handles Pydantic models, complex types,
|
||||
and ensures proper serialization.
|
||||
and ensures proper serialization. Uses orjson for better performance when available.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : Any
|
||||
The data to serialize. Can be any type including Pydantic models, dicts, lists, etc.
|
||||
*args : Any
|
||||
Additional positional arguments
|
||||
indent : int | None
|
||||
If not None, pretty-print with indentation
|
||||
option : int
|
||||
orjson option flags (default: 0)
|
||||
Additional positional arguments passed to json.dumps() (ignored if using orjson)
|
||||
**kwargs : Any
|
||||
Additional keyword arguments. Supported: default, ensure_ascii, separators, indent
|
||||
Additional keyword arguments passed to json.dumps() (limited support with orjson)
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -59,19 +55,16 @@ def dumps(
|
||||
"""
|
||||
serializable_data = to_dict(data)
|
||||
|
||||
# Handle indent parameter
|
||||
if indent is not None or kwargs.get("indent") is not None:
|
||||
option |= orjson.OPT_INDENT_2
|
||||
|
||||
# orjson only accepts specific parameters, filter out stdlib json params
|
||||
# ensure_ascii: orjson always produces UTF-8 (better than ASCII)
|
||||
# separators: orjson uses compact separators by default
|
||||
supported_orjson_params = {"default"}
|
||||
orjson_kwargs = {k: v for k, v in kwargs.items() if k in supported_orjson_params}
|
||||
|
||||
return orjson.dumps(serializable_data, option=option, **orjson_kwargs).decode(
|
||||
"utf-8"
|
||||
)
|
||||
if HAS_ORJSON:
|
||||
# orjson is faster but has limited options support
|
||||
option = 0
|
||||
if kwargs.get("indent") is not None:
|
||||
option |= orjson.OPT_INDENT_2
|
||||
# orjson.dumps returns bytes, so we decode to str
|
||||
return orjson.dumps(serializable_data, option=option).decode("utf-8")
|
||||
else:
|
||||
# Fallback to standard json
|
||||
return json.dumps(serializable_data, *args, **kwargs)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -88,7 +81,14 @@ def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
) -> Any:
|
||||
parsed = orjson.loads(data)
|
||||
if HAS_ORJSON:
|
||||
# orjson can handle both str and bytes directly
|
||||
parsed = orjson.loads(data)
|
||||
else:
|
||||
# Standard json requires string input
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
@@ -123,20 +123,41 @@ def convert_pydantic_to_json(output_data: Any) -> Any:
|
||||
return output_data
|
||||
|
||||
|
||||
def _sanitize_null_bytes(data: Any) -> Any:
|
||||
"""
|
||||
Recursively sanitize null bytes from data structures to prevent PostgreSQL 22P05 errors.
|
||||
PostgreSQL cannot store null bytes (\u0000) in text fields.
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return data.replace("\u0000", "")
|
||||
elif isinstance(data, dict):
|
||||
return {key: _sanitize_null_bytes(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [_sanitize_null_bytes(item) for item in data]
|
||||
elif isinstance(data, tuple):
|
||||
return tuple(_sanitize_null_bytes(item) for item in data)
|
||||
else:
|
||||
# For other types (int, float, bool, None, etc.), return as-is
|
||||
return data
|
||||
|
||||
|
||||
def SafeJson(data: Any) -> Json:
|
||||
"""
|
||||
Safely serialize data and return Prisma's Json type.
|
||||
Sanitizes null bytes to prevent PostgreSQL 22P05 errors.
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
json_string = data.model_dump_json(
|
||||
warnings="error",
|
||||
exclude_none=True,
|
||||
fallback=lambda v: None,
|
||||
)
|
||||
else:
|
||||
json_string = dumps(data, default=lambda v: None)
|
||||
# Sanitize null bytes before serialization
|
||||
sanitized_data = _sanitize_null_bytes(data)
|
||||
|
||||
# Remove PostgreSQL-incompatible control characters in single regex operation
|
||||
sanitized_json = POSTGRES_CONTROL_CHARS.sub("", json_string)
|
||||
return Json(json.loads(sanitized_json))
|
||||
if isinstance(sanitized_data, BaseModel):
|
||||
return Json(
|
||||
sanitized_data.model_dump(
|
||||
mode="json",
|
||||
warnings="error",
|
||||
exclude_none=True,
|
||||
fallback=lambda v: None,
|
||||
)
|
||||
)
|
||||
# Round-trip through JSON to ensure proper serialization with fallback for non-serializable values
|
||||
json_string = dumps(sanitized_data, default=lambda v: None)
|
||||
return Json(json.loads(json_string))
|
||||
|
||||
@@ -4,7 +4,6 @@ from enum import Enum
|
||||
import sentry_sdk
|
||||
from pydantic import SecretStr
|
||||
from sentry_sdk.integrations.anthropic import AnthropicIntegration
|
||||
from sentry_sdk.integrations.asyncio import AsyncioIntegration
|
||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
from backend.util.settings import Settings
|
||||
@@ -26,7 +25,6 @@ def sentry_init():
|
||||
environment=f"app:{settings.config.app_env.value}-behave:{settings.config.behave_as.value}",
|
||||
_experiments={"enable_logs": True},
|
||||
integrations=[
|
||||
AsyncioIntegration(),
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
AnthropicIntegration(
|
||||
include_prompts=False,
|
||||
|
||||
@@ -19,48 +19,9 @@ def _msg_tokens(msg: dict, enc) -> int:
|
||||
"""
|
||||
OpenAI counts ≈3 wrapper tokens per chat message, plus 1 if "name"
|
||||
is present, plus the tokenised content length.
|
||||
For tool calls, we need to count tokens in tool_calls and content fields.
|
||||
"""
|
||||
WRAPPER = 3 + (1 if "name" in msg else 0)
|
||||
|
||||
# Count content tokens
|
||||
content_tokens = _tok_len(msg.get("content") or "", enc)
|
||||
|
||||
# Count tool call tokens for both OpenAI and Anthropic formats
|
||||
tool_call_tokens = 0
|
||||
|
||||
# OpenAI format: tool_calls array at message level
|
||||
if "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
||||
for tool_call in msg["tool_calls"]:
|
||||
# Count the tool call structure tokens
|
||||
tool_call_tokens += _tok_len(tool_call.get("id", ""), enc)
|
||||
tool_call_tokens += _tok_len(tool_call.get("type", ""), enc)
|
||||
if "function" in tool_call:
|
||||
tool_call_tokens += _tok_len(tool_call["function"].get("name", ""), enc)
|
||||
tool_call_tokens += _tok_len(
|
||||
tool_call["function"].get("arguments", ""), enc
|
||||
)
|
||||
|
||||
# Anthropic format: tool_use within content array
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "tool_use":
|
||||
# Count the tool use structure tokens
|
||||
tool_call_tokens += _tok_len(item.get("id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("name", ""), enc)
|
||||
tool_call_tokens += _tok_len(json.dumps(item.get("input", {})), enc)
|
||||
elif isinstance(item, dict) and item.get("type") == "tool_result":
|
||||
# Count tool result tokens
|
||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
elif isinstance(item, dict) and "content" in item:
|
||||
# Other content types with content field
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
# For list content, override content_tokens since we counted everything above
|
||||
content_tokens = 0
|
||||
|
||||
return WRAPPER + content_tokens + tool_call_tokens
|
||||
return WRAPPER + _tok_len(msg.get("content") or "", enc)
|
||||
|
||||
|
||||
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
|
||||
@@ -1,278 +0,0 @@
|
||||
"""Tests for prompt utility functions, especially tool call token counting."""
|
||||
|
||||
import pytest
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import _msg_tokens, estimate_token_count
|
||||
|
||||
|
||||
class TestMsgTokens:
|
||||
"""Test the _msg_tokens function with various message types."""
|
||||
|
||||
@pytest.fixture
|
||||
def enc(self):
|
||||
"""Get the encoding for gpt-4o model."""
|
||||
return encoding_for_model("gpt-4o")
|
||||
|
||||
def test_regular_message_token_counting(self, enc):
|
||||
"""Test that regular messages are counted correctly (backward compatibility)."""
|
||||
msg = {"role": "user", "content": "What's the weather like in San Francisco?"}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should be wrapper (3) + content tokens
|
||||
expected = 3 + len(enc.encode(msg["content"]))
|
||||
assert tokens == expected
|
||||
assert tokens > 3 # Has content
|
||||
|
||||
def test_regular_message_with_name(self, enc):
|
||||
"""Test that messages with name field get extra wrapper token."""
|
||||
msg = {"role": "user", "name": "test_user", "content": "Hello!"}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should be wrapper (3 + 1 for name) + content tokens
|
||||
expected = 4 + len(enc.encode(msg["content"]))
|
||||
assert tokens == expected
|
||||
|
||||
def test_openai_tool_call_token_counting(self, enc):
|
||||
"""Test OpenAI format tool call token counting."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco", "unit": "celsius"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count wrapper + all tool call components
|
||||
expected_tool_tokens = (
|
||||
len(enc.encode("call_abc123"))
|
||||
+ len(enc.encode("function"))
|
||||
+ len(enc.encode("get_weather"))
|
||||
+ len(enc.encode('{"location": "San Francisco", "unit": "celsius"}'))
|
||||
)
|
||||
expected = 3 + expected_tool_tokens # wrapper + tool tokens
|
||||
|
||||
assert tokens == expected
|
||||
assert tokens > 8 # Should be significantly more than just wrapper
|
||||
|
||||
def test_openai_multiple_tool_calls(self, enc):
|
||||
"""Test OpenAI format with multiple tool calls."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "func1", "arguments": '{"arg": "value1"}'},
|
||||
},
|
||||
{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {"name": "func2", "arguments": '{"arg": "value2"}'},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count all tool calls
|
||||
assert tokens > 20 # Should be more than single tool call
|
||||
|
||||
def test_anthropic_tool_use_token_counting(self, enc):
|
||||
"""Test Anthropic format tool use token counting."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_xyz456",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "San Francisco", "unit": "celsius"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count wrapper + tool use components
|
||||
expected_tool_tokens = (
|
||||
len(enc.encode("toolu_xyz456"))
|
||||
+ len(enc.encode("get_weather"))
|
||||
+ len(
|
||||
enc.encode(json.dumps({"location": "San Francisco", "unit": "celsius"}))
|
||||
)
|
||||
)
|
||||
expected = 3 + expected_tool_tokens # wrapper + tool tokens
|
||||
|
||||
assert tokens == expected
|
||||
assert tokens > 8 # Should be significantly more than just wrapper
|
||||
|
||||
def test_anthropic_tool_result_token_counting(self, enc):
|
||||
"""Test Anthropic format tool result token counting."""
|
||||
msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_xyz456",
|
||||
"content": "The weather in San Francisco is 22°C and sunny.",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count wrapper + tool result components
|
||||
expected_tool_tokens = len(enc.encode("toolu_xyz456")) + len(
|
||||
enc.encode("The weather in San Francisco is 22°C and sunny.")
|
||||
)
|
||||
expected = 3 + expected_tool_tokens # wrapper + tool tokens
|
||||
|
||||
assert tokens == expected
|
||||
assert tokens > 8 # Should be significantly more than just wrapper
|
||||
|
||||
def test_anthropic_mixed_content(self, enc):
|
||||
"""Test Anthropic format with mixed content types."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "content": "I'll check the weather for you."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "SF"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count all content items
|
||||
assert tokens > 15 # Should count both text and tool use
|
||||
|
||||
def test_empty_content(self, enc):
|
||||
"""Test message with empty or None content."""
|
||||
msg = {"role": "assistant", "content": None}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
assert tokens == 3 # Just wrapper tokens
|
||||
|
||||
msg["content"] = ""
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
assert tokens == 3 # Just wrapper tokens
|
||||
|
||||
def test_string_content_with_tool_calls(self, enc):
|
||||
"""Test OpenAI format where content is string but tool_calls exist."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": "Let me check that for you.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {"name": "test_func", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count both content and tool calls
|
||||
content_tokens = len(enc.encode("Let me check that for you."))
|
||||
tool_tokens = (
|
||||
len(enc.encode("call_123"))
|
||||
+ len(enc.encode("function"))
|
||||
+ len(enc.encode("test_func"))
|
||||
+ len(enc.encode("{}"))
|
||||
)
|
||||
expected = 3 + content_tokens + tool_tokens
|
||||
|
||||
assert tokens == expected
|
||||
|
||||
|
||||
class TestEstimateTokenCount:
|
||||
"""Test the estimate_token_count function with conversations containing tool calls."""
|
||||
|
||||
def test_conversation_with_tool_calls(self):
|
||||
"""Test token counting for a complete conversation with tool calls."""
|
||||
conversation = [
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "San Francisco"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_123",
|
||||
"content": "22°C and sunny",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The weather in San Francisco is 22°C and sunny.",
|
||||
},
|
||||
]
|
||||
|
||||
total_tokens = estimate_token_count(conversation)
|
||||
|
||||
# Verify total equals sum of individual messages
|
||||
enc = encoding_for_model("gpt-4o")
|
||||
expected_total = sum(_msg_tokens(msg, enc) for msg in conversation)
|
||||
|
||||
assert total_tokens == expected_total
|
||||
assert total_tokens > 40 # Should be substantial for this conversation
|
||||
|
||||
def test_openai_conversation(self):
|
||||
"""Test token counting for OpenAI format conversation."""
|
||||
conversation = [
|
||||
{"role": "user", "content": "Calculate 2 + 2"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_calc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"arguments": '{"expression": "2 + 2"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_calc", "content": "4"},
|
||||
{"role": "assistant", "content": "The result is 4."},
|
||||
]
|
||||
|
||||
total_tokens = estimate_token_count(conversation)
|
||||
|
||||
# Verify total equals sum of individual messages
|
||||
enc = encoding_for_model("gpt-4o")
|
||||
expected_total = sum(_msg_tokens(msg, enc) for msg in conversation)
|
||||
|
||||
assert total_tokens == expected_total
|
||||
assert total_tokens > 20 # Should be substantial
|
||||
@@ -21,10 +21,10 @@ logger = logging.getLogger(__name__)
|
||||
EXCESSIVE_RETRY_THRESHOLD = 50
|
||||
|
||||
|
||||
def _send_critical_retry_alert(
|
||||
def _send_retry_alert(
|
||||
func_name: str, attempt_number: int, exception: Exception, context: str = ""
|
||||
):
|
||||
"""Send alert when a function is approaching the retry failure threshold."""
|
||||
"""Send alert for excessive retry attempts."""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
@@ -33,19 +33,19 @@ def _send_critical_retry_alert(
|
||||
|
||||
prefix = f"{context}: " if context else ""
|
||||
alert_msg = (
|
||||
f"🚨 CRITICAL: Operation Approaching Failure Threshold: {prefix}'{func_name}'\n\n"
|
||||
f"Current attempt: {attempt_number}/{EXCESSIVE_RETRY_THRESHOLD}\n"
|
||||
f"🚨 Excessive Retry Alert: {prefix}'{func_name}' has failed {attempt_number} times!\n\n"
|
||||
f"Error: {type(exception).__name__}: {exception}\n\n"
|
||||
f"This operation is about to fail permanently. Investigate immediately."
|
||||
f"This indicates a persistent issue that requires investigation. "
|
||||
f"The operation has been retrying for an extended period."
|
||||
)
|
||||
|
||||
notification_client.discord_system_alert(alert_msg)
|
||||
logger.critical(
|
||||
f"CRITICAL ALERT SENT: Operation {func_name} at attempt {attempt_number}"
|
||||
f"ALERT SENT: Excessive retries detected for {func_name} after {attempt_number} attempts"
|
||||
)
|
||||
|
||||
except Exception as alert_error:
|
||||
logger.error(f"Failed to send critical retry alert: {alert_error}")
|
||||
logger.error(f"Failed to send retry alert: {alert_error}")
|
||||
# Don't let alerting failures break the main flow
|
||||
|
||||
|
||||
@@ -59,23 +59,22 @@ def _create_retry_callback(context: str = ""):
|
||||
|
||||
prefix = f"{context}: " if context else ""
|
||||
|
||||
# Send alert if we've exceeded the threshold
|
||||
if attempt_number >= EXCESSIVE_RETRY_THRESHOLD:
|
||||
_send_retry_alert(func_name, attempt_number, exception, context)
|
||||
|
||||
if retry_state.outcome.failed and retry_state.next_action is None:
|
||||
# Final failure - just log the error (alert was already sent at excessive threshold)
|
||||
# Final failure
|
||||
logger.error(
|
||||
f"{prefix}Giving up after {attempt_number} attempts for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
else:
|
||||
# Retry attempt - send critical alert only once at threshold
|
||||
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
|
||||
_send_critical_retry_alert(
|
||||
func_name, attempt_number, exception, context
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{prefix}Retry attempt {attempt_number} for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
# Retry attempt
|
||||
logger.warning(
|
||||
f"{prefix}Retry attempt {attempt_number} for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
@@ -139,6 +138,13 @@ def conn_retry(
|
||||
def on_retry(retry_state):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
exception = retry_state.outcome.exception()
|
||||
attempt_number = retry_state.attempt_number
|
||||
|
||||
# Send alert if we've exceeded the threshold
|
||||
if attempt_number >= EXCESSIVE_RETRY_THRESHOLD:
|
||||
func_name = f"{resource_name}:{action_name}"
|
||||
context = f"Connection retry {resource_name}"
|
||||
_send_retry_alert(func_name, attempt_number, exception, context)
|
||||
|
||||
if retry_state.outcome.failed and retry_state.next_action is None:
|
||||
logger.error(f"{prefix} {action_name} failed after retries: {exception}")
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import update_wrapper
|
||||
from functools import cached_property, update_wrapper
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
@@ -375,8 +375,6 @@ def get_service_client(
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
self._connection_failure_count = 0
|
||||
self._last_client_reset = 0
|
||||
self._async_clients = {} # None key for default async client
|
||||
self._sync_clients = {} # For sync clients (no event loop concept)
|
||||
|
||||
def _create_sync_client(self) -> httpx.Client:
|
||||
return httpx.Client(
|
||||
@@ -400,33 +398,13 @@ def get_service_client(
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
"""Get the sync client (thread-safe singleton)."""
|
||||
# Use service name as key for better identification
|
||||
service_name = service_client_type.get_service_type().__name__
|
||||
if client := self._sync_clients.get(service_name):
|
||||
return client
|
||||
return self._sync_clients.setdefault(
|
||||
service_name, self._create_sync_client()
|
||||
)
|
||||
return self._create_sync_client()
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
"""Get the appropriate async client for the current context.
|
||||
|
||||
Returns per-event-loop client when in async context,
|
||||
falls back to default client otherwise.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop, use None as default key
|
||||
loop = None
|
||||
|
||||
if client := self._async_clients.get(loop):
|
||||
return client
|
||||
return self._async_clients.setdefault(loop, self._create_async_client())
|
||||
return self._create_async_client()
|
||||
|
||||
def _handle_connection_error(self, error: Exception) -> None:
|
||||
"""Handle connection errors and implement self-healing"""
|
||||
@@ -445,8 +423,10 @@ def get_service_client(
|
||||
|
||||
# Clear cached clients to force recreation on next access
|
||||
# Only recreate when there's actually a problem
|
||||
self._sync_clients.clear()
|
||||
self._async_clients.clear()
|
||||
if hasattr(self, "sync_client"):
|
||||
delattr(self, "sync_client")
|
||||
if hasattr(self, "async_client"):
|
||||
delattr(self, "async_client")
|
||||
|
||||
# Reset counters
|
||||
self._connection_failure_count = 0
|
||||
@@ -512,37 +492,28 @@ def get_service_client(
|
||||
raise
|
||||
|
||||
async def aclose(self) -> None:
|
||||
# Close all sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
self._sync_clients.clear()
|
||||
|
||||
# Close all async clients (including default with None key)
|
||||
for client in self._async_clients.values():
|
||||
await client.aclose()
|
||||
self._async_clients.clear()
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
if hasattr(self, "async_client"):
|
||||
await self.async_client.aclose()
|
||||
|
||||
def close(self) -> None:
|
||||
# Close all sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
self._sync_clients.clear()
|
||||
# Note: Cannot close async clients synchronously
|
||||
# They will be cleaned up by garbage collection
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
# Note: Cannot close async client synchronously
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup HTTP clients on garbage collection to prevent resource leaks."""
|
||||
try:
|
||||
# Close any remaining sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
|
||||
# Warn if async clients weren't properly closed
|
||||
if self._async_clients:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
if hasattr(self, "async_client"):
|
||||
# Note: Can't await in __del__, so we just close sync
|
||||
# The async client will be cleaned up by garbage collection
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"DynamicClient async clients not explicitly closed. "
|
||||
"DynamicClient async client not explicitly closed. "
|
||||
"Call aclose() before destroying the client.",
|
||||
ResourceWarning,
|
||||
stacklevel=2,
|
||||
|
||||
@@ -59,19 +59,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for graph execution.",
|
||||
)
|
||||
|
||||
# FastAPI Thread Pool Configuration
|
||||
# IMPORTANT: FastAPI automatically offloads ALL sync functions to a thread pool:
|
||||
# - Sync endpoint functions (def instead of async def)
|
||||
# - Sync dependency functions (def instead of async def)
|
||||
# - Manually called run_in_threadpool() operations
|
||||
# Default thread pool size is only 40, which becomes a bottleneck under high concurrency
|
||||
fastapi_thread_pool_size: int = Field(
|
||||
default=60,
|
||||
ge=40,
|
||||
le=500,
|
||||
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
@@ -140,10 +127,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=5 * 60,
|
||||
description="Time in seconds after which the execution stuck on QUEUED status is considered late.",
|
||||
)
|
||||
cluster_lock_timeout: int = Field(
|
||||
default=300,
|
||||
description="Cluster lock timeout in seconds for graph execution coordination.",
|
||||
)
|
||||
execution_late_notification_checkrange_secs: int = Field(
|
||||
default=60 * 60,
|
||||
description="Time in seconds for how far back to check for the late executions.",
|
||||
@@ -258,19 +241,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The vhost for the RabbitMQ server",
|
||||
)
|
||||
|
||||
redis_host: str = Field(
|
||||
default="localhost",
|
||||
description="The host for the Redis server",
|
||||
)
|
||||
redis_port: int = Field(
|
||||
default=6379,
|
||||
description="The port for the Redis server",
|
||||
)
|
||||
redis_password: str = Field(
|
||||
default="",
|
||||
description="The password for the Redis server (empty string if no password)",
|
||||
)
|
||||
|
||||
postmark_sender_email: str = Field(
|
||||
default="invalid@invalid.com",
|
||||
description="The email address to use for sending emails",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user