mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 00:58:16 -05:00
Compare commits
95 Commits
fix/databa
...
fix/waitli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae20da8aaa | ||
|
|
8258338caf | ||
|
|
e62a56e8ba | ||
|
|
f3f9a60157 | ||
|
|
9469b9e2eb | ||
|
|
b7ae2c2fd2 | ||
|
|
8b995c2394 | ||
|
|
12b1067017 | ||
|
|
ba53cb78dc | ||
|
|
f9778cc87e | ||
|
|
b230b1b5cf | ||
|
|
1925e77733 | ||
|
|
9bc9b53b99 | ||
|
|
adfa75eca8 | ||
|
|
0f19d01483 | ||
|
|
112c39f6a6 | ||
|
|
22946f4617 | ||
|
|
938834ac83 | ||
|
|
934cb3a9c7 | ||
|
|
7b8499ec69 | ||
|
|
63076a67e1 | ||
|
|
41260a7b4a | ||
|
|
5f2d4643f8 | ||
|
|
9c8652b273 | ||
|
|
58ef687a54 | ||
|
|
c7dcbc64ec | ||
|
|
99ac206272 | ||
|
|
f67d78df3e | ||
|
|
e32c509ccc | ||
|
|
20acd8b51d | ||
|
|
a49c957467 | ||
|
|
cf6e724e99 | ||
|
|
b67555391d | ||
|
|
05a72f4185 | ||
|
|
36f634c417 | ||
|
|
18e169aa51 | ||
|
|
c5b90f7b09 | ||
|
|
a446c1acc9 | ||
|
|
59d242f69c | ||
|
|
a2cd5d9c1f | ||
|
|
df5b348676 | ||
|
|
4856bd1f3a | ||
|
|
2e1d3dd185 | ||
|
|
ff72343035 | ||
|
|
7982c34450 | ||
|
|
59c27fe248 | ||
|
|
c7575dc579 | ||
|
|
73603a8ce5 | ||
|
|
e562ca37aa | ||
|
|
f906fd9298 | ||
|
|
9e79add436 | ||
|
|
de6f4fca23 | ||
|
|
fb4b8ed9fc | ||
|
|
f3900127d7 | ||
|
|
7c47f54e25 | ||
|
|
927042d93e | ||
|
|
4244979a45 | ||
|
|
aa27365e7f | ||
|
|
b86aa8b14e | ||
|
|
e7ab2626f5 | ||
|
|
ff58ce174b | ||
|
|
2d8ab6b7c0 | ||
|
|
a7306970b8 | ||
|
|
c42f94ce2a | ||
|
|
4e1557e498 | ||
|
|
7f8cf36ceb | ||
|
|
0978566089 | ||
|
|
8b4eb6f87c | ||
|
|
4b7d17b9d2 | ||
|
|
0fc6a44389 | ||
|
|
f5ee579ab2 | ||
|
|
57a06f7088 | ||
|
|
258bf0b1a5 | ||
|
|
4a1cb6d64b | ||
|
|
7c9db7419a | ||
|
|
18bbd8e572 | ||
|
|
047f011520 | ||
|
|
d11917eb10 | ||
|
|
4663066e65 | ||
|
|
48a0faa611 | ||
|
|
70d00b4104 | ||
|
|
aad0434cb2 | ||
|
|
f33ec1f2ec | ||
|
|
e68b873bcf | ||
|
|
4530e97e59 | ||
|
|
477c261488 | ||
|
|
8ac2228e1e | ||
|
|
91dd9364bb | ||
|
|
f314fbf14f | ||
|
|
a97ff641c3 | ||
|
|
114f604d7b | ||
|
|
3abea1ed96 | ||
|
|
da6e1ad26d | ||
|
|
634fffb967 | ||
|
|
f3ec426c82 |
@@ -1,6 +1,3 @@
|
||||
[pr_reviewer]
|
||||
num_code_suggestions=0
|
||||
|
||||
[pr_code_suggestions]
|
||||
commitable_code_suggestions=false
|
||||
num_code_suggestions=0
|
||||
|
||||
47
autogpt_platform/Makefile
Normal file
47
autogpt_platform/Makefile
Normal file
@@ -0,0 +1,47 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
start-core:
|
||||
docker compose up -d deps
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
docker compose logs -f deps
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
format:
|
||||
cd backend && poetry run format
|
||||
cd frontend && pnpm format
|
||||
cd frontend && pnpm lint
|
||||
|
||||
init-env:
|
||||
cp -n .env.default .env || true
|
||||
cd backend && cp -n .env.default .env || true
|
||||
cd frontend && cp -n .env.default .env || true
|
||||
|
||||
|
||||
# Run migrations for backend
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
run-frontend:
|
||||
cd frontend && pnpm dev
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
|
||||
@echo " stop-core - Stop the core services"
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@@ -38,6 +38,37 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Running Just Core services
|
||||
|
||||
You can now run the following to enable just the core services.
|
||||
|
||||
```
|
||||
# For help
|
||||
make help
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
make start-core
|
||||
|
||||
# Stop core services
|
||||
make stop-core
|
||||
|
||||
# View logs from core services
|
||||
make logs-core
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
make format
|
||||
|
||||
# Run migrations for backend database
|
||||
make migrate
|
||||
|
||||
# Run backend server
|
||||
make run-backend
|
||||
|
||||
# Run frontend development server
|
||||
make run-frontend
|
||||
|
||||
```
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
Here are some useful Docker Compose commands for managing your AutoGPT Platform:
|
||||
|
||||
@@ -10,7 +10,7 @@ from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
|
||||
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
@@ -20,7 +20,9 @@ def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
async def requires_admin_user(
|
||||
jwt_payload: dict = fastapi.Security(get_jwt_payload),
|
||||
) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
@@ -30,7 +32,7 @@ def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
async def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestAuthDependencies:
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
@@ -53,12 +53,12 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
user = await requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
@@ -69,28 +69,28 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
user = await requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_user_missing_sub(self):
|
||||
async def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_requires_user_empty_sub(self):
|
||||
async def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
@@ -101,51 +101,51 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_admin_user(jwt_payload)
|
||||
user = await requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_admin_user_with_regular_user(self):
|
||||
async def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_admin_user(jwt_payload)
|
||||
await requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
def test_requires_admin_user_missing_role(self):
|
||||
async def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
requires_admin_user(jwt_payload)
|
||||
await requires_admin_user(jwt_payload)
|
||||
|
||||
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = get_user_id(jwt_payload)
|
||||
user_id = await get_user_id(jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
def test_get_user_id_missing_sub(self):
|
||||
async def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
await get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_get_user_id_none_sub(self):
|
||||
async def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
await get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestAuthDependenciesIntegration:
|
||||
|
||||
return _create_token
|
||||
|
||||
def test_endpoint_auth_enabled_no_token(self):
|
||||
async def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -184,7 +184,7 @@ class TestAuthDependenciesIntegration:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_endpoint_with_valid_token(self, create_token):
|
||||
async def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -203,7 +203,7 @@ class TestAuthDependenciesIntegration:
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
async def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -240,7 +240,7 @@ class TestAuthDependenciesIntegration:
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
def test_dependency_with_complex_payload(self):
|
||||
async def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -256,14 +256,14 @@ class TestAuthDependenciesEdgeCases:
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = requires_user(complex_payload)
|
||||
user = await requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = requires_admin_user(complex_payload)
|
||||
admin = await requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
def test_dependency_with_unicode_in_payload(self):
|
||||
async def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
@@ -272,11 +272,11 @@ class TestAuthDependenciesEdgeCases:
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = requires_user(unicode_payload)
|
||||
user = await requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
def test_dependency_with_null_values(self):
|
||||
async def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -286,18 +286,18 @@ class TestAuthDependenciesEdgeCases:
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = requires_user(null_payload)
|
||||
user = await requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
def test_concurrent_requests_isolation(self):
|
||||
async def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = requires_user(payload1)
|
||||
user2 = requires_admin_user(payload2)
|
||||
user1 = await requires_user(payload1)
|
||||
user2 = await requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
@@ -314,7 +314,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
def test_dependency_error_cases(
|
||||
async def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
@@ -325,7 +325,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
def test_dependency_valid_user(self):
|
||||
async def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
@@ -16,7 +16,7 @@ bearer_jwt_auth = HTTPBearer(
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_payload(
|
||||
async def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -116,32 +116,32 @@ def test_parse_jwt_token_missing_audience():
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_get_jwt_payload_with_valid_token():
|
||||
async def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = jwt_utils.get_jwt_payload(credentials)
|
||||
result = await jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
def test_get_jwt_payload_no_credentials():
|
||||
async def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(None)
|
||||
await jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_get_jwt_payload_invalid_token():
|
||||
async def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(credentials)
|
||||
await jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
@@ -139,8 +140,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
print(f"Log directory: {config.log_dir}")
|
||||
|
||||
# Activity log handler (INFO and above)
|
||||
activity_log_handler = logging.FileHandler(
|
||||
config.log_dir / LOG_FILE, "a", "utf-8"
|
||||
# Security fix: Use RotatingFileHandler with size limits to prevent disk exhaustion
|
||||
activity_log_handler = RotatingFileHandler(
|
||||
config.log_dir / LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
)
|
||||
activity_log_handler.setLevel(config.level)
|
||||
activity_log_handler.setFormatter(
|
||||
@@ -150,8 +156,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
if config.level == logging.DEBUG:
|
||||
# Debug log handler (all levels)
|
||||
debug_log_handler = logging.FileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE, "a", "utf-8"
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
debug_log_handler = RotatingFileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
)
|
||||
debug_log_handler.setLevel(logging.DEBUG)
|
||||
debug_log_handler.setFormatter(
|
||||
@@ -160,8 +171,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
log_handlers.append(debug_log_handler)
|
||||
|
||||
# Error log handler (ERROR and above)
|
||||
error_log_handler = logging.FileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
error_log_handler = RotatingFileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
)
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
|
||||
|
||||
@@ -115,12 +115,23 @@ def cached(
|
||||
"""
|
||||
|
||||
def decorator(target_func):
|
||||
# Cache storage and locks
|
||||
# Cache storage and per-event-loop locks
|
||||
cache_storage = {}
|
||||
_event_loop_locks = {} # Maps event loop to its asyncio.Lock
|
||||
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
# Async function with asyncio.Lock
|
||||
cache_lock = asyncio.Lock()
|
||||
|
||||
def _get_cache_lock():
|
||||
"""Get or create an asyncio.Lock for the current event loop."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop, use None as default key
|
||||
loop = None
|
||||
|
||||
if loop not in _event_loop_locks:
|
||||
return _event_loop_locks.setdefault(loop, asyncio.Lock())
|
||||
return _event_loop_locks[loop]
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -141,7 +152,7 @@ def cached(
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with cache_lock:
|
||||
async with _get_cache_lock():
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
|
||||
214
autogpt_platform/backend/backend/blocks/ai_condition.py
Normal file
214
autogpt_platform/backend/backend/blocks/ai_condition.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
)
|
||||
from backend.data.block import BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
An AI-powered condition block that uses natural language to evaluate conditions.
|
||||
|
||||
This block allows users to define conditions in plain English (e.g., "the input is an email address",
|
||||
"the input is a city in the USA") and uses AI to determine if the input satisfies the condition.
|
||||
It provides the same yes/no data pass-through functionality as the standard ConditionBlock.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
input_value: Any = SchemaField(
|
||||
description="The input value to evaluate with the AI condition",
|
||||
placeholder="Enter the value to be evaluated (text, number, or any data)",
|
||||
)
|
||||
condition: str = SchemaField(
|
||||
description="A plaintext English description of the condition to evaluate",
|
||||
placeholder="E.g., 'the input is the body of an email', 'the input is a City in the USA', 'the input is an error or a refusal'",
|
||||
)
|
||||
yes_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is true. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
no_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is false. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the AI condition evaluation (True or False)"
|
||||
)
|
||||
yes_output: Any = SchemaField(
|
||||
description="The output value if the condition is true"
|
||||
)
|
||||
no_output: Any = SchemaField(
|
||||
description="The output value if the condition is false"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the AI evaluation is uncertain or fails"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553ec5b8-6c45-4299-8d75-b394d05f72ff",
|
||||
input_schema=AIConditionBlock.Input,
|
||||
output_schema=AIConditionBlock.Output,
|
||||
description="Uses AI to evaluate natural language conditions and provide conditional outputs",
|
||||
categories={BlockCategory.AI, BlockCategory.LOGIC},
|
||||
test_input={
|
||||
"input_value": "john@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("result", True),
|
||||
("yes_output", "Valid email"),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response="true",
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def llm_call(
|
||||
self,
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list,
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Wrapper method for llm_call to enable mocking in tests."""
|
||||
return await llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
force_json_output=False,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Evaluate the AI condition and return appropriate outputs.
|
||||
"""
|
||||
# Prepare the yes and no values, using input_value as default
|
||||
yes_value = (
|
||||
input_data.yes_value
|
||||
if input_data.yes_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
no_value = (
|
||||
input_data.no_value
|
||||
if input_data.no_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
|
||||
# Convert input_value to string for AI evaluation
|
||||
input_str = str(input_data.input_value)
|
||||
|
||||
# Create the prompt for AI evaluation
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an AI assistant that evaluates conditions based on input data. "
|
||||
"You must respond with only 'true' or 'false' (lowercase) to indicate whether "
|
||||
"the given condition is met by the input value. Be accurate and consider the "
|
||||
"context and meaning of both the input and the condition."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Input value: {input_str}\n"
|
||||
f"Condition to evaluate: {input_data.condition}\n\n"
|
||||
f"Does the input value satisfy the condition? Respond with only 'true' or 'false'."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
@@ -1,8 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
from pydantic import SecretStr
|
||||
from e2b_code_interpreter import Result as E2BExecutionResult
|
||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||
from pydantic import BaseModel, JsonValue, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
@@ -36,14 +38,135 @@ class ProgrammingLanguage(Enum):
|
||||
JAVA = "java"
|
||||
|
||||
|
||||
class CodeExecutionBlock(Block):
|
||||
class MainCodeExecutionResult(BaseModel):
|
||||
"""
|
||||
*Pydantic model mirroring `e2b_code_interpreter.Result`*
|
||||
|
||||
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
|
||||
The result is similar to the structure returned by ipython kernel: https://ipython.readthedocs.io/en/stable/development/execution.html#execution-semantics
|
||||
|
||||
The result can contain multiple types of data, such as text, images, plots, etc. Each type of data is represented
|
||||
as a string, and the result can contain multiple types of data. The display calls don't have to have text representation,
|
||||
for the actual result the representation is always present for the result, the other representations are always optional.
|
||||
""" # noqa
|
||||
|
||||
class Chart(BaseModel, E2BExecutionResultChart):
|
||||
pass
|
||||
|
||||
text: Optional[str] = None
|
||||
html: Optional[str] = None
|
||||
markdown: Optional[str] = None
|
||||
svg: Optional[str] = None
|
||||
png: Optional[str] = None
|
||||
jpeg: Optional[str] = None
|
||||
pdf: Optional[str] = None
|
||||
latex: Optional[str] = None
|
||||
json: Optional[JsonValue] = None # type: ignore (reportIncompatibleMethodOverride)
|
||||
javascript: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
chart: Optional[Chart] = None
|
||||
extra: Optional[dict] = None
|
||||
"""Extra data that can be included. Not part of the standard types."""
|
||||
|
||||
|
||||
class CodeExecutionResult(MainCodeExecutionResult):
|
||||
__doc__ = MainCodeExecutionResult.__doc__
|
||||
|
||||
is_main_result: bool = False
|
||||
"""Whether this data is the main result of the cell. Data can be produced by display calls of which can be multiple in a cell.""" # noqa
|
||||
|
||||
|
||||
class BaseE2BExecutorMixin:
|
||||
"""Shared implementation methods for E2B executor blocks."""
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
api_key: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
template_id: str = "",
|
||||
setup_commands: Optional[list[str]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
sandbox_id: Optional[str] = None,
|
||||
dispose_sandbox: bool = False,
|
||||
):
|
||||
"""
|
||||
Unified code execution method that handles all three use cases:
|
||||
1. Create new sandbox and execute (ExecuteCodeBlock)
|
||||
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
||||
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
||||
""" # noqa
|
||||
sandbox = None
|
||||
try:
|
||||
if sandbox_id:
|
||||
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
||||
sandbox = await AsyncSandbox.connect(
|
||||
sandbox_id=sandbox_id, api_key=api_key
|
||||
)
|
||||
else:
|
||||
# Create new sandbox (ExecuteCodeBlock/InstantiateCodeSandboxBlock case)
|
||||
sandbox = await AsyncSandbox.create(
|
||||
api_key=api_key, template=template_id, timeout=timeout
|
||||
)
|
||||
if setup_commands:
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
results = execution.results
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
|
||||
finally:
|
||||
# Dispose of sandbox if requested to reduce usage costs
|
||||
if dispose_sandbox and sandbox:
|
||||
await sandbox.kill()
|
||||
|
||||
def process_execution_results(
|
||||
self, results: list[E2BExecutionResult]
|
||||
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
|
||||
"""Process and filter execution results."""
|
||||
# Filter out empty formats and convert to dicts
|
||||
processed_results = [
|
||||
{
|
||||
f: value
|
||||
for f in [*r.formats(), "extra", "is_main_result"]
|
||||
if (value := getattr(r, f, None)) is not None
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
if main_result := next(
|
||||
(r for r in processed_results if r.get("is_main_result")), None
|
||||
):
|
||||
# Make main_result a copy we can modify & remove is_main_result
|
||||
(main_result := {**main_result}).pop("is_main_result")
|
||||
|
||||
return main_result, processed_results
|
||||
|
||||
|
||||
class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
# TODO : Add support to upload and download files
|
||||
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
|
||||
# NOTE: Currently, you can only customize the CPU and Memory
|
||||
# by creating a pre customized sandbox template
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -76,6 +199,14 @@ class CodeExecutionBlock(Block):
|
||||
description="Execution timeout in seconds", default=300
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description=(
|
||||
"Whether to dispose of the sandbox immediately after execution. "
|
||||
"If disabled, the sandbox will run until its timeout expires."
|
||||
),
|
||||
default=True,
|
||||
)
|
||||
|
||||
template_id: str = SchemaField(
|
||||
description=(
|
||||
"You can use an E2B sandbox template by entering its ID here. "
|
||||
@@ -87,7 +218,16 @@ class CodeExecutionBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -97,10 +237,10 @@ class CodeExecutionBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0b02b072-abe7-11ef-8372-fb5d162dd712",
|
||||
description="Executes code in an isolated sandbox environment with internet access.",
|
||||
description="Executes code in a sandbox environment with internet access.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=CodeExecutionBlock.Input,
|
||||
output_schema=CodeExecutionBlock.Output,
|
||||
input_schema=ExecuteCodeBlock.Input,
|
||||
output_schema=ExecuteCodeBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -111,91 +251,59 @@ class CodeExecutionBlock(Block):
|
||||
"template_id": "",
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class InstantiationBlock(Block):
|
||||
class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
)
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -240,7 +348,10 @@ class InstantiationBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
sandbox_id: str = SchemaField(description="ID of the sandbox instance")
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
response: str = SchemaField(
|
||||
title="Text Result",
|
||||
description="Text result (if any) of the setup code execution",
|
||||
)
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -250,10 +361,13 @@ class InstantiationBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ff0861c9-1726-4aec-9e5b-bf53f3622112",
|
||||
description="Instantiate an isolated sandbox environment with internet access where to execute code in.",
|
||||
description=(
|
||||
"Instantiate a sandbox environment with internet access "
|
||||
"in which you can execute code with the Execute Code Step block."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=InstantiationBlock.Input,
|
||||
output_schema=InstantiationBlock.Output,
|
||||
input_schema=InstantiateCodeSandboxBlock.Input,
|
||||
output_schema=InstantiateCodeSandboxBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -269,11 +383,12 @@ class InstantiationBlock(Block):
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda setup_code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"sandbox_id",
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -282,78 +397,38 @@ class InstantiationBlock(Block):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sandbox_id, response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.setup_code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
_, text_output, stdout, stderr, sandbox_id = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.setup_code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
)
|
||||
if sandbox_id:
|
||||
yield "sandbox_id", sandbox_id
|
||||
else:
|
||||
yield "error", "Sandbox ID not found"
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return sandbox.sandbox_id, response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class StepExecutionBlock(Block):
|
||||
class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
)
|
||||
|
||||
sandbox_id: str = SchemaField(
|
||||
@@ -374,8 +449,22 @@ class StepExecutionBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description="Whether to dispose of the sandbox after executing this code.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -385,10 +474,10 @@ class StepExecutionBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="82b59b8e-ea10-4d57-9161-8b169b0adba6",
|
||||
description="Execute code in a previously instantiated sandbox environment.",
|
||||
description="Execute code in a previously instantiated sandbox.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StepExecutionBlock.Input,
|
||||
output_schema=StepExecutionBlock.Output,
|
||||
input_schema=ExecuteCodeStepBlock.Input,
|
||||
output_schema=ExecuteCodeStepBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -397,61 +486,43 @@ class StepExecutionBlock(Block):
|
||||
"language": ProgrammingLanguage.PYTHON.value,
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_step_code": lambda sandbox_id, step_code, language, api_key: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
"execute_code": lambda api_key, code, language, sandbox_id, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
sandbox_id, # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_step_code(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
api_key: str,
|
||||
):
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id=sandbox_id, api_key=api_key)
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not found")
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(code, language=language.value)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = await self.execute_step_code(
|
||||
input_data.sandbox_id,
|
||||
input_data.step_code,
|
||||
input_data.language,
|
||||
credentials.api_key.get_secret_value(),
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.step_code,
|
||||
language=input_data.language,
|
||||
sandbox_id=input_data.sandbox_id,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -90,7 +90,7 @@ class CodeExtractionBlock(Block):
|
||||
for aliases in language_aliases.values()
|
||||
for alias in aliases
|
||||
)
|
||||
+ r")\s+[\s\S]*?```"
|
||||
+ r")[ \t]*\n[\s\S]*?```"
|
||||
)
|
||||
|
||||
remaining_text = re.sub(pattern, "", input_data.text).strip()
|
||||
@@ -103,7 +103,9 @@ class CodeExtractionBlock(Block):
|
||||
# Escape special regex characters in the language string
|
||||
language = re.escape(language)
|
||||
# Extract all code blocks enclosed in ```language``` blocks
|
||||
pattern = re.compile(rf"```{language}\s+(.*?)```", re.DOTALL | re.IGNORECASE)
|
||||
pattern = re.compile(
|
||||
rf"```{language}[ \t]*\n(.*?)\n```", re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
matches = pattern.finditer(text)
|
||||
# Combine all code blocks for this language with newlines between them
|
||||
code_blocks = [match.group(1).strip() for match in matches]
|
||||
|
||||
@@ -66,6 +66,7 @@ class AddToDictionaryBlock(Block):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
advanced=False,
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
|
||||
@@ -90,6 +90,7 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -161,43 +162,52 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch keyword suggestions: {str(e)}"
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
|
||||
@@ -98,6 +98,7 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -171,50 +172,60 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get(
|
||||
"competition"
|
||||
),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get(
|
||||
"keyword_properties", {}
|
||||
).get("keyword_difficulty"),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch related keywords: {str(e)}"
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
|
||||
@@ -171,11 +171,11 @@ class SendDiscordMessageBlock(Block):
|
||||
description="The content of the message to send"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="The name of the channel the message will be sent to"
|
||||
description="Channel ID or channel name to send the message to"
|
||||
)
|
||||
server_name: str = SchemaField(
|
||||
description="The name of the server where the channel is located",
|
||||
advanced=True, # Optional field for server name
|
||||
description="Server name (only needed if using channel name)",
|
||||
advanced=True,
|
||||
default="",
|
||||
)
|
||||
|
||||
@@ -231,25 +231,49 @@ class SendDiscordMessageBlock(Block):
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Logged in as {client.user}")
|
||||
for guild in client.guilds:
|
||||
if server_name and guild.name != server_name:
|
||||
continue
|
||||
for channel in guild.text_channels:
|
||||
if channel.name == channel_name:
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
chunks = self.chunk_message(message_content)
|
||||
last_message = None
|
||||
for chunk in chunks:
|
||||
last_message = await channel.send(chunk)
|
||||
result["status"] = "Message sent"
|
||||
result["message_id"] = (
|
||||
str(last_message.id) if last_message else ""
|
||||
)
|
||||
result["channel_id"] = str(channel.id)
|
||||
await client.close()
|
||||
return
|
||||
channel = None
|
||||
|
||||
result["status"] = "Channel not found"
|
||||
# Try to parse as channel ID first
|
||||
try:
|
||||
channel_id = int(channel_name)
|
||||
channel = client.get_channel(channel_id)
|
||||
except ValueError:
|
||||
# Not a valid ID, will try name lookup
|
||||
pass
|
||||
|
||||
# If not found by ID (or not an ID), try name lookup
|
||||
if not channel:
|
||||
for guild in client.guilds:
|
||||
if server_name and guild.name != server_name:
|
||||
continue
|
||||
for ch in guild.text_channels:
|
||||
if ch.name == channel_name:
|
||||
channel = ch
|
||||
break
|
||||
if channel:
|
||||
break
|
||||
|
||||
if not channel:
|
||||
result["status"] = f"Channel not found: {channel_name}"
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Type check - ensure it's a text channel that can send messages
|
||||
if not hasattr(channel, "send"):
|
||||
result["status"] = (
|
||||
f"Channel {channel_name} cannot receive messages (not a text channel)"
|
||||
)
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
chunks = self.chunk_message(message_content)
|
||||
last_message = None
|
||||
for chunk in chunks:
|
||||
last_message = await channel.send(chunk) # type: ignore
|
||||
result["status"] = "Message sent"
|
||||
result["message_id"] = str(last_message.id) if last_message else ""
|
||||
result["channel_id"] = str(channel.id)
|
||||
await client.close()
|
||||
|
||||
await client.start(token)
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Utility functions for converting between our ScrapeFormat enum and firecrawl FormatOption types."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from firecrawl.v2.types import FormatOption, ScreenshotFormat
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
|
||||
|
||||
def convert_to_format_options(
|
||||
formats: List[ScrapeFormat],
|
||||
) -> List[FormatOption]:
|
||||
"""Convert our ScrapeFormat enum values to firecrawl FormatOption types.
|
||||
|
||||
Handles special cases like screenshot@fullPage which needs to be converted
|
||||
to a ScreenshotFormat object.
|
||||
"""
|
||||
result: List[FormatOption] = []
|
||||
|
||||
for format_enum in formats:
|
||||
if format_enum.value == "screenshot@fullPage":
|
||||
# Special case: convert to ScreenshotFormat with full_page=True
|
||||
result.append(ScreenshotFormat(type="screenshot", full_page=True))
|
||||
else:
|
||||
# Regular string literals
|
||||
result.append(format_enum.value)
|
||||
|
||||
return result
|
||||
@@ -1,8 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -14,21 +15,10 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlCrawlBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
@@ -78,18 +68,17 @@ class FirecrawlCrawlBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
crawl_result = app.crawl_url(
|
||||
crawl_result = app.crawl(
|
||||
input_data.url,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=[format.value for format in input_data.formats],
|
||||
onlyMainContent=input_data.only_main_content,
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
formats=convert_to_format_options(input_data.formats),
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", crawl_result.data
|
||||
@@ -101,7 +90,7 @@ class FirecrawlCrawlBlock(Block):
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", data.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", data.rawHtml
|
||||
yield "raw_html", data.raw_html
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", data.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
@@ -109,6 +98,6 @@ class FirecrawlCrawlBlock(Block):
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", data.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", data.changeTracking
|
||||
yield "change_tracking", data.change_tracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", data.json
|
||||
|
||||
@@ -20,7 +20,6 @@ from ._config import firecrawl
|
||||
|
||||
@cost(BlockCost(2, BlockCostType.RUN))
|
||||
class FirecrawlExtractBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
urls: list[str] = SchemaField(
|
||||
@@ -53,7 +52,6 @@ class FirecrawlExtractBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
extract_result = app.extract(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.sdk import (
|
||||
@@ -14,14 +16,16 @@ from ._config import firecrawl
|
||||
|
||||
|
||||
class FirecrawlMapWebsiteBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
|
||||
url: str = SchemaField(description="The website url to map")
|
||||
|
||||
class Output(BlockSchema):
|
||||
links: list[str] = SchemaField(description="The links of the website")
|
||||
links: list[str] = SchemaField(description="List of URLs found on the website")
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="List of search results with url, title, and description"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -35,12 +39,22 @@ class FirecrawlMapWebsiteBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
map_result = app.map_url(
|
||||
map_result = app.map(
|
||||
url=input_data.url,
|
||||
)
|
||||
|
||||
yield "links", map_result.links
|
||||
# Convert SearchResult objects to dicts
|
||||
results_data = [
|
||||
{
|
||||
"url": link.url,
|
||||
"title": link.title,
|
||||
"description": link.description,
|
||||
}
|
||||
for link in map_result.links
|
||||
]
|
||||
|
||||
yield "links", [link.url for link in map_result.links]
|
||||
yield "results", results_data
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -14,21 +14,10 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlScrapeBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
@@ -78,12 +67,11 @@ class FirecrawlScrapeBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
scrape_result = app.scrape_url(
|
||||
scrape_result = app.scrape(
|
||||
input_data.url,
|
||||
formats=[format.value for format in input_data.formats],
|
||||
formats=convert_to_format_options(input_data.formats),
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
@@ -96,7 +84,7 @@ class FirecrawlScrapeBlock(Block):
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", scrape_result.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", scrape_result.rawHtml
|
||||
yield "raw_html", scrape_result.raw_html
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", scrape_result.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
@@ -104,6 +92,6 @@ class FirecrawlScrapeBlock(Block):
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", scrape_result.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", scrape_result.changeTracking
|
||||
yield "change_tracking", scrape_result.change_tracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", scrape_result.json
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -14,21 +15,10 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlSearchBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
query: str = SchemaField(description="The query to search for")
|
||||
@@ -61,7 +51,6 @@ class FirecrawlSearchBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
@@ -69,11 +58,12 @@ class FirecrawlSearchBlock(Block):
|
||||
input_data.query,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=[format.value for format in input_data.formats],
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
formats=convert_to_format_options(input_data.formats) or None,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", scrape_result
|
||||
for site in scrape_result.data:
|
||||
yield "site", site
|
||||
if hasattr(scrape_result, "web") and scrape_result.web:
|
||||
for site in scrape_result.web:
|
||||
yield "site", site
|
||||
|
||||
@@ -554,6 +554,89 @@ class AgentToggleInputBlock(AgentInputBlock):
|
||||
)
|
||||
|
||||
|
||||
class AgentTableInputBlock(AgentInputBlock):
|
||||
"""
|
||||
This block allows users to input data in a table format.
|
||||
|
||||
Configure the table columns at build time, then users can input
|
||||
rows of data at runtime. Each row is output as a dictionary
|
||||
with column names as keys.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[list[dict[str, Any]]] = SchemaField(
|
||||
description="The table data as a list of dictionaries.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
column_headers: list[str] = SchemaField(
|
||||
description="Column headers for the table.",
|
||||
default_factory=lambda: ["Column 1", "Column 2", "Column 3"],
|
||||
advanced=False,
|
||||
title="Column Headers",
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
"""Generate schema for the value field with table format."""
|
||||
schema = super().generate_schema()
|
||||
schema["type"] = "array"
|
||||
schema["format"] = "table"
|
||||
schema["items"] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
header: {"type": "string"}
|
||||
for header in (
|
||||
self.column_headers or ["Column 1", "Column 2", "Column 3"]
|
||||
)
|
||||
},
|
||||
}
|
||||
if self.value is not None:
|
||||
schema["default"] = self.value
|
||||
return schema
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: list[dict[str, Any]] = SchemaField(
|
||||
description="The table data as a list of dictionaries with headers as keys."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5603b273-f41e-4020-af7d-fbc9c6a8d928",
|
||||
description="Block for table data input with customizable headers.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentTableInputBlock.Input,
|
||||
output_schema=AgentTableInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"name": "test_table",
|
||||
"column_headers": ["Name", "Age", "City"],
|
||||
"value": [
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
],
|
||||
"description": "Example table input",
|
||||
}
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
[
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Yields the table data as a list of dictionaries.
|
||||
"""
|
||||
# Pass through the value, defaulting to empty list if None
|
||||
yield "result", input_data.value if input_data.value is not None else []
|
||||
|
||||
|
||||
IO_BLOCK_IDs = [
|
||||
AgentInputBlock().id,
|
||||
AgentOutputBlock().id,
|
||||
@@ -565,4 +648,5 @@ IO_BLOCK_IDs = [
|
||||
AgentFileInputBlock().id,
|
||||
AgentDropdownInputBlock().id,
|
||||
AgentToggleInputBlock().id,
|
||||
AgentTableInputBlock().id,
|
||||
]
|
||||
|
||||
@@ -54,20 +54,43 @@ class StepThroughItemsBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add limits to prevent DoS from large iterations
|
||||
MAX_ITEMS = 10000 # Maximum items to iterate
|
||||
MAX_ITEM_SIZE = 1024 * 1024 # 1MB per item
|
||||
|
||||
for data in [input_data.items, input_data.items_object, input_data.items_str]:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# Limit string size before parsing
|
||||
if isinstance(data, str):
|
||||
if len(data) > MAX_ITEM_SIZE:
|
||||
raise ValueError(
|
||||
f"Input too large: {len(data)} bytes > {MAX_ITEM_SIZE} bytes"
|
||||
)
|
||||
items = json.loads(data)
|
||||
else:
|
||||
items = data
|
||||
|
||||
# Check total item count
|
||||
if isinstance(items, (list, dict)):
|
||||
if len(items) > MAX_ITEMS:
|
||||
raise ValueError(f"Too many items: {len(items)} > {MAX_ITEMS}")
|
||||
|
||||
iteration_count = 0
|
||||
if isinstance(items, dict):
|
||||
# If items is a dictionary, iterate over its values
|
||||
for item in items.values():
|
||||
yield "item", item
|
||||
yield "key", item
|
||||
for key, value in items.items():
|
||||
if iteration_count >= MAX_ITEMS:
|
||||
break
|
||||
yield "item", value
|
||||
yield "key", key # Fixed: should yield key, not item
|
||||
iteration_count += 1
|
||||
else:
|
||||
# If items is a list, iterate over the list
|
||||
for index, item in enumerate(items):
|
||||
if iteration_count >= MAX_ITEMS:
|
||||
break
|
||||
yield "item", item
|
||||
yield "key", index
|
||||
iteration_count += 1
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import List
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
JinaCredentials,
|
||||
JinaCredentialsField,
|
||||
@@ -10,6 +13,12 @@ from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class Reference(TypedDict):
|
||||
url: str
|
||||
keyQuote: str
|
||||
isSupportive: bool
|
||||
|
||||
|
||||
class FactCheckerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
statement: str = SchemaField(
|
||||
@@ -23,6 +32,10 @@ class FactCheckerBlock(Block):
|
||||
)
|
||||
result: bool = SchemaField(description="The result of the factuality check")
|
||||
reason: str = SchemaField(description="The reason for the factuality result")
|
||||
references: List[Reference] = SchemaField(
|
||||
description="List of references supporting or contradicting the statement",
|
||||
default=[],
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the check fails")
|
||||
|
||||
def __init__(self):
|
||||
@@ -53,5 +66,11 @@ class FactCheckerBlock(Block):
|
||||
yield "factuality", data["factuality"]
|
||||
yield "result", data["result"]
|
||||
yield "reason", data["reason"]
|
||||
|
||||
# Yield references if present in the response
|
||||
if "references" in data:
|
||||
yield "references", data["references"]
|
||||
else:
|
||||
yield "references", []
|
||||
else:
|
||||
raise RuntimeError(f"Expected 'data' key not found in response: {data}")
|
||||
|
||||
@@ -37,5 +37,5 @@ class Project(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
priority: int
|
||||
progress: int
|
||||
content: str
|
||||
progress: float
|
||||
content: str | None
|
||||
|
||||
@@ -101,6 +101,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
@@ -213,6 +215,12 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-haiku-4-5-20251001
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-3-7-sonnet-20250219
|
||||
@@ -1400,11 +1408,27 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
|
||||
@staticmethod
|
||||
def _split_text(text: str, max_tokens: int, overlap: int) -> list[str]:
|
||||
# Security fix: Add validation to prevent DoS attacks
|
||||
# Limit text size to prevent memory exhaustion
|
||||
MAX_TEXT_LENGTH = 1_000_000 # 1MB character limit
|
||||
MAX_CHUNKS = 100 # Maximum number of chunks to prevent excessive memory use
|
||||
|
||||
if len(text) > MAX_TEXT_LENGTH:
|
||||
text = text[:MAX_TEXT_LENGTH]
|
||||
|
||||
# Ensure chunk_size is at least 1 to prevent infinite loops
|
||||
chunk_size = max(1, max_tokens - overlap)
|
||||
|
||||
# Ensure overlap is less than max_tokens to prevent invalid configurations
|
||||
if overlap >= max_tokens:
|
||||
overlap = max(0, max_tokens - 1)
|
||||
|
||||
words = text.split()
|
||||
chunks = []
|
||||
chunk_size = max_tokens - overlap
|
||||
|
||||
for i in range(0, len(words), chunk_size):
|
||||
if len(chunks) >= MAX_CHUNKS:
|
||||
break # Limit the number of chunks to prevent memory exhaustion
|
||||
chunk = " ".join(words[i : i + max_tokens])
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
226
autogpt_platform/backend/backend/blocks/perplexity.py
Normal file
226
autogpt_platform/backend/backend/blocks/perplexity.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# flake8: noqa: E501
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
||||
|
||||
|
||||
class PerplexityModel(str, Enum):
|
||||
"""Perplexity sonar models available via OpenRouter"""
|
||||
|
||||
SONAR = "perplexity/sonar"
|
||||
SONAR_PRO = "perplexity/sonar-pro"
|
||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
|
||||
|
||||
PerplexityCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="test-perplexity-creds",
|
||||
provider="open_router",
|
||||
api_key=SecretStr("mock-openrouter-api-key"),
|
||||
title="Mock OpenRouter API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def PerplexityCredentialsField() -> PerplexityCredentials:
|
||||
return CredentialsField(
|
||||
description="OpenRouter API key for accessing Perplexity models.",
|
||||
)
|
||||
|
||||
|
||||
class PerplexityBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="The query to send to the Perplexity model.",
|
||||
placeholder="Enter your query here...",
|
||||
)
|
||||
model: PerplexityModel = SchemaField(
|
||||
title="Perplexity Model",
|
||||
default=PerplexityModel.SONAR,
|
||||
description="The Perplexity sonar model to use.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
description="Optional system prompt to provide context to the model.",
|
||||
advanced=True,
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(
|
||||
description="The response from the Perplexity model."
|
||||
)
|
||||
annotations: list[dict[str, Any]] = SchemaField(
|
||||
description="List of URL citations and annotations from the response."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f",
|
||||
description="Query Perplexity's sonar models with real-time web search capabilities and receive annotated responses with source citations.",
|
||||
categories={BlockCategory.AI, BlockCategory.SEARCH},
|
||||
input_schema=PerplexityBlock.Input,
|
||||
output_schema=PerplexityBlock.Output,
|
||||
test_input={
|
||||
"prompt": "What is the weather today?",
|
||||
"model": PerplexityModel.SONAR,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("response", "The weather varies by location..."),
|
||||
("annotations", list),
|
||||
],
|
||||
test_mock={
|
||||
"call_perplexity": lambda *args, **kwargs: {
|
||||
"response": "The weather varies by location...",
|
||||
"annotations": [
|
||||
{
|
||||
"type": "url_citation",
|
||||
"url_citation": {
|
||||
"title": "weather.com",
|
||||
"url": "https://weather.com",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
self.execution_stats = NodeExecutionStats()
|
||||
|
||||
async def call_perplexity(
|
||||
self,
|
||||
credentials: APIKeyCredentials,
|
||||
model: PerplexityModel,
|
||||
prompt: str,
|
||||
system_prompt: str = "",
|
||||
max_tokens: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Call Perplexity via OpenRouter and extract annotations."""
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=model.value,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError("No response from Perplexity via OpenRouter.")
|
||||
|
||||
# Extract the response content
|
||||
response_content = response.choices[0].message.content or ""
|
||||
|
||||
# Extract annotations if present in the message
|
||||
annotations = []
|
||||
if hasattr(response.choices[0].message, "annotations"):
|
||||
# If annotations are directly available
|
||||
annotations = response.choices[0].message.annotations
|
||||
else:
|
||||
# Check if there's a raw response with annotations
|
||||
raw = getattr(response.choices[0].message, "_raw_response", None)
|
||||
if isinstance(raw, dict) and "annotations" in raw:
|
||||
annotations = raw["annotations"]
|
||||
|
||||
if not annotations and hasattr(response, "model_extra"):
|
||||
# Check model_extra for annotations
|
||||
model_extra = response.model_extra
|
||||
if isinstance(model_extra, dict):
|
||||
# Check in choices
|
||||
if "choices" in model_extra and len(model_extra["choices"]) > 0:
|
||||
choice = model_extra["choices"][0]
|
||||
if "message" in choice and "annotations" in choice["message"]:
|
||||
annotations = choice["message"]["annotations"]
|
||||
|
||||
# Also check the raw response object for annotations
|
||||
if not annotations:
|
||||
raw = getattr(response, "_raw_response", None)
|
||||
if isinstance(raw, dict):
|
||||
# Check various possible locations for annotations
|
||||
if "annotations" in raw:
|
||||
annotations = raw["annotations"]
|
||||
elif "choices" in raw and len(raw["choices"]) > 0:
|
||||
choice = raw["choices"][0]
|
||||
if "message" in choice and "annotations" in choice["message"]:
|
||||
annotations = choice["message"]["annotations"]
|
||||
|
||||
# Update execution stats
|
||||
if response.usage:
|
||||
self.execution_stats.input_token_count = response.usage.prompt_tokens
|
||||
self.execution_stats.output_token_count = (
|
||||
response.usage.completion_tokens
|
||||
)
|
||||
|
||||
return {"response": response_content, "annotations": annotations or []}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Perplexity: {e}")
|
||||
raise
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
logger.debug(f"Running Perplexity block with model: {input_data.model}")
|
||||
|
||||
try:
|
||||
result = await self.call_perplexity(
|
||||
credentials=credentials,
|
||||
model=input_data.model,
|
||||
prompt=input_data.prompt,
|
||||
system_prompt=input_data.system_prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
|
||||
yield "response", result["response"]
|
||||
yield "annotations", result["annotations"]
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling Perplexity: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
yield "error", error_msg
|
||||
@@ -1,4 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
@@ -101,7 +104,38 @@ class ReadRSSFeedBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def parse_feed(url: str) -> dict[str, Any]:
|
||||
return feedparser.parse(url) # type: ignore
|
||||
# Security fix: Add protection against memory exhaustion attacks
|
||||
MAX_FEED_SIZE = 10 * 1024 * 1024 # 10MB limit for RSS feeds
|
||||
|
||||
# Validate URL
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError(f"Invalid URL scheme: {parsed_url.scheme}")
|
||||
|
||||
# Download with size limit
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=30) as response:
|
||||
# Check content length if available
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if content_length and int(content_length) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: {content_length} bytes exceeds {MAX_FEED_SIZE} limit"
|
||||
)
|
||||
|
||||
# Read with size limit
|
||||
content = response.read(MAX_FEED_SIZE + 1)
|
||||
if len(content) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: exceeds {MAX_FEED_SIZE} byte limit"
|
||||
)
|
||||
|
||||
# Parse with feedparser using the validated content
|
||||
# feedparser has built-in protection against XML attacks
|
||||
return feedparser.parse(content) # type: ignore
|
||||
except Exception as e:
|
||||
# Log error and return empty feed
|
||||
logging.warning(f"Failed to parse RSS feed from {url}: {e}")
|
||||
return {"entries": []}
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
keep_going = True
|
||||
|
||||
@@ -13,6 +13,11 @@ from backend.data.block import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.dynamic_fields import (
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
is_dynamic_field,
|
||||
)
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
@@ -98,6 +103,22 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Safely convert raw_response to dictionary format for conversation history.
|
||||
Handles different response types from different LLM providers.
|
||||
"""
|
||||
if isinstance(raw_response, str):
|
||||
# Ollama returns a string, convert to dict format
|
||||
return {"role": "assistant", "content": raw_response}
|
||||
elif isinstance(raw_response, dict):
|
||||
# Already a dict (from tests or some providers)
|
||||
return raw_response
|
||||
else:
|
||||
# OpenAI/Anthropic return objects, convert with json.to_dict
|
||||
return json.to_dict(raw_response)
|
||||
|
||||
|
||||
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
|
||||
"""
|
||||
All the tool calls entry in the conversation history requires a response.
|
||||
@@ -261,6 +282,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def cleanup(s: str):
|
||||
"""Clean up block names for use as tool function names."""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
@staticmethod
|
||||
@@ -288,41 +310,66 @@ class SmartDecisionMakerBlock(Block):
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
properties = {}
|
||||
field_mapping = {} # clean_name -> original_name
|
||||
|
||||
for link in links:
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Clean property key to ensure Anthropic API compatibility for ALL fields
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
# Handle dynamic fields (e.g., values_#_*, items_$_*, etc.)
|
||||
# These are fields that get merged by the executor into their base field
|
||||
if (
|
||||
"_#_" in link.sink_name
|
||||
or "_$_" in link.sink_name
|
||||
or "_@_" in link.sink_name
|
||||
):
|
||||
# For dynamic fields, provide a generic string schema
|
||||
# The executor will handle merging these into the appropriate structure
|
||||
properties[sink_name] = {
|
||||
if is_dynamic:
|
||||
# For dynamic fields, use cleaned name but preserve original in description
|
||||
properties[clean_field_name] = {
|
||||
"type": "string",
|
||||
"description": f"Dynamic value for {link.sink_name}",
|
||||
"description": get_dynamic_field_description(field_name),
|
||||
}
|
||||
else:
|
||||
# For regular fields, use the block's schema
|
||||
# For regular fields, use the block's schema directly
|
||||
try:
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
properties[clean_field_name] = (
|
||||
sink_block_input_schema.get_field_schema(field_name)
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
# If the field doesn't exist in the schema, provide a generic schema
|
||||
properties[sink_name] = {
|
||||
# If field doesn't exist in schema, provide a generic one
|
||||
properties[clean_field_name] = {
|
||||
"type": "string",
|
||||
"description": f"Value for {link.sink_name}",
|
||||
"description": f"Value for {field_name}",
|
||||
}
|
||||
|
||||
# Build the parameters schema using a single unified path
|
||||
base_schema = block.input_schema.jsonschema()
|
||||
base_required = set(base_schema.get("required", []))
|
||||
|
||||
# Compute required fields at the leaf level:
|
||||
# - If a linked field is dynamic and its base is required in the block schema, require the leaf
|
||||
# - If a linked field is regular and is required in the block schema, require the leaf
|
||||
required_fields: set[str] = set()
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Always use cleaned field name for property key (Anthropic API compliance)
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
|
||||
if is_dynamic:
|
||||
base_name = extract_base_field_name(field_name)
|
||||
if base_name in base_required:
|
||||
required_fields.add(clean_field_name)
|
||||
else:
|
||||
if field_name in base_required:
|
||||
required_fields.add(clean_field_name)
|
||||
|
||||
tool_function["parameters"] = {
|
||||
**block.input_schema.jsonschema(),
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"additionalProperties": False,
|
||||
"required": sorted(required_fields),
|
||||
}
|
||||
|
||||
# Store field mapping for later use in output processing
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
@@ -366,13 +413,12 @@ class SmartDecisionMakerBlock(Block):
|
||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||
link.sink_name, {}
|
||||
)
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
description = (
|
||||
sink_block_properties["description"]
|
||||
if "description" in sink_block_properties
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[sink_name] = {
|
||||
properties[link.sink_name] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||
@@ -388,24 +434,17 @@ class SmartDecisionMakerBlock(Block):
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
async def _create_function_signature(
|
||||
node_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
This method filters the graph links to identify those that are tools and are
|
||||
connected to the given node_id. It then constructs function signatures for each
|
||||
tool based on the metadata and input schema of the linked nodes.
|
||||
Creates function signatures for connected tools.
|
||||
|
||||
Args:
|
||||
node_id: The node_id for which to create function signatures.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
|
||||
for a tool, including its name, description, and parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
List of function signatures for tools
|
||||
"""
|
||||
db_client = get_database_manager_async_client()
|
||||
tools = [
|
||||
@@ -430,20 +469,116 @@ class SmartDecisionMakerBlock(Block):
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
return_tool_functions.append(
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
else:
|
||||
return_tool_functions.append(
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
|
||||
return return_tool_functions
|
||||
|
||||
async def _attempt_llm_call_with_validation(
|
||||
self,
|
||||
credentials: llm.APIKeyCredentials,
|
||||
input_data: Input,
|
||||
current_prompt: list[dict],
|
||||
tool_functions: list[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Attempt a single LLM call with tool validation.
|
||||
|
||||
Returns the response if successful, raises ValueError if validation fails.
|
||||
"""
|
||||
resp = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=current_prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
|
||||
# Track LLM usage stats per call
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=resp.prompt_tokens,
|
||||
output_token_count=resp.completion_tokens,
|
||||
llm_call_count=1,
|
||||
)
|
||||
)
|
||||
|
||||
if not resp.tool_calls:
|
||||
return resp
|
||||
validation_errors_list: list[str] = []
|
||||
for tool_call in resp.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
try:
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
except Exception as e:
|
||||
validation_errors_list.append(
|
||||
f"Tool call '{tool_name}' has invalid JSON arguments: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_def is None and len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
|
||||
# Get parameters schema from tool definition
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
parameters = tool_def["function"]["parameters"]
|
||||
expected_args = parameters.get("properties", {})
|
||||
required_params = set(parameters.get("required", []))
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
required_params = set()
|
||||
|
||||
# Validate tool call arguments
|
||||
provided_args = set(tool_args.keys())
|
||||
expected_args_set = set(expected_args.keys())
|
||||
|
||||
# Check for unexpected arguments (typos)
|
||||
unexpected_args = provided_args - expected_args_set
|
||||
# Only check for missing REQUIRED parameters
|
||||
missing_required_args = required_params - provided_args
|
||||
|
||||
if unexpected_args or missing_required_args:
|
||||
error_msg = f"Tool call '{tool_name}' has parameter errors:"
|
||||
if unexpected_args:
|
||||
error_msg += f" Unknown parameters: {sorted(unexpected_args)}."
|
||||
if missing_required_args:
|
||||
error_msg += f" Missing required parameters: {sorted(missing_required_args)}."
|
||||
error_msg += f" Expected parameters: {sorted(expected_args_set)}."
|
||||
if required_params:
|
||||
error_msg += f" Required parameters: {sorted(required_params)}."
|
||||
validation_errors_list.append(error_msg)
|
||||
|
||||
if validation_errors_list:
|
||||
raise ValueError("; ".join(validation_errors_list))
|
||||
|
||||
return resp
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
@@ -466,27 +601,19 @@ class SmartDecisionMakerBlock(Block):
|
||||
if pending_tool_calls and input_data.last_tool_output is None:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
# Only assign the last tool output to the first pending tool call
|
||||
tool_output = []
|
||||
if pending_tool_calls and input_data.last_tool_output is not None:
|
||||
# Get the first pending tool call ID
|
||||
first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
tool_output.append(
|
||||
_create_tool_response(first_call_id, input_data.last_tool_output)
|
||||
)
|
||||
|
||||
# Add tool output to prompt right away
|
||||
prompt.extend(tool_output)
|
||||
|
||||
# Check if there are still pending tool calls after handling the first one
|
||||
remaining_pending_calls = get_pending_tool_calls(prompt)
|
||||
|
||||
# If there are still pending tool calls, yield the conversation and return early
|
||||
if remaining_pending_calls:
|
||||
yield "conversations", prompt
|
||||
return
|
||||
|
||||
# Fallback on adding tool output in the conversation history as user prompt.
|
||||
elif input_data.last_tool_output:
|
||||
logger.error(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
@@ -519,24 +646,33 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
current_prompt = list(prompt)
|
||||
max_attempts = max(1, int(input_data.retry))
|
||||
response = None
|
||||
|
||||
# Track LLM usage stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
llm_call_count=1,
|
||||
last_error = None
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = await self._attempt_llm_call_with_validation(
|
||||
credentials, input_data, current_prompt, tool_functions
|
||||
)
|
||||
break
|
||||
|
||||
except ValueError as e:
|
||||
last_error = e
|
||||
error_feedback = (
|
||||
"Your tool call had parameter errors. Please fix the following issues and try again:\n"
|
||||
+ f"- {str(e)}\n"
|
||||
+ "\nPlease make sure to use the exact parameter names as specified in the function schema."
|
||||
)
|
||||
current_prompt = list(current_prompt) + [
|
||||
{"role": "user", "content": error_feedback}
|
||||
]
|
||||
|
||||
if response is None:
|
||||
raise last_error or ValueError(
|
||||
"Failed to get valid response after all retry attempts"
|
||||
)
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", response.response
|
||||
@@ -546,7 +682,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
@@ -555,7 +690,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
@@ -563,20 +697,38 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
else:
|
||||
expected_args = tool_args.keys()
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
|
||||
# Yield provided arguments and None for missing ones
|
||||
for arg_name in expected_args:
|
||||
if arg_name in tool_args:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
# Get field mapping from tool definition
|
||||
field_mapping = (
|
||||
tool_def.get("function", {}).get("_field_mapping", {})
|
||||
if tool_def
|
||||
else {}
|
||||
)
|
||||
|
||||
for clean_arg_name in expected_args:
|
||||
# arg_name is now always the cleaned field name (for Anthropic API compliance)
|
||||
# Get the original field name from field mapping for proper emit key generation
|
||||
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
arg_value = tool_args.get(clean_arg_name)
|
||||
|
||||
sanitized_tool_name = self.cleanup(tool_name)
|
||||
sanitized_arg_name = self.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sanitized_tool_name}_~_{sanitized_arg_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
graph_exec_id,
|
||||
node_exec_id,
|
||||
emit_key,
|
||||
)
|
||||
yield emit_key, arg_value
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(response.raw_response)
|
||||
prompt.append(_convert_raw_response_to_dict(response.raw_response))
|
||||
|
||||
yield "conversations", prompt
|
||||
|
||||
@@ -19,7 +19,7 @@ async def test_block_ids_valid(block: Type[Block]):
|
||||
# Skip list for blocks with known invalid UUIDs
|
||||
skip_blocks = {
|
||||
"GetWeatherInformationBlock",
|
||||
"CodeExecutionBlock",
|
||||
"ExecuteCodeBlock",
|
||||
"CountdownTimerBlock",
|
||||
"TwitterGetListTweetsBlock",
|
||||
"TwitterRemoveListMemberBlock",
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Test security fixes for various DoS vulnerabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.code_extraction_block import CodeExtractionBlock
|
||||
from backend.blocks.iteration import StepThroughItemsBlock
|
||||
from backend.blocks.llm import AITextSummarizerBlock
|
||||
from backend.blocks.text import ExtractTextInformationBlock
|
||||
from backend.blocks.xml_parser import XMLParserBlock
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class TestCodeExtractionBlockSecurity:
|
||||
"""Test ReDoS fixes in CodeExtractionBlock."""
|
||||
|
||||
async def test_redos_protection(self):
|
||||
"""Test that the regex patterns don't cause ReDoS."""
|
||||
block = CodeExtractionBlock()
|
||||
|
||||
# Test with input that would previously cause ReDoS
|
||||
malicious_input = "```python" + " " * 10000 # Large spaces
|
||||
|
||||
result = []
|
||||
async for output_name, output_data in block.run(
|
||||
CodeExtractionBlock.Input(text=malicious_input)
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(result) >= 1
|
||||
assert any(name == "remaining_text" for name, _ in result)
|
||||
|
||||
|
||||
class TestAITextSummarizerBlockSecurity:
|
||||
"""Test memory exhaustion fixes in AITextSummarizerBlock."""
|
||||
|
||||
def test_split_text_limits(self):
|
||||
"""Test that _split_text has proper limits."""
|
||||
# Test text size limit
|
||||
large_text = "a" * 2_000_000 # 2MB text
|
||||
result = AITextSummarizerBlock._split_text(large_text, 1000, 100)
|
||||
|
||||
# Should be truncated to 1MB
|
||||
total_chars = sum(len(chunk) for chunk in result)
|
||||
assert total_chars <= 1_000_000 + 1000 # Allow for chunk boundary
|
||||
|
||||
# Test chunk count limit
|
||||
result = AITextSummarizerBlock._split_text("word " * 10000, 10, 9)
|
||||
assert len(result) <= 100 # MAX_CHUNKS limit
|
||||
|
||||
# Test parameter validation
|
||||
result = AITextSummarizerBlock._split_text(
|
||||
"test", 10, 15
|
||||
) # overlap > max_tokens
|
||||
assert len(result) >= 1 # Should still work
|
||||
|
||||
|
||||
class TestExtractTextInformationBlockSecurity:
|
||||
"""Test ReDoS and memory exhaustion fixes in ExtractTextInformationBlock."""
|
||||
|
||||
async def test_text_size_limits(self):
|
||||
"""Test text size limits."""
|
||||
block = ExtractTextInformationBlock()
|
||||
|
||||
# Test with large input
|
||||
large_text = "a" * 2_000_000 # 2MB
|
||||
|
||||
results = []
|
||||
async for output_name, output_data in block.run(
|
||||
ExtractTextInformationBlock.Input(
|
||||
text=large_text, pattern=r"a+", find_all=True, group=0
|
||||
)
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
# Should complete and have limits applied
|
||||
matched_results = [r for name, r in results if name == "matched_results"]
|
||||
if matched_results:
|
||||
assert len(matched_results[0]) <= 1000 # MAX_MATCHES limit
|
||||
|
||||
async def test_dangerous_pattern_timeout(self):
|
||||
"""Test timeout protection for dangerous patterns."""
|
||||
block = ExtractTextInformationBlock()
|
||||
|
||||
# Test with potentially dangerous lookahead pattern
|
||||
test_input = "a" * 1000
|
||||
|
||||
# This should complete quickly due to timeout protection
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results = []
|
||||
async for output_name, output_data in block.run(
|
||||
ExtractTextInformationBlock.Input(
|
||||
text=test_input, pattern=r"(?=.+)", find_all=True, group=0
|
||||
)
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
# Should complete within reasonable time (much less than 5s timeout)
|
||||
assert (end_time - start_time) < 10
|
||||
|
||||
async def test_redos_catastrophic_backtracking(self):
|
||||
"""Test that ReDoS patterns with catastrophic backtracking are handled."""
|
||||
block = ExtractTextInformationBlock()
|
||||
|
||||
# Pattern that causes catastrophic backtracking: (a+)+b
|
||||
# With input "aaaaaaaaaaaaaaaaaaaaaaaaaaaa" (no 'b'), this causes exponential time
|
||||
dangerous_pattern = r"(a+)+b"
|
||||
test_input = "a" * 30 # 30 'a's without a 'b' at the end
|
||||
|
||||
# This should be handled by timeout protection or pattern detection
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results = []
|
||||
|
||||
async for output_name, output_data in block.run(
|
||||
ExtractTextInformationBlock.Input(
|
||||
text=test_input, pattern=dangerous_pattern, find_all=True, group=0
|
||||
)
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
elapsed = end_time - start_time
|
||||
|
||||
# Should complete within timeout (6 seconds to be safe)
|
||||
# The current threading.Timer approach doesn't work, so this will likely fail
|
||||
# demonstrating the need for a fix
|
||||
assert elapsed < 6, f"Regex took {elapsed}s, timeout mechanism failed"
|
||||
|
||||
# Should return empty results on timeout or no match
|
||||
matched_results = [r for name, r in results if name == "matched_results"]
|
||||
assert matched_results[0] == [] # No matches expected
|
||||
|
||||
|
||||
class TestStepThroughItemsBlockSecurity:
|
||||
"""Test iteration limits in StepThroughItemsBlock."""
|
||||
|
||||
async def test_item_count_limits(self):
|
||||
"""Test maximum item count limits."""
|
||||
block = StepThroughItemsBlock()
|
||||
|
||||
# Test with too many items
|
||||
large_list = list(range(20000)) # Exceeds MAX_ITEMS (10000)
|
||||
|
||||
with pytest.raises(ValueError, match="Too many items"):
|
||||
async for _ in block.run(StepThroughItemsBlock.Input(items=large_list)):
|
||||
pass
|
||||
|
||||
async def test_string_size_limits(self):
|
||||
"""Test string input size limits."""
|
||||
block = StepThroughItemsBlock()
|
||||
|
||||
# Test with large JSON string
|
||||
large_string = '["item"]' * 200000 # Large JSON string
|
||||
|
||||
with pytest.raises(ValueError, match="Input too large"):
|
||||
async for _ in block.run(
|
||||
StepThroughItemsBlock.Input(items_str=large_string)
|
||||
):
|
||||
pass
|
||||
|
||||
async def test_normal_iteration_works(self):
|
||||
"""Test that normal iteration still works."""
|
||||
block = StepThroughItemsBlock()
|
||||
|
||||
results = []
|
||||
async for output_name, output_data in block.run(
|
||||
StepThroughItemsBlock.Input(items=[1, 2, 3])
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
# Should have 6 outputs (item, key for each of 3 items)
|
||||
assert len(results) == 6
|
||||
items = [data for name, data in results if name == "item"]
|
||||
assert items == [1, 2, 3]
|
||||
|
||||
|
||||
class TestXMLParserBlockSecurity:
|
||||
"""Test XML size limits in XMLParserBlock."""
|
||||
|
||||
async def test_xml_size_limits(self):
|
||||
"""Test XML input size limits."""
|
||||
block = XMLParserBlock()
|
||||
|
||||
# Test with large XML - need to exceed 10MB limit
|
||||
# Each "<item>data</item>" is 17 chars, need ~620K items for >10MB
|
||||
large_xml = "<root>" + "<item>data</item>" * 620000 + "</root>"
|
||||
|
||||
with pytest.raises(ValueError, match="XML too large"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
@patch("backend.util.file.scan_content_safe")
|
||||
@patch("backend.util.file.get_cloud_storage_handler")
|
||||
async def test_file_size_limits(self, mock_cloud_storage, mock_scan):
|
||||
"""Test file size limits."""
|
||||
# Mock cloud storage handler - get_cloud_storage_handler is async
|
||||
# but is_cloud_path and parse_cloud_path are sync methods
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
|
||||
# Make get_cloud_storage_handler an async function that returns the mock handler
|
||||
async def async_get_handler():
|
||||
return mock_handler
|
||||
|
||||
mock_cloud_storage.side_effect = async_get_handler
|
||||
mock_scan.return_value = None
|
||||
|
||||
# Test with large base64 content
|
||||
large_content = "a" * (200 * 1024 * 1024) # 200MB
|
||||
large_data_uri = f"data:text/plain;base64,{large_content}"
|
||||
|
||||
with pytest.raises(ValueError, match="File too large"):
|
||||
await store_media_file(
|
||||
graph_exec_id="test",
|
||||
file=MediaFileType(large_data_uri),
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
@patch("backend.util.file.Path")
|
||||
@patch("backend.util.file.scan_content_safe")
|
||||
@patch("backend.util.file.get_cloud_storage_handler")
|
||||
async def test_directory_size_limits(self, mock_cloud_storage, mock_scan, MockPath):
|
||||
"""Test directory size limits."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
|
||||
async def async_get_handler():
|
||||
return mock_handler
|
||||
|
||||
mock_cloud_storage.side_effect = async_get_handler
|
||||
mock_scan.return_value = None
|
||||
|
||||
# Create mock path instance for the execution directory
|
||||
mock_path_instance = MagicMock()
|
||||
mock_path_instance.exists.return_value = True
|
||||
|
||||
# Mock glob to return files that total > 1GB
|
||||
mock_file = MagicMock()
|
||||
mock_file.is_file.return_value = True
|
||||
mock_file.stat.return_value.st_size = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
mock_path_instance.glob.return_value = [mock_file]
|
||||
|
||||
# Make Path() return our mock
|
||||
MockPath.return_value = mock_path_instance
|
||||
|
||||
# Should raise an error when directory size exceeds limit
|
||||
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
||||
await store_media_file(
|
||||
graph_exec_id="test",
|
||||
file=MediaFileType(
|
||||
"data:text/plain;base64,dGVzdA=="
|
||||
), # Small test file
|
||||
user_id="test_user",
|
||||
)
|
||||
@@ -216,8 +216,17 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
}
|
||||
|
||||
# Mock the _create_function_signature method to avoid database calls
|
||||
with patch("backend.blocks.llm.llm_call", return_value=mock_response), patch.object(
|
||||
SmartDecisionMakerBlock, "_create_function_signature", return_value=[]
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
|
||||
# Create test input
|
||||
@@ -249,3 +258,471 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Verify outputs
|
||||
assert "finished" in outputs # Should have finished since no tool calls
|
||||
assert outputs["finished"] == "I need to think about this."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions with specific parameter schema
|
||||
mock_tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_keywords",
|
||||
"description": "Search for keywords with difficulty filtering",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"max_keyword_difficulty": {
|
||||
"type": "integer",
|
||||
"description": "Maximum keyword difficulty (required)",
|
||||
},
|
||||
"optional_param": {
|
||||
"type": "string",
|
||||
"description": "Optional parameter with default",
|
||||
"default": "default_value",
|
||||
},
|
||||
},
|
||||
"required": ["query", "max_keyword_difficulty"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test case 1: Tool call with TYPO in parameter name (should retry and eventually fail)
|
||||
mock_tool_call_with_typo = MagicMock()
|
||||
mock_tool_call_with_typo.function.name = "search_keywords"
|
||||
mock_tool_call_with_typo.function.arguments = '{"query": "test", "maximum_keyword_difficulty": 50}' # TYPO: maximum instead of max
|
||||
|
||||
mock_response_with_typo = MagicMock()
|
||||
mock_response_with_typo.response = None
|
||||
mock_response_with_typo.tool_calls = [mock_tool_call_with_typo]
|
||||
mock_response_with_typo.prompt_tokens = 50
|
||||
mock_response_with_typo.completion_tokens = 25
|
||||
mock_response_with_typo.reasoning = None
|
||||
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
)
|
||||
|
||||
# Should raise ValueError after retries due to typo'd parameter name
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify error message contains details about the typo
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call 'search_keywords' has parameter errors" in error_msg
|
||||
assert "Unknown parameters: ['maximum_keyword_difficulty']" in error_msg
|
||||
|
||||
# Verify that LLM was called the expected number of times (retries)
|
||||
assert mock_llm_call.call_count == 2 # Should retry based on input_data.retry
|
||||
|
||||
# Test case 2: Tool call missing REQUIRED parameter (should raise ValueError)
|
||||
mock_tool_call_missing_required = MagicMock()
|
||||
mock_tool_call_missing_required.function.name = "search_keywords"
|
||||
mock_tool_call_missing_required.function.arguments = (
|
||||
'{"query": "test"}' # Missing required max_keyword_difficulty
|
||||
)
|
||||
|
||||
mock_response_missing_required = MagicMock()
|
||||
mock_response_missing_required.response = None
|
||||
mock_response_missing_required.tool_calls = [mock_tool_call_missing_required]
|
||||
mock_response_missing_required.prompt_tokens = 50
|
||||
mock_response_missing_required.completion_tokens = 25
|
||||
mock_response_missing_required.reasoning = None
|
||||
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should raise ValueError due to missing required parameter
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call 'search_keywords' has parameter errors" in error_msg
|
||||
assert "Missing required parameters: ['max_keyword_difficulty']" in error_msg
|
||||
|
||||
# Test case 3: Valid tool call with OPTIONAL parameter missing (should succeed)
|
||||
mock_tool_call_valid = MagicMock()
|
||||
mock_tool_call_valid.function.name = "search_keywords"
|
||||
mock_tool_call_valid.function.arguments = '{"query": "test", "max_keyword_difficulty": 50}' # optional_param missing, but that's OK
|
||||
|
||||
mock_response_valid = MagicMock()
|
||||
mock_response_valid.response = None
|
||||
mock_response_valid.tool_calls = [mock_tool_call_valid]
|
||||
mock_response_valid.prompt_tokens = 50
|
||||
mock_response_valid.completion_tokens = 25
|
||||
mock_response_valid.reasoning = None
|
||||
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should succeed - optional parameter missing is OK
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify tool outputs were generated correctly
|
||||
assert "tools_^_search_keywords_~_query" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert "tools_^_search_keywords_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
# Optional parameter should be None when not provided
|
||||
assert "tools_^_search_keywords_~_optional_param" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] is None
|
||||
|
||||
# Test case 4: Valid tool call with ALL parameters (should succeed)
|
||||
mock_tool_call_all_params = MagicMock()
|
||||
mock_tool_call_all_params.function.name = "search_keywords"
|
||||
mock_tool_call_all_params.function.arguments = '{"query": "test", "max_keyword_difficulty": 50, "optional_param": "custom_value"}'
|
||||
|
||||
mock_response_all_params = MagicMock()
|
||||
mock_response_all_params.response = None
|
||||
mock_response_all_params.tool_calls = [mock_tool_call_all_params]
|
||||
mock_response_all_params.prompt_tokens = 50
|
||||
mock_response_all_params.completion_tokens = 25
|
||||
mock_response_all_params.reasoning = None
|
||||
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should succeed with all parameters
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify all tool outputs were generated correctly
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] == "custom_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_raw_response_conversion():
|
||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions
|
||||
mock_tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test case 1: Simulate ChatCompletionMessage raw_response that caused the original error
|
||||
class MockChatCompletionMessage:
|
||||
"""Simulate OpenAI's ChatCompletionMessage object that lacks .get() method"""
|
||||
|
||||
def __init__(self, role, content, tool_calls=None):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.tool_calls = tool_calls or []
|
||||
|
||||
# This is what caused the error - no .get() method
|
||||
# def get(self, key, default=None): # Intentionally missing
|
||||
|
||||
# First response: has invalid parameter name (triggers retry)
|
||||
mock_tool_call_invalid = MagicMock()
|
||||
mock_tool_call_invalid.function.name = "test_tool"
|
||||
mock_tool_call_invalid.function.arguments = (
|
||||
'{"wrong_param": "test_value"}' # Invalid parameter name
|
||||
)
|
||||
|
||||
mock_response_retry = MagicMock()
|
||||
mock_response_retry.response = None
|
||||
mock_response_retry.tool_calls = [mock_tool_call_invalid]
|
||||
mock_response_retry.prompt_tokens = 50
|
||||
mock_response_retry.completion_tokens = 25
|
||||
mock_response_retry.reasoning = None
|
||||
# This would cause the original error without our fix
|
||||
mock_response_retry.raw_response = MockChatCompletionMessage(
|
||||
role="assistant", content=None, tool_calls=[mock_tool_call_invalid]
|
||||
)
|
||||
|
||||
# Second response: successful (correct parameter name)
|
||||
mock_tool_call_valid = MagicMock()
|
||||
mock_tool_call_valid.function.name = "test_tool"
|
||||
mock_tool_call_valid.function.arguments = (
|
||||
'{"param": "test_value"}' # Correct parameter name
|
||||
)
|
||||
|
||||
mock_response_success = MagicMock()
|
||||
mock_response_success.response = None
|
||||
mock_response_success.tool_calls = [mock_tool_call_valid]
|
||||
mock_response_success.prompt_tokens = 50
|
||||
mock_response_success.completion_tokens = 25
|
||||
mock_response_success.reasoning = None
|
||||
mock_response_success.raw_response = MockChatCompletionMessage(
|
||||
role="assistant", content=None, tool_calls=[mock_tool_call_valid]
|
||||
)
|
||||
|
||||
# Mock llm_call to return different responses on different calls
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
# First call returns response that will trigger retry due to validation error
|
||||
# Second call returns successful response
|
||||
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
# Should succeed after retry, demonstrating our helper function works
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify the tool output was generated successfully
|
||||
assert "tools_^_test_tool_~_param" in outputs
|
||||
assert outputs["tools_^_test_tool_~_param"] == "test_value"
|
||||
|
||||
# Verify conversation history was properly maintained
|
||||
assert "conversations" in outputs
|
||||
conversations = outputs["conversations"]
|
||||
assert len(conversations) > 0
|
||||
|
||||
# The conversations should contain properly converted raw_response objects as dicts
|
||||
# This would have failed with the original bug due to ChatCompletionMessage.get() error
|
||||
for msg in conversations:
|
||||
assert isinstance(msg, dict), f"Expected dict, got {type(msg)}"
|
||||
if msg.get("role") == "assistant":
|
||||
# Should have been converted from ChatCompletionMessage to dict
|
||||
assert "role" in msg
|
||||
|
||||
# Verify LLM was called twice (initial + 1 retry)
|
||||
assert mock_llm_call.call_count == 2
|
||||
|
||||
# Test case 2: Test with different raw_response types (Ollama string, dict)
|
||||
# Test Ollama string response
|
||||
mock_response_ollama = MagicMock()
|
||||
mock_response_ollama.response = "I'll help you with that."
|
||||
mock_response_ollama.tool_calls = None
|
||||
mock_response_ollama.prompt_tokens = 30
|
||||
mock_response_ollama.completion_tokens = 15
|
||||
mock_response_ollama.reasoning = None
|
||||
mock_response_ollama.raw_response = (
|
||||
"I'll help you with that." # Ollama returns string
|
||||
)
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Should finish since no tool calls
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "I'll help you with that."
|
||||
|
||||
# Test case 3: Test with dict raw_response (some providers/tests)
|
||||
mock_response_dict = MagicMock()
|
||||
mock_response_dict.response = "Test response"
|
||||
mock_response_dict.tool_calls = None
|
||||
mock_response_dict.prompt_tokens = 25
|
||||
mock_response_dict.completion_tokens = 10
|
||||
mock_response_dict.reasoning = None
|
||||
mock_response_dict.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": "Test response",
|
||||
} # Dict format
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "Test response"
|
||||
|
||||
@@ -48,16 +48,24 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
assert "parameters" in signature["function"]
|
||||
assert "properties" in signature["function"]["parameters"]
|
||||
|
||||
# Check that dynamic fields are handled
|
||||
# Check that dynamic fields are handled with original names
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3 # Should have all three fields
|
||||
|
||||
# Each dynamic field should have proper schema
|
||||
for prop_value in properties.values():
|
||||
# Check that field names are cleaned (for Anthropic API compatibility)
|
||||
assert "values___name" in properties
|
||||
assert "values___age" in properties
|
||||
assert "values___city" in properties
|
||||
|
||||
# Each dynamic field should have proper schema with descriptive text
|
||||
for field_name, prop_value in properties.items():
|
||||
assert "type" in prop_value
|
||||
assert prop_value["type"] == "string" # Dynamic fields get string type
|
||||
assert "description" in prop_value
|
||||
assert "Dynamic value for" in prop_value["description"]
|
||||
# Check that descriptions properly explain the dynamic field
|
||||
if field_name == "values___name":
|
||||
assert "Dictionary field 'name'" in prop_value["description"]
|
||||
assert "values['name']" in prop_value["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -96,10 +104,18 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 2 # Should have both list items
|
||||
|
||||
# Each dynamic field should have proper schema
|
||||
for prop_value in properties.values():
|
||||
# Check that field names are cleaned (for Anthropic API compatibility)
|
||||
assert "entries___0" in properties
|
||||
assert "entries___1" in properties
|
||||
|
||||
# Each dynamic field should have proper schema with descriptive text
|
||||
for field_name, prop_value in properties.items():
|
||||
assert prop_value["type"] == "string"
|
||||
assert "Dynamic value for" in prop_value["description"]
|
||||
assert "description" in prop_value
|
||||
# Check that descriptions properly explain the list field
|
||||
if field_name == "entries___0":
|
||||
assert "List item 0" in prop_value["description"]
|
||||
assert "entries[0]" in prop_value["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -0,0 +1,553 @@
|
||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.text import MatchTextPatternBlock
|
||||
from backend.data.dynamic_fields import get_dynamic_field_description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_field_description_generation():
|
||||
"""Test that dynamic field descriptions are generated correctly."""
|
||||
# Test dictionary field description
|
||||
desc = get_dynamic_field_description("values_#_name")
|
||||
assert "Dictionary field 'name' for base field 'values'" in desc
|
||||
assert "values['name']" in desc
|
||||
|
||||
# Test list field description
|
||||
desc = get_dynamic_field_description("items_$_0")
|
||||
assert "List item 0 for base field 'items'" in desc
|
||||
assert "items[0]" in desc
|
||||
|
||||
# Test object field description
|
||||
desc = get_dynamic_field_description("user_@_email")
|
||||
assert "Object attribute 'email' for base field 'user'" in desc
|
||||
assert "user.email" in desc
|
||||
|
||||
# Test regular field fallback
|
||||
desc = get_dynamic_field_description("regular_field")
|
||||
assert desc == "Value for regular_field"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_dict_fields():
|
||||
"""Test that function signatures are created correctly for dictionary dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___name", # Sanitized source
|
||||
sink_name="values_#_name", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___age", # Sanitized source
|
||||
sink_name="values_#_age", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___email", # Sanitized source
|
||||
sink_name="values_#_email", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
assert signature["type"] == "function"
|
||||
assert "function" in signature
|
||||
assert "parameters" in signature["function"]
|
||||
assert "properties" in signature["function"]["parameters"]
|
||||
|
||||
# Check that dynamic fields are handled with original names
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "values___name" in properties
|
||||
assert "values___age" in properties
|
||||
assert "values___email" in properties
|
||||
|
||||
# Check descriptions mention they are dictionary fields
|
||||
assert "Dictionary field" in properties["values___name"]["description"]
|
||||
assert "values['name']" in properties["values___name"]["description"]
|
||||
|
||||
assert "Dictionary field" in properties["values___age"]["description"]
|
||||
assert "values['age']" in properties["values___age"]["description"]
|
||||
|
||||
assert "Dictionary field" in properties["values___email"]["description"]
|
||||
assert "values['email']" in properties["values___email"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_list_fields():
|
||||
"""Test that function signatures are created correctly for list dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_2",
|
||||
sink_name="entries_$_2", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
assert signature["type"] == "function"
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "entries___0" in properties
|
||||
assert "entries___1" in properties
|
||||
assert "entries___2" in properties
|
||||
|
||||
# Check descriptions mention they are list items
|
||||
assert "List item 0" in properties["entries___0"]["description"]
|
||||
assert "entries[0]" in properties["entries___0"]["description"]
|
||||
|
||||
assert "List item 1" in properties["entries___1"]["description"]
|
||||
assert "entries[1]" in properties["entries___1"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_object_fields():
|
||||
"""Test that function signatures are created correctly for object dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for MatchTextPatternBlock (simulating object fields)
|
||||
mock_node = Mock()
|
||||
mock_node.block = MatchTextPatternBlock()
|
||||
mock_node.block_id = MatchTextPatternBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic object fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_name",
|
||||
sink_name="data_@_user_name", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_email",
|
||||
sink_name="data_@_user_email", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "data___user_name" in properties
|
||||
assert "data___user_email" in properties
|
||||
|
||||
# Check descriptions mention they are object attributes
|
||||
assert "Object attribute" in properties["data___user_name"]["description"]
|
||||
assert "data.user_name" in properties["data___user_name"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_function_signature():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock the database client and connected nodes
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
) as mock_db:
|
||||
mock_client = AsyncMock()
|
||||
mock_db.return_value = mock_client
|
||||
|
||||
# Create mock nodes and links
|
||||
mock_dict_node = Mock()
|
||||
mock_dict_node.block = CreateDictionaryBlock()
|
||||
mock_dict_node.block_id = CreateDictionaryBlock().id
|
||||
mock_dict_node.input_default = {}
|
||||
|
||||
mock_list_node = Mock()
|
||||
mock_list_node.block = AddToListBlock()
|
||||
mock_list_node.block_id = AddToListBlock().id
|
||||
mock_list_node.input_default = {}
|
||||
|
||||
# Mock links with dynamic fields
|
||||
dict_link1 = Mock(
|
||||
source_name="tools_^_create_dictionary_~_name",
|
||||
sink_name="values_#_name",
|
||||
sink_id="dict_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
dict_link2 = Mock(
|
||||
source_name="tools_^_create_dictionary_~_age",
|
||||
sink_name="values_#_age",
|
||||
sink_id="dict_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
list_link = Mock(
|
||||
source_name="tools_^_add_to_list_~_0",
|
||||
sink_name="entries_$_0",
|
||||
sink_id="list_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
|
||||
mock_client.get_connected_output_nodes.return_value = [
|
||||
(dict_link1, mock_dict_node),
|
||||
(dict_link2, mock_dict_node),
|
||||
(list_link, mock_list_node),
|
||||
]
|
||||
|
||||
# Call the method that builds signatures
|
||||
tool_functions = await block._create_function_signature("test_node_id")
|
||||
|
||||
# Verify we got 2 tool functions (one for dict, one for list)
|
||||
assert len(tool_functions) == 2
|
||||
|
||||
# Verify the tool functions contain the dynamic field names
|
||||
dict_tool = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == "createdictionaryblock"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert dict_tool is not None
|
||||
dict_properties = dict_tool["function"]["parameters"]["properties"]
|
||||
assert "values___name" in dict_properties
|
||||
assert "values___age" in dict_properties
|
||||
|
||||
list_tool = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == "addtolistblock"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert list_tool is not None
|
||||
list_properties = list_tool["function"]["parameters"]["properties"]
|
||||
assert "entries___0" in list_properties
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_yielding_with_dynamic_fields():
|
||||
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# No more sanitized mapping needed since we removed sanitization
|
||||
|
||||
# Mock LLM response with tool calls
|
||||
mock_response = Mock()
|
||||
mock_response.tool_calls = [
|
||||
Mock(
|
||||
function=Mock(
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"values___name": "Alice",
|
||||
"values___age": 30,
|
||||
"values___email": "alice@example.com",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
mock_response.tool_calls[0].function.name = "createdictionaryblock"
|
||||
mock_response.reasoning = "Creating a dictionary with user information"
|
||||
mock_response.raw_response = {"role": "assistant", "content": "test"}
|
||||
mock_response.prompt_tokens = 100
|
||||
mock_response.completion_tokens = 50
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "createdictionaryblock",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"values___name": {"type": "string"},
|
||||
"values___age": {"type": "number"},
|
||||
"values___email": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create input data
|
||||
from backend.blocks import llm
|
||||
|
||||
input_data = block.input_schema(
|
||||
prompt="Create a user dictionary",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
)
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
graph_id="test_graph",
|
||||
node_id="test_node",
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify the outputs use sanitized field names (matching frontend normalizeToolName)
|
||||
assert "tools_^_createdictionaryblock_~_values___name" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___name"] == "Alice"
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___age" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___age"] == 30
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___email" in outputs
|
||||
assert (
|
||||
outputs["tools_^_createdictionaryblock_~_values___email"]
|
||||
== "alice@example.com"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_regular_and_dynamic_fields():
|
||||
"""Test handling of blocks with both regular and dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node
|
||||
mock_node = Mock()
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "A test block"
|
||||
mock_node.block.input_schema = Mock()
|
||||
|
||||
# Mock the get_field_schema to return a proper schema for regular fields
|
||||
def get_field_schema(field_name):
|
||||
if field_name == "regular_field":
|
||||
return {"type": "string", "description": "A regular field"}
|
||||
elif field_name == "values":
|
||||
return {"type": "object", "description": "A dictionary field"}
|
||||
else:
|
||||
raise KeyError(f"Field {field_name} not found")
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
|
||||
# Create links with both regular and dynamic fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_test_~_regular",
|
||||
sink_name="regular_field", # Regular field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key",
|
||||
sink_name="values_#_key1", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key2",
|
||||
sink_name="values_#_key2", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Check properties
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3
|
||||
|
||||
# Regular field should have its original schema
|
||||
assert "regular_field" in properties
|
||||
assert properties["regular_field"]["description"] == "A regular field"
|
||||
|
||||
# Dynamic fields should have generated descriptions
|
||||
assert "values___key1" in properties
|
||||
assert "Dictionary field" in properties["values___key1"]["description"]
|
||||
|
||||
assert "values___key2" in properties
|
||||
assert "Dictionary field" in properties["values___key2"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_errors_dont_pollute_conversation():
|
||||
"""Test that validation errors are only used during retries and don't pollute the conversation."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Track conversation history changes
|
||||
conversation_snapshots = []
|
||||
|
||||
# Mock response with invalid tool call (missing required parameter)
|
||||
invalid_response = Mock()
|
||||
invalid_response.tool_calls = [
|
||||
Mock(
|
||||
function=Mock(
|
||||
arguments=json.dumps({"wrong_param": "value"}), # Wrong parameter name
|
||||
)
|
||||
)
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
invalid_response.tool_calls[0].function.name = "test_tool"
|
||||
invalid_response.reasoning = None
|
||||
invalid_response.raw_response = {"role": "assistant", "content": "invalid"}
|
||||
invalid_response.prompt_tokens = 100
|
||||
invalid_response.completion_tokens = 50
|
||||
|
||||
# Mock valid response after retry
|
||||
valid_response = Mock()
|
||||
valid_response.tool_calls = [
|
||||
Mock(function=Mock(arguments=json.dumps({"correct_param": "value"})))
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
valid_response.tool_calls[0].function.name = "test_tool"
|
||||
valid_response.reasoning = None
|
||||
valid_response.raw_response = {"role": "assistant", "content": "valid"}
|
||||
valid_response.prompt_tokens = 100
|
||||
valid_response.completion_tokens = 50
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
# Capture conversation state
|
||||
conversation_snapshots.append(kwargs.get("prompt", []).copy())
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return invalid_response
|
||||
else:
|
||||
return valid_response
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.side_effect = mock_llm_call
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"correct_param": {
|
||||
"type": "string",
|
||||
"description": "The correct parameter",
|
||||
}
|
||||
},
|
||||
"required": ["correct_param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create input data
|
||||
from backend.blocks import llm
|
||||
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
retry=3, # Allow retries
|
||||
)
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
graph_id="test_graph",
|
||||
node_id="test_node",
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify we had 2 LLM calls (initial + retry)
|
||||
assert call_count == 2
|
||||
|
||||
# Check the final conversation output
|
||||
final_conversation = outputs.get("conversations", [])
|
||||
|
||||
# The final conversation should NOT contain the validation error message
|
||||
error_messages = [
|
||||
msg
|
||||
for msg in final_conversation
|
||||
if msg.get("role") == "user"
|
||||
and "parameter errors" in msg.get("content", "")
|
||||
]
|
||||
assert (
|
||||
len(error_messages) == 0
|
||||
), "Validation error leaked into final conversation"
|
||||
|
||||
# The final conversation should only have the successful response
|
||||
assert final_conversation[-1]["content"] == "valid"
|
||||
131
autogpt_platform/backend/backend/blocks/test/test_table_input.py
Normal file
131
autogpt_platform/backend/backend/blocks/test/test_table_input.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import pytest
|
||||
|
||||
from backend.blocks.io import AgentTableInputBlock
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_block():
|
||||
"""Test the AgentTableInputBlock with basic input/output."""
|
||||
block = AgentTableInputBlock()
|
||||
await execute_block_test(block)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_data():
|
||||
"""Test AgentTableInputBlock with actual table data."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="test_table",
|
||||
column_headers=["Name", "Age", "City"],
|
||||
value=[
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
{"Name": "Bob", "Age": "35", "City": "Paris"},
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 3
|
||||
assert result[0]["Name"] == "John"
|
||||
assert result[1]["Age"] == "25"
|
||||
assert result[2]["City"] == "Paris"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_empty_data():
|
||||
"""Test AgentTableInputBlock with empty data."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="empty_table", column_headers=["Col1", "Col2"], value=[]
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
assert output_data[0][1] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_missing_columns():
|
||||
"""Test AgentTableInputBlock passes through data with missing columns as-is."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="partial_table",
|
||||
column_headers=["Name", "Age", "City"],
|
||||
value=[
|
||||
{"Name": "John", "Age": "30"}, # Missing City
|
||||
{"Name": "Jane", "City": "London"}, # Missing Age
|
||||
{"Age": "35", "City": "Paris"}, # Missing Name
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 3
|
||||
|
||||
# Check data is passed through as-is
|
||||
assert result[0] == {"Name": "John", "Age": "30"}
|
||||
assert result[1] == {"Name": "Jane", "City": "London"}
|
||||
assert result[2] == {"Age": "35", "City": "Paris"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_none_value():
|
||||
"""Test AgentTableInputBlock with None value returns empty list."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="none_table", column_headers=["Name", "Age"], value=None
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
assert output_data[0][1] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_default_headers():
|
||||
"""Test AgentTableInputBlock with default column headers."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
# Don't specify column_headers, should use defaults
|
||||
input_data = block.Input(
|
||||
name="default_headers_table",
|
||||
value=[
|
||||
{"Column 1": "A", "Column 2": "B", "Column 3": "C"},
|
||||
{"Column 1": "D", "Column 2": "E", "Column 3": "F"},
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 2
|
||||
assert result[0]["Column 1"] == "A"
|
||||
assert result[1]["Column 3"] == "F"
|
||||
@@ -2,6 +2,8 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import regex # Has built-in timeout support
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json, text
|
||||
@@ -137,6 +139,11 @@ class ExtractTextInformationBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add limits to prevent ReDoS and memory exhaustion
|
||||
MAX_TEXT_LENGTH = 1_000_000 # 1MB character limit
|
||||
MAX_MATCHES = 1000 # Maximum number of matches to prevent memory exhaustion
|
||||
MAX_MATCH_LENGTH = 10_000 # Maximum length per match
|
||||
|
||||
flags = 0
|
||||
if not input_data.case_sensitive:
|
||||
flags = flags | re.IGNORECASE
|
||||
@@ -148,20 +155,85 @@ class ExtractTextInformationBlock(Block):
|
||||
else:
|
||||
txt = json.dumps(input_data.text)
|
||||
|
||||
matches = [
|
||||
match.group(input_data.group)
|
||||
for match in re.finditer(input_data.pattern, txt, flags)
|
||||
if input_data.group <= len(match.groups())
|
||||
]
|
||||
if not input_data.find_all:
|
||||
matches = matches[:1]
|
||||
for match in matches:
|
||||
yield "positive", match
|
||||
if not matches:
|
||||
yield "negative", input_data.text
|
||||
# Limit text size to prevent DoS
|
||||
if len(txt) > MAX_TEXT_LENGTH:
|
||||
txt = txt[:MAX_TEXT_LENGTH]
|
||||
|
||||
yield "matched_results", matches
|
||||
yield "matched_count", len(matches)
|
||||
# Validate regex pattern to prevent dangerous patterns
|
||||
dangerous_patterns = [
|
||||
r".*\+.*\+", # Nested quantifiers
|
||||
r".*\*.*\*", # Nested quantifiers
|
||||
r"(?=.*\+)", # Lookahead with quantifier
|
||||
r"(?=.*\*)", # Lookahead with quantifier
|
||||
r"\(.+\)\+", # Group with nested quantifier
|
||||
r"\(.+\)\*", # Group with nested quantifier
|
||||
r"\([^)]+\+\)\+", # Nested quantifiers like (a+)+
|
||||
r"\([^)]+\*\)\*", # Nested quantifiers like (a*)*
|
||||
]
|
||||
|
||||
# Check if pattern is potentially dangerous
|
||||
is_dangerous = any(
|
||||
re.search(dangerous, input_data.pattern) for dangerous in dangerous_patterns
|
||||
)
|
||||
|
||||
# Use regex module with timeout for dangerous patterns
|
||||
# For safe patterns, use standard re module for compatibility
|
||||
try:
|
||||
matches = []
|
||||
match_count = 0
|
||||
|
||||
if is_dangerous:
|
||||
# Use regex module with timeout (5 seconds) for dangerous patterns
|
||||
# The regex module supports timeout parameter in finditer
|
||||
try:
|
||||
for match in regex.finditer(
|
||||
input_data.pattern, txt, flags=flags, timeout=5.0
|
||||
):
|
||||
if match_count >= MAX_MATCHES:
|
||||
break
|
||||
if input_data.group <= len(match.groups()):
|
||||
match_text = match.group(input_data.group)
|
||||
# Limit match length to prevent memory exhaustion
|
||||
if len(match_text) > MAX_MATCH_LENGTH:
|
||||
match_text = match_text[:MAX_MATCH_LENGTH]
|
||||
matches.append(match_text)
|
||||
match_count += 1
|
||||
except regex.error as e:
|
||||
# Timeout occurred or regex error
|
||||
if "timeout" in str(e).lower():
|
||||
# Timeout - return empty results
|
||||
pass
|
||||
else:
|
||||
# Other regex error
|
||||
raise
|
||||
else:
|
||||
# Use standard re module for non-dangerous patterns
|
||||
for match in re.finditer(input_data.pattern, txt, flags):
|
||||
if match_count >= MAX_MATCHES:
|
||||
break
|
||||
if input_data.group <= len(match.groups()):
|
||||
match_text = match.group(input_data.group)
|
||||
# Limit match length to prevent memory exhaustion
|
||||
if len(match_text) > MAX_MATCH_LENGTH:
|
||||
match_text = match_text[:MAX_MATCH_LENGTH]
|
||||
matches.append(match_text)
|
||||
match_count += 1
|
||||
|
||||
if not input_data.find_all:
|
||||
matches = matches[:1]
|
||||
|
||||
for match in matches:
|
||||
yield "positive", match
|
||||
if not matches:
|
||||
yield "negative", input_data.text
|
||||
|
||||
yield "matched_results", matches
|
||||
yield "matched_count", len(matches)
|
||||
except Exception:
|
||||
# Return empty results on any regex error
|
||||
yield "negative", input_data.text
|
||||
yield "matched_results", []
|
||||
yield "matched_count", 0
|
||||
|
||||
|
||||
class FillTextTemplateBlock(Block):
|
||||
|
||||
@@ -270,13 +270,17 @@ class GetCurrentDateBlock(Block):
|
||||
test_output=[
|
||||
(
|
||||
"date",
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%Y-%m-%d"))
|
||||
< timedelta(days=8), # 7 days difference + 1 day error margin.
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%Y-%m-%d").date()
|
||||
)
|
||||
<= timedelta(days=8), # 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
"date",
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%m/%d/%Y"))
|
||||
< timedelta(days=8),
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%m/%d/%Y").date()
|
||||
)
|
||||
<= timedelta(days=8),
|
||||
# 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
@@ -382,7 +386,7 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%Y/%m/%d").date()
|
||||
)
|
||||
< timedelta(days=1), # Date format only, no time component
|
||||
<= timedelta(days=1), # Date format only, no time component
|
||||
),
|
||||
(
|
||||
"date_time",
|
||||
|
||||
@@ -26,6 +26,14 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add size limits to prevent XML bomb attacks
|
||||
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
|
||||
|
||||
if len(input_data.input_xml) > MAX_XML_SIZE:
|
||||
raise ValueError(
|
||||
f"XML too large: {len(input_data.input_xml)} bytes > {MAX_XML_SIZE} bytes"
|
||||
)
|
||||
|
||||
try:
|
||||
tokens = tokenize(input_data.input_xml)
|
||||
parser = Parser(tokens)
|
||||
|
||||
@@ -9,6 +9,7 @@ from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -178,9 +179,13 @@ async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
|
||||
async def list_user_api_keys(
|
||||
user_id: str, limit: int = MAX_USER_API_KEYS_FETCH
|
||||
) -> list[APIKeyInfo]:
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
|
||||
return [APIKeyInfo.from_db(key) for key in api_keys]
|
||||
|
||||
@@ -69,6 +69,8 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
|
||||
@@ -23,6 +23,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -905,7 +906,9 @@ class UserCredit(UserCreditBase):
|
||||
),
|
||||
)
|
||||
|
||||
async def get_refund_requests(self, user_id: str) -> list[RefundRequest]:
|
||||
async def get_refund_requests(
|
||||
self, user_id: str, limit: int = MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
) -> list[RefundRequest]:
|
||||
return [
|
||||
RefundRequest(
|
||||
id=r.id,
|
||||
@@ -921,6 +924,7 @@ class UserCredit(UserCreditBase):
|
||||
for r in await CreditRefundRequest.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
284
autogpt_platform/backend/backend/data/dynamic_fields.py
Normal file
284
autogpt_platform/backend/backend/data/dynamic_fields.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Utilities for handling dynamic field names with special delimiters.
|
||||
|
||||
Dynamic fields allow graphs to connect complex data structures using special delimiters:
|
||||
- _#_ for dictionary keys (e.g., "values_#_name" → values["name"])
|
||||
- _$_ for list indices (e.g., "items_$_0" → items[0])
|
||||
- _@_ for object attributes (e.g., "obj_@_attr" → obj.attr)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
# Dynamic field delimiters
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
DYNAMIC_DELIMITERS = (LIST_SPLIT, DICT_SPLIT, OBJC_SPLIT)
|
||||
|
||||
|
||||
def extract_base_field_name(field_name: str) -> str:
|
||||
"""
|
||||
Extract the base field name from a dynamic field name by removing all dynamic suffixes.
|
||||
|
||||
Examples:
|
||||
extract_base_field_name("values_#_name") → "values"
|
||||
extract_base_field_name("items_$_0") → "items"
|
||||
extract_base_field_name("obj_@_attr") → "obj"
|
||||
extract_base_field_name("regular_field") → "regular_field"
|
||||
|
||||
Args:
|
||||
field_name: The field name that may contain dynamic delimiters
|
||||
|
||||
Returns:
|
||||
The base field name without any dynamic suffixes
|
||||
"""
|
||||
base_name = field_name
|
||||
for delimiter in DYNAMIC_DELIMITERS:
|
||||
if delimiter in base_name:
|
||||
base_name = base_name.split(delimiter)[0]
|
||||
return base_name
|
||||
|
||||
|
||||
def is_dynamic_field(field_name: str) -> bool:
|
||||
"""
|
||||
Check if a field name contains dynamic delimiters.
|
||||
|
||||
Args:
|
||||
field_name: The field name to check
|
||||
|
||||
Returns:
|
||||
True if the field contains any dynamic delimiters, False otherwise
|
||||
"""
|
||||
return any(delimiter in field_name for delimiter in DYNAMIC_DELIMITERS)
|
||||
|
||||
|
||||
def get_dynamic_field_description(field_name: str) -> str:
|
||||
"""
|
||||
Generate a description for a dynamic field based on its structure.
|
||||
|
||||
Args:
|
||||
field_name: The full dynamic field name (e.g., "values_#_name")
|
||||
|
||||
Returns:
|
||||
A descriptive string explaining what this dynamic field represents
|
||||
"""
|
||||
base_name = extract_base_field_name(field_name)
|
||||
|
||||
if DICT_SPLIT in field_name:
|
||||
# Extract the key part after _#_
|
||||
parts = field_name.split(DICT_SPLIT)
|
||||
if len(parts) > 1:
|
||||
key = parts[1].split("_")[0] if "_" in parts[1] else parts[1]
|
||||
return f"Dictionary field '{key}' for base field '{base_name}' ({base_name}['{key}'])"
|
||||
elif LIST_SPLIT in field_name:
|
||||
# Extract the index part after _$_
|
||||
parts = field_name.split(LIST_SPLIT)
|
||||
if len(parts) > 1:
|
||||
index = parts[1].split("_")[0] if "_" in parts[1] else parts[1]
|
||||
return (
|
||||
f"List item {index} for base field '{base_name}' ({base_name}[{index}])"
|
||||
)
|
||||
elif OBJC_SPLIT in field_name:
|
||||
# Extract the attribute part after _@_
|
||||
parts = field_name.split(OBJC_SPLIT)
|
||||
if len(parts) > 1:
|
||||
# Get the full attribute name (everything after _@_)
|
||||
attr = parts[1]
|
||||
return f"Object attribute '{attr}' for base field '{base_name}' ({base_name}.{attr})"
|
||||
|
||||
return f"Value for {field_name}"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Dynamic field parsing and merging utilities
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _next_delim(s: str) -> tuple[str | None, int]:
|
||||
"""
|
||||
Return the *earliest* delimiter appearing in `s` and its index.
|
||||
|
||||
If none present → (None, -1).
|
||||
"""
|
||||
first: str | None = None
|
||||
pos = len(s) # sentinel: larger than any real index
|
||||
for d in DYNAMIC_DELIMITERS:
|
||||
i = s.find(d)
|
||||
if 0 <= i < pos:
|
||||
first, pos = d, i
|
||||
return first, (pos if first else -1)
|
||||
|
||||
|
||||
def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Convert the raw path string (starting with a delimiter) into
|
||||
[ (delimiter, identifier), … ] or None if the syntax is malformed.
|
||||
"""
|
||||
tokens: list[tuple[str, str]] = []
|
||||
while path:
|
||||
# 1. Which delimiter starts this chunk?
|
||||
delim = next((d for d in DYNAMIC_DELIMITERS if path.startswith(d)), None)
|
||||
if delim is None:
|
||||
return None # invalid syntax
|
||||
|
||||
# 2. Slice off the delimiter, then up to the next delimiter (or EOS)
|
||||
path = path[len(delim) :]
|
||||
nxt_delim, pos = _next_delim(path)
|
||||
token, path = (
|
||||
path[: pos if pos != -1 else len(path)],
|
||||
path[pos if pos != -1 else len(path) :],
|
||||
)
|
||||
if token == "":
|
||||
return None # empty identifier is invalid
|
||||
tokens.append((delim, token))
|
||||
return tokens
|
||||
|
||||
|
||||
def parse_execution_output(output: tuple[str, Any], name: str) -> Any:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path)
|
||||
returns **None**.
|
||||
|
||||
Args:
|
||||
output: Tuple of (base_name, data) representing a block output entry
|
||||
name: The flattened field name to extract from the output data
|
||||
|
||||
Returns:
|
||||
The value at the specified path, or None if not found/invalid
|
||||
"""
|
||||
base_name, data = output
|
||||
|
||||
# Exact match → whole object
|
||||
if name == base_name:
|
||||
return data
|
||||
|
||||
# Must start with the expected name
|
||||
if not name.startswith(base_name):
|
||||
return None
|
||||
path = name[len(base_name) :]
|
||||
if not path:
|
||||
return None # nothing left to parse
|
||||
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: Any = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
return None
|
||||
if not isinstance(cur, list) or idx >= len(cur):
|
||||
return None
|
||||
cur = cur[idx]
|
||||
|
||||
elif delim == DICT_SPLIT:
|
||||
if not isinstance(cur, dict) or ident not in cur:
|
||||
return None
|
||||
cur = cur[ident]
|
||||
|
||||
elif delim == OBJC_SPLIT:
|
||||
if not hasattr(cur, ident):
|
||||
return None
|
||||
cur = getattr(cur, ident)
|
||||
|
||||
else:
|
||||
return None # unreachable
|
||||
|
||||
return cur
|
||||
|
||||
|
||||
def _assign(container: Any, tokens: list[tuple[str, str]], value: Any) -> Any:
|
||||
"""
|
||||
Recursive helper that *returns* the (possibly new) container with
|
||||
`value` assigned along the remaining `tokens` path.
|
||||
"""
|
||||
if not tokens:
|
||||
return value # leaf reached
|
||||
|
||||
delim, ident = tokens[0]
|
||||
rest = tokens[1:]
|
||||
|
||||
# ---------- list ----------
|
||||
if delim == LIST_SPLIT:
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
raise ValueError("index must be an integer")
|
||||
|
||||
if container is None:
|
||||
container = []
|
||||
elif not isinstance(container, list):
|
||||
container = list(container) if hasattr(container, "__iter__") else []
|
||||
|
||||
while len(container) <= idx:
|
||||
container.append(None)
|
||||
container[idx] = _assign(container[idx], rest, value)
|
||||
return container
|
||||
|
||||
# ---------- dict ----------
|
||||
if delim == DICT_SPLIT:
|
||||
if container is None:
|
||||
container = {}
|
||||
elif not isinstance(container, dict):
|
||||
container = dict(container) if hasattr(container, "items") else {}
|
||||
container[ident] = _assign(container.get(ident), rest, value)
|
||||
return container
|
||||
|
||||
# ---------- object ----------
|
||||
if delim == OBJC_SPLIT:
|
||||
if container is None:
|
||||
container = MockObject()
|
||||
elif not hasattr(container, "__dict__"):
|
||||
# If it's not an object, create a new one
|
||||
container = MockObject()
|
||||
setattr(
|
||||
container,
|
||||
ident,
|
||||
_assign(getattr(container, ident, None), rest, value),
|
||||
)
|
||||
return container
|
||||
|
||||
return value # unreachable
|
||||
|
||||
|
||||
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Reconstruct nested objects from a *flattened* dict of key → value.
|
||||
|
||||
Raises ValueError on syntactically invalid list indices.
|
||||
|
||||
Args:
|
||||
data: Dictionary with potentially flattened dynamic field keys
|
||||
|
||||
Returns:
|
||||
Dictionary with nested objects reconstructed from flattened keys
|
||||
"""
|
||||
merged: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
# Split off the base name (before the first delimiter, if any)
|
||||
delim, pos = _next_delim(key)
|
||||
if delim is None:
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
base, path = key[:pos], key[pos:]
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
# Invalid key; treat as scalar under the raw name
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
merged[base] = _assign(merged.get(base), tokens, value)
|
||||
|
||||
data.update(merged)
|
||||
return data
|
||||
@@ -38,8 +38,8 @@ from prisma.types import (
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
@@ -478,6 +478,48 @@ async def get_graph_executions(
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
async def get_graph_executions_count(
|
||||
user_id: Optional[str] = None,
|
||||
graph_id: Optional[str] = None,
|
||||
statuses: Optional[list[ExecutionStatus]] = None,
|
||||
created_time_gte: Optional[datetime] = None,
|
||||
created_time_lte: Optional[datetime] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get count of graph executions with optional filters.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter by
|
||||
graph_id: Optional graph ID to filter by
|
||||
statuses: Optional list of execution statuses to filter by
|
||||
created_time_gte: Optional minimum creation time
|
||||
created_time_lte: Optional maximum creation time
|
||||
|
||||
Returns:
|
||||
Count of matching graph executions
|
||||
"""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
|
||||
if user_id:
|
||||
where_filter["userId"] = user_id
|
||||
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
|
||||
if created_time_gte or created_time_lte:
|
||||
where_filter["createdAt"] = {
|
||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
if statuses:
|
||||
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
count = await AgentGraphExecution.prisma().count(where=where_filter)
|
||||
return count
|
||||
|
||||
|
||||
class GraphExecutionsPaginated(BaseModel):
|
||||
"""Response schema for paginated graph executions."""
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from prisma.enums import AgentExecutionStatus
|
||||
from backend.data.execution import get_graph_executions
|
||||
from backend.data.graph import get_graph_metadata
|
||||
from backend.data.model import UserExecutionSummaryStats
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[SummaryData]")
|
||||
|
||||
@@ -20,6 +20,8 @@ from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.dynamic_fields import extract_base_field_name
|
||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
@@ -31,7 +33,15 @@ from backend.util import type as type_utils
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .block import (
|
||||
Block,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
EmptySchema,
|
||||
get_block,
|
||||
get_blocks,
|
||||
)
|
||||
from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
@@ -72,12 +82,15 @@ class Node(BaseDbModel):
|
||||
output_links: list[Link] = []
|
||||
|
||||
@property
|
||||
def block(self) -> Block[BlockSchema, BlockSchema]:
|
||||
def block(self) -> "Block[BlockSchema, BlockSchema] | _UnknownBlockBase":
|
||||
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
||||
block = get_block(self.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Block #{self.block_id} does not exist -> Node #{self.id} is invalid"
|
||||
# Log warning but don't raise exception - return a placeholder block for deleted blocks
|
||||
logger.warning(
|
||||
f"Block #{self.block_id} does not exist for Node #{self.id} (deleted/missing block), using UnknownBlock"
|
||||
)
|
||||
return _UnknownBlockBase(self.block_id)
|
||||
return block
|
||||
|
||||
|
||||
@@ -116,17 +129,20 @@ class NodeModel(Node):
|
||||
Returns a copy of the node model, stripped of any non-transferable properties
|
||||
"""
|
||||
stripped_node = self.model_copy(deep=True)
|
||||
# Remove credentials from node input
|
||||
|
||||
# Remove credentials and other (possible) secrets from node input
|
||||
if stripped_node.input_default:
|
||||
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
|
||||
stripped_node.input_default, self.block.input_schema.jsonschema()
|
||||
)
|
||||
|
||||
# Remove default secret value from secret input nodes
|
||||
if (
|
||||
stripped_node.block.block_type == BlockType.INPUT
|
||||
and stripped_node.input_default.get("secret", False) is True
|
||||
and "value" in stripped_node.input_default
|
||||
):
|
||||
stripped_node.input_default["value"] = ""
|
||||
del stripped_node.input_default["value"]
|
||||
|
||||
# Remove webhook info
|
||||
stripped_node.webhook_id = None
|
||||
@@ -143,8 +159,10 @@ class NodeModel(Node):
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
field_schema: dict | None = field_schemas.get(key)
|
||||
if (field_schema and field_schema.get("secret", False)) or any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
if (field_schema and field_schema.get("secret", False)) or (
|
||||
any(sensitive_key in key.lower() for sensitive_key in sensitive_keys)
|
||||
# Prevent removing `secret` flag on input nodes
|
||||
and type(value) is not bool
|
||||
):
|
||||
# This is a secret value -> filter this key-value pair out
|
||||
continue
|
||||
@@ -729,7 +747,7 @@ def _is_tool_pin(name: str) -> bool:
|
||||
|
||||
|
||||
def _sanitize_pin_name(name: str) -> str:
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
sanitized_name = extract_base_field_name(name)
|
||||
if _is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
@@ -1059,11 +1077,14 @@ async def set_graph_active_version(graph_id: str, version: int, user_id: str) ->
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_all_versions(graph_id: str, user_id: str) -> list[GraphModel]:
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: str, limit: int = MAX_GRAPH_VERSIONS_FETCH
|
||||
) -> list[GraphModel]:
|
||||
graph_versions = await AgentGraph.prisma().find_many(
|
||||
where={"id": graph_id, "userId": user_id},
|
||||
order={"version": "desc"},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
take=limit,
|
||||
)
|
||||
|
||||
if not graph_versions:
|
||||
@@ -1312,3 +1333,34 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
id,
|
||||
path,
|
||||
)
|
||||
|
||||
|
||||
# Simple placeholder class for deleted/missing blocks
|
||||
class _UnknownBlockBase(Block):
|
||||
"""
|
||||
Placeholder for deleted/missing blocks that inherits from Block
|
||||
but uses a name that doesn't end with 'Block' to avoid auto-discovery.
|
||||
"""
|
||||
|
||||
def __init__(self, block_id: str = "00000000-0000-0000-0000-000000000000"):
|
||||
# Initialize with minimal valid Block parameters
|
||||
super().__init__(
|
||||
id=block_id,
|
||||
description=f"Unknown or deleted block (original ID: {block_id})",
|
||||
disabled=True,
|
||||
input_schema=EmptySchema,
|
||||
output_schema=EmptySchema,
|
||||
categories=set(),
|
||||
contributors=[],
|
||||
static_output=False,
|
||||
block_type=BlockType.STANDARD,
|
||||
webhook_config=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "UnknownBlock"
|
||||
|
||||
async def run(self, input_data, **kwargs):
|
||||
"""Always yield an error for missing blocks."""
|
||||
yield "error", f"Block {self.id} no longer exists"
|
||||
|
||||
@@ -201,25 +201,56 @@ async def test_get_input_schema(server: SpinTestServer, snapshot: Snapshot):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clean_graph(server: SpinTestServer):
|
||||
"""
|
||||
Test the clean_graph function that:
|
||||
1. Clears input block values
|
||||
2. Removes credentials from nodes
|
||||
Test the stripped_for_export function that:
|
||||
1. Removes sensitive/secret fields from node inputs
|
||||
2. Removes webhook information
|
||||
3. Preserves non-sensitive data including input block values
|
||||
"""
|
||||
# Create a graph with input blocks and credentials
|
||||
# Create a graph with input blocks containing both sensitive and normal data
|
||||
graph = Graph(
|
||||
id="test_clean_graph",
|
||||
name="Test Clean Graph",
|
||||
description="Test graph cleaning",
|
||||
nodes=[
|
||||
Node(
|
||||
id="input_node",
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={
|
||||
"_test_id": "input_node",
|
||||
"name": "test_input",
|
||||
"value": "test value",
|
||||
"value": "test value", # This should be preserved
|
||||
"description": "Test input description",
|
||||
},
|
||||
),
|
||||
Node(
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={
|
||||
"_test_id": "input_node_secret",
|
||||
"name": "secret_input",
|
||||
"value": "another value",
|
||||
"secret": True, # This makes the input secret
|
||||
},
|
||||
),
|
||||
Node(
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"_test_id": "node_with_secrets",
|
||||
"input": "normal_value",
|
||||
"control_test_input": "should be preserved",
|
||||
"api_key": "secret_api_key_123", # Should be filtered
|
||||
"password": "secret_password_456", # Should be filtered
|
||||
"token": "secret_token_789", # Should be filtered
|
||||
"credentials": { # Should be filtered
|
||||
"id": "fake-github-credentials-id",
|
||||
"provider": "github",
|
||||
"type": "api_key",
|
||||
},
|
||||
"anthropic_credentials": { # Should be filtered
|
||||
"id": "fake-anthropic-credentials-id",
|
||||
"provider": "anthropic",
|
||||
"type": "api_key",
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
@@ -231,15 +262,54 @@ async def test_clean_graph(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Clean the graph
|
||||
created_graph = await server.agent_server.test_get_graph(
|
||||
cleaned_graph = await server.agent_server.test_get_graph(
|
||||
created_graph.id, created_graph.version, DEFAULT_USER_ID, for_export=True
|
||||
)
|
||||
|
||||
# # Verify input block value is cleared
|
||||
# Verify sensitive fields are removed but normal fields are preserved
|
||||
input_node = next(
|
||||
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
|
||||
n for n in cleaned_graph.nodes if n.input_default["_test_id"] == "input_node"
|
||||
)
|
||||
assert input_node.input_default["value"] == ""
|
||||
|
||||
# Non-sensitive fields should be preserved
|
||||
assert input_node.input_default["name"] == "test_input"
|
||||
assert input_node.input_default["value"] == "test value" # Should be preserved now
|
||||
assert input_node.input_default["description"] == "Test input description"
|
||||
|
||||
# Sensitive fields should be filtered out
|
||||
assert "api_key" not in input_node.input_default
|
||||
assert "password" not in input_node.input_default
|
||||
|
||||
# Verify secret input node preserves non-sensitive fields but removes secret value
|
||||
secret_node = next(
|
||||
n
|
||||
for n in cleaned_graph.nodes
|
||||
if n.input_default["_test_id"] == "input_node_secret"
|
||||
)
|
||||
assert secret_node.input_default["name"] == "secret_input"
|
||||
assert "value" not in secret_node.input_default # Secret default should be removed
|
||||
assert secret_node.input_default["secret"] is True
|
||||
|
||||
# Verify sensitive fields are filtered from nodes with secrets
|
||||
secrets_node = next(
|
||||
n
|
||||
for n in cleaned_graph.nodes
|
||||
if n.input_default["_test_id"] == "node_with_secrets"
|
||||
)
|
||||
# Normal fields should be preserved
|
||||
assert secrets_node.input_default["input"] == "normal_value"
|
||||
assert secrets_node.input_default["control_test_input"] == "should be preserved"
|
||||
# Sensitive fields should be filtered out
|
||||
assert "api_key" not in secrets_node.input_default
|
||||
assert "password" not in secrets_node.input_default
|
||||
assert "token" not in secrets_node.input_default
|
||||
assert "credentials" not in secrets_node.input_default
|
||||
assert "anthropic_credentials" not in secrets_node.input_default
|
||||
|
||||
# Verify webhook info is removed (if any nodes had it)
|
||||
for node in cleaned_graph.nodes:
|
||||
assert node.webhook_id is None
|
||||
assert node.webhook is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -14,6 +14,7 @@ AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
|
||||
EXECUTION_RESULT_ORDER: list[prisma.types.AgentNodeExecutionOrderByInput] = [
|
||||
{"queuedTime": "desc"},
|
||||
# Fallback: Incomplete execs has no queuedTime.
|
||||
@@ -28,6 +29,13 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
MAX_LIBRARY_AGENT_EXECUTIONS_FETCH = 10
|
||||
|
||||
# Default limits for potentially large result sets
|
||||
MAX_CREDIT_REFUND_REQUESTS_FETCH = 100
|
||||
MAX_INTEGRATION_WEBHOOKS_FETCH = 100
|
||||
MAX_USER_API_KEYS_FETCH = 500
|
||||
MAX_GRAPH_VERSIONS_FETCH = 50
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
"NodeExecutions": {
|
||||
@@ -71,13 +79,56 @@ INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
}
|
||||
|
||||
|
||||
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
|
||||
return {
|
||||
"AgentGraph": {
|
||||
"include": {
|
||||
**AGENT_GRAPH_INCLUDE,
|
||||
"Executions": {"where": {"userId": user_id}},
|
||||
}
|
||||
},
|
||||
"Creator": True,
|
||||
def library_agent_include(
|
||||
user_id: str,
|
||||
include_nodes: bool = True,
|
||||
include_executions: bool = True,
|
||||
execution_limit: int = MAX_LIBRARY_AGENT_EXECUTIONS_FETCH,
|
||||
) -> prisma.types.LibraryAgentInclude:
|
||||
"""
|
||||
Fully configurable includes for library agent queries with performance optimization.
|
||||
|
||||
Args:
|
||||
user_id: User ID for filtering user-specific data
|
||||
include_nodes: Whether to include graph nodes (default: True, needed for get_sub_graphs)
|
||||
include_executions: Whether to include executions (default: True, safe with execution_limit)
|
||||
execution_limit: Limit on executions to fetch (default: MAX_LIBRARY_AGENT_EXECUTIONS_FETCH)
|
||||
|
||||
Defaults maintain backward compatibility and safety - includes everything needed for all functionality.
|
||||
For performance optimization, explicitly set include_nodes=False and include_executions=False
|
||||
for listing views where frontend fetches data separately.
|
||||
|
||||
Performance impact:
|
||||
- Default (full nodes + limited executions): Original performance, works everywhere
|
||||
- Listing optimization (no nodes/executions): ~2s for 15 agents vs potential timeouts
|
||||
- Unlimited executions: varies by user (thousands of executions = timeouts)
|
||||
"""
|
||||
result: prisma.types.LibraryAgentInclude = {
|
||||
"Creator": True, # Always needed for creator info
|
||||
}
|
||||
|
||||
# Build AgentGraph include based on requested options
|
||||
if include_nodes or include_executions:
|
||||
agent_graph_include = {}
|
||||
|
||||
# Add nodes if requested (always full nodes)
|
||||
if include_nodes:
|
||||
agent_graph_include.update(AGENT_GRAPH_INCLUDE) # Full nodes
|
||||
|
||||
# Add executions if requested
|
||||
if include_executions:
|
||||
agent_graph_include["Executions"] = {
|
||||
"where": {"userId": user_id},
|
||||
"order_by": {"createdAt": "desc"},
|
||||
"take": execution_limit,
|
||||
}
|
||||
|
||||
result["AgentGraph"] = cast(
|
||||
prisma.types.AgentGraphArgsFromLibraryAgent,
|
||||
{"include": agent_graph_include},
|
||||
)
|
||||
else:
|
||||
# Default: Basic metadata only (fast - recommended for most use cases)
|
||||
result["AgentGraph"] = True # Basic graph metadata (name, description, id)
|
||||
|
||||
return result
|
||||
|
||||
@@ -11,7 +11,10 @@ from prisma.types import (
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.includes import (
|
||||
INTEGRATION_WEBHOOK_INCLUDE,
|
||||
MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
@@ -128,22 +131,36 @@ async def get_webhook(
|
||||
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[True]
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: Literal[True],
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[WebhookWithRelations]: ...
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: Literal[False] = False,
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[Webhook]: ...
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: bool = False
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: bool = False,
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[Webhook] | list[WebhookWithRelations]:
|
||||
if not credentials_id:
|
||||
raise ValueError("credentials_id must not be empty")
|
||||
webhooks = await IntegrationWebhook.prisma().find_many(
|
||||
where={"userId": user_id, "credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [
|
||||
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
|
||||
@@ -270,6 +270,7 @@ def SchemaField(
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
json_schema_extra: Optional[dict[str, Any]] = None,
|
||||
) -> T:
|
||||
if default is PydanticUndefined and default_factory is None:
|
||||
@@ -285,6 +286,7 @@ def SchemaField(
|
||||
"advanced": advanced,
|
||||
"hidden": hidden,
|
||||
"depends_on": depends_on,
|
||||
"format": format,
|
||||
**(json_schema_extra or {}),
|
||||
}.items()
|
||||
if v is not None
|
||||
|
||||
@@ -15,7 +15,7 @@ from prisma.types import (
|
||||
# from backend.notifications.models import NotificationEvent
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
@@ -235,6 +235,7 @@ class BaseEventModel(BaseModel):
|
||||
|
||||
|
||||
class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
|
||||
id: Optional[str] = None # None when creating, populated when reading from DB
|
||||
data: NotificationDataType_co
|
||||
|
||||
@property
|
||||
@@ -378,6 +379,7 @@ class NotificationPreference(BaseModel):
|
||||
|
||||
|
||||
class UserNotificationEventDTO(BaseModel):
|
||||
id: str # Added to track notifications for removal
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime
|
||||
@@ -386,6 +388,7 @@ class UserNotificationEventDTO(BaseModel):
|
||||
@staticmethod
|
||||
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
|
||||
return UserNotificationEventDTO(
|
||||
id=model.id,
|
||||
type=model.type,
|
||||
data=dict(model.data),
|
||||
created_at=model.createdAt,
|
||||
@@ -541,6 +544,79 @@ async def empty_user_notification_batch(
|
||||
) from e
|
||||
|
||||
|
||||
async def clear_all_user_notification_batches(user_id: str) -> None:
|
||||
"""Clear ALL notification batches for a user across all types.
|
||||
|
||||
Used when user's email is bounced/inactive and we should stop
|
||||
trying to send them ANY emails.
|
||||
"""
|
||||
try:
|
||||
# Delete all notification events for this user
|
||||
await NotificationEvent.prisma().delete_many(
|
||||
where={"UserNotificationBatch": {"is": {"userId": user_id}}}
|
||||
)
|
||||
|
||||
# Delete all batches for this user
|
||||
await UserNotificationBatch.prisma().delete_many(where={"userId": user_id})
|
||||
|
||||
logger.info(f"Cleared all notification batches for user {user_id}")
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to clear all notification batches for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def remove_notifications_from_batch(
|
||||
user_id: str, notification_type: NotificationType, notification_ids: list[str]
|
||||
) -> None:
|
||||
"""Remove specific notifications from a user's batch by their IDs.
|
||||
|
||||
This is used after successful sending to remove only the
|
||||
sent notifications, preventing duplicates on retry.
|
||||
"""
|
||||
if not notification_ids:
|
||||
return
|
||||
|
||||
try:
|
||||
# Delete the specific notification events
|
||||
deleted_count = await NotificationEvent.prisma().delete_many(
|
||||
where={
|
||||
"id": {"in": notification_ids},
|
||||
"UserNotificationBatch": {
|
||||
"is": {"userId": user_id, "type": notification_type}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Removed {deleted_count} notifications from batch for user {user_id}"
|
||||
)
|
||||
|
||||
# Check if batch is now empty and delete it if so
|
||||
remaining = await NotificationEvent.prisma().count(
|
||||
where={
|
||||
"UserNotificationBatch": {
|
||||
"is": {"userId": user_id, "type": notification_type}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if remaining == 0:
|
||||
await UserNotificationBatch.prisma().delete_many(
|
||||
where=UserNotificationBatchWhereInput(
|
||||
userId=user_id,
|
||||
type=notification_type,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Deleted empty batch for user {user_id} and type {notification_type}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to remove notifications from batch for user {user_id} and type {notification_type}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma
|
||||
import pydantic
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.json import SafeJson
|
||||
@@ -30,7 +31,7 @@ user_credit = get_user_credit_model()
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
notificationDot: Optional[bool] = None
|
||||
walletShown: Optional[bool] = None
|
||||
notified: Optional[list[OnboardingStep]] = None
|
||||
usageReason: Optional[str] = None
|
||||
integrations: Optional[list[str]] = None
|
||||
@@ -39,6 +40,8 @@ class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
agentInput: Optional[dict[str, Any]] = None
|
||||
onboardingAgentExecutionId: Optional[str] = None
|
||||
agentRuns: Optional[int] = None
|
||||
lastRunAt: Optional[datetime] = None
|
||||
consecutiveRunDays: Optional[int] = None
|
||||
|
||||
|
||||
async def get_user_onboarding(user_id: str):
|
||||
@@ -57,16 +60,22 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["completedSteps"] = list(set(data.completedSteps))
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.RUN_AGENTS,
|
||||
OnboardingStep.MARKETPLACE_VISIT,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.BUILDER_RUN_AGENT,
|
||||
OnboardingStep.RE_RUN_AGENT,
|
||||
OnboardingStep.SCHEDULE_AGENT,
|
||||
OnboardingStep.RUN_AGENTS,
|
||||
OnboardingStep.RUN_3_DAYS,
|
||||
OnboardingStep.TRIGGER_WEBHOOK,
|
||||
OnboardingStep.RUN_14_DAYS,
|
||||
OnboardingStep.RUN_AGENTS_100,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step)
|
||||
if data.notificationDot is not None:
|
||||
update["notificationDot"] = data.notificationDot
|
||||
if data.walletShown is not None:
|
||||
update["walletShown"] = data.walletShown
|
||||
if data.notified is not None:
|
||||
update["notified"] = list(set(data.notified))
|
||||
if data.usageReason is not None:
|
||||
@@ -83,6 +92,10 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
if data.agentRuns is not None:
|
||||
update["agentRuns"] = data.agentRuns
|
||||
if data.lastRunAt is not None:
|
||||
update["lastRunAt"] = data.lastRunAt
|
||||
if data.consecutiveRunDays is not None:
|
||||
update["consecutiveRunDays"] = data.consecutiveRunDays
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
@@ -101,16 +114,28 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
# This is seen as a reward for the GET_RESULTS step in the wallet
|
||||
case OnboardingStep.AGENT_NEW_RUN:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_AGENTS:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_VISIT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_ADD_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_SAVE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_RUN_AGENT:
|
||||
case OnboardingStep.RE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.SCHEDULE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.RUN_AGENTS:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_3_DAYS:
|
||||
reward = 100
|
||||
case OnboardingStep.TRIGGER_WEBHOOK:
|
||||
reward = 100
|
||||
case OnboardingStep.RUN_14_DAYS:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_AGENTS_100:
|
||||
reward = 300
|
||||
|
||||
if reward == 0:
|
||||
return
|
||||
@@ -132,6 +157,22 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
)
|
||||
|
||||
|
||||
async def complete_webhook_trigger_step(user_id: str):
|
||||
"""
|
||||
Completes the TRIGGER_WEBHOOK onboarding step for the user if not already completed.
|
||||
"""
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
if OnboardingStep.TRIGGER_WEBHOOK not in onboarding.completedSteps:
|
||||
await update_user_onboarding(
|
||||
user_id,
|
||||
UserOnboardingUpdate(
|
||||
completedSteps=onboarding.completedSteps
|
||||
+ [OnboardingStep.TRIGGER_WEBHOOK]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
"""
|
||||
Removes all special characters from a string, truncates it to 100 characters,
|
||||
@@ -236,8 +277,14 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
for word in user_onboarding.integrations
|
||||
]
|
||||
|
||||
where_clause["is_available"] = True
|
||||
|
||||
# Try to take only agents that are available and allowed for onboarding
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
where={
|
||||
"is_available": True,
|
||||
"useForOnboarding": True,
|
||||
},
|
||||
order=[
|
||||
{"featured": "desc"},
|
||||
{"runs": "desc"},
|
||||
@@ -246,59 +293,16 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
take=100,
|
||||
)
|
||||
|
||||
agentListings = await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
|
||||
for listing in agentListings:
|
||||
agent = listing.AgentGraph
|
||||
if agent is None:
|
||||
continue
|
||||
graph = GraphModel.from_db(agent)
|
||||
# Remove agents with empty input schema
|
||||
if not graph.input_schema:
|
||||
storeAgents = [
|
||||
a for a in storeAgents if a.storeListingVersionId != listing.id
|
||||
]
|
||||
continue
|
||||
|
||||
# Remove agents with empty credentials
|
||||
# Get nodes from this agent that have credentials
|
||||
nodes = await prisma.models.AgentNode.prisma().find_many(
|
||||
where={
|
||||
"agentGraphId": agent.id,
|
||||
"agentBlockId": {"in": list(CREDENTIALS_FIELDS.keys())},
|
||||
},
|
||||
)
|
||||
for node in nodes:
|
||||
block_id = node.agentBlockId
|
||||
field_name = CREDENTIALS_FIELDS[block_id]
|
||||
# If there are no credentials or they are empty, remove the agent
|
||||
# FIXME ignores default values
|
||||
if (
|
||||
field_name not in node.constantInput
|
||||
or node.constantInput[field_name] is None
|
||||
):
|
||||
storeAgents = [
|
||||
a for a in storeAgents if a.storeListingVersionId != listing.id
|
||||
]
|
||||
break
|
||||
|
||||
# If there are less than 2 agents, add more agents to the list
|
||||
# If not enough agents found, relax the useForOnboarding filter
|
||||
if len(storeAgents) < 2:
|
||||
storeAgents += await prisma.models.StoreAgent.prisma().find_many(
|
||||
where={
|
||||
"listing_id": {"not_in": [agent.listing_id for agent in storeAgents]},
|
||||
},
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
order=[
|
||||
{"featured": "desc"},
|
||||
{"runs": "desc"},
|
||||
{"rating": "desc"},
|
||||
],
|
||||
take=2 - len(storeAgents),
|
||||
take=100,
|
||||
)
|
||||
|
||||
# Calculate points for the first X agents and choose the top 2
|
||||
@@ -333,8 +337,13 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
]
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl_seconds=300) # Cache for 5 minutes since this rarely changes
|
||||
async def onboarding_enabled() -> bool:
|
||||
"""
|
||||
Check if onboarding should be enabled based on store agent count.
|
||||
Cached to prevent repeated slow database queries.
|
||||
"""
|
||||
# Use a more efficient query that stops counting after finding enough agents
|
||||
count = await prisma.models.StoreAgent.prisma().count(take=MIN_AGENT_COUNT + 1)
|
||||
|
||||
# Onboading is enabled if there are at least 2 agents in the store
|
||||
# Onboarding is enabled if there are at least 2 agents in the store
|
||||
return count >= MIN_AGENT_COUNT
|
||||
|
||||
@@ -16,8 +16,8 @@ from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import User, UserIntegrations, UserMetadata
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.encryption import JSONCryptor
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -354,6 +354,36 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
) from e
|
||||
|
||||
|
||||
async def disable_all_user_notifications(user_id: str) -> None:
|
||||
"""Disable all notification preferences for a user.
|
||||
|
||||
Used when user's email bounces/is inactive to prevent any future notifications.
|
||||
"""
|
||||
try:
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={
|
||||
"notifyOnAgentRun": False,
|
||||
"notifyOnZeroBalance": False,
|
||||
"notifyOnLowBalance": False,
|
||||
"notifyOnBlockExecutionFailed": False,
|
||||
"notifyOnContinuousAgentError": False,
|
||||
"notifyOnDailySummary": False,
|
||||
"notifyOnWeeklySummary": False,
|
||||
"notifyOnMonthlySummary": False,
|
||||
"notifyOnAgentApproved": False,
|
||||
"notifyOnAgentRejected": False,
|
||||
},
|
||||
)
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
logger.info(f"Disabled all notification preferences for user {user_id}")
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to disable notifications for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_email_verification(user_id: str) -> bool:
|
||||
"""Get the email verification status for a user."""
|
||||
try:
|
||||
|
||||
@@ -4,7 +4,12 @@ Module for generating AI-based activity status for graph executions.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
try:
|
||||
from typing import NotRequired
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -146,17 +151,35 @@ async def generate_activity_status_for_execution(
|
||||
"Focus on the ACTUAL TASK the user wanted done, not the internal workflow steps. "
|
||||
"Avoid technical terms like 'workflow', 'execution', 'components', 'nodes', 'processing', etc. "
|
||||
"Keep it to 3 sentences maximum. Be conversational and human-friendly.\n\n"
|
||||
"UNDERSTAND THE INTENDED PURPOSE:\n"
|
||||
"- FIRST: Read the graph description carefully to understand what the user wanted to accomplish\n"
|
||||
"- The graph name and description tell you the main goal/intention of this automation\n"
|
||||
"- Use this intended purpose as your PRIMARY criteria for success/failure evaluation\n"
|
||||
"- Ask yourself: 'Did this execution actually accomplish what the graph was designed to do?'\n\n"
|
||||
"CRITICAL OUTPUT ANALYSIS:\n"
|
||||
"- Check if blocks that should produce user-facing results actually produced outputs\n"
|
||||
"- Blocks with names containing 'Output', 'Post', 'Create', 'Send', 'Publish', 'Generate' are usually meant to produce final results\n"
|
||||
"- If these critical blocks have NO outputs (empty recent_outputs), the task likely FAILED even if status shows 'completed'\n"
|
||||
"- Sub-agents (AgentExecutorBlock) that produce no outputs usually indicate failed sub-tasks\n"
|
||||
"- Most importantly: Does the execution result match what the graph description promised to deliver?\n\n"
|
||||
"SUCCESS EVALUATION BASED ON INTENTION:\n"
|
||||
"- If the graph is meant to 'create blog posts' → check if blog content was actually created\n"
|
||||
"- If the graph is meant to 'send emails' → check if emails were actually sent\n"
|
||||
"- If the graph is meant to 'analyze data' → check if analysis results were produced\n"
|
||||
"- If the graph is meant to 'generate reports' → check if reports were generated\n"
|
||||
"- Technical completion ≠ goal achievement. Focus on whether the USER'S INTENDED OUTCOME was delivered\n\n"
|
||||
"IMPORTANT: Be HONEST about what actually happened:\n"
|
||||
"- If the input was invalid/nonsensical, say so directly\n"
|
||||
"- If the task failed, explain what went wrong in simple terms\n"
|
||||
"- If errors occurred, focus on what the user needs to know\n"
|
||||
"- Only claim success if the task was genuinely completed\n"
|
||||
"- Don't sugar-coat failures or present them as helpful feedback\n\n"
|
||||
"- Only claim success if the INTENDED PURPOSE was genuinely accomplished AND produced expected outputs\n"
|
||||
"- Don't sugar-coat failures or present them as helpful feedback\n"
|
||||
"- ESPECIALLY: If the graph's main purpose wasn't achieved, this is a failure regardless of 'completed' status\n\n"
|
||||
"Understanding Errors:\n"
|
||||
"- Node errors: Individual steps may fail but the overall task might still complete (e.g., one data source fails but others work)\n"
|
||||
"- Graph error (in overall_status.graph_error): This means the entire execution failed and nothing was accomplished\n"
|
||||
"- Even if execution shows 'completed', check if critical nodes failed that would prevent the desired outcome\n"
|
||||
"- Focus on the end result the user wanted, not whether technical steps completed"
|
||||
"- Missing outputs from critical blocks: Even if no errors, this means the task failed to produce expected results\n"
|
||||
"- Focus on whether the graph's intended purpose was fulfilled, not whether technical steps completed"
|
||||
),
|
||||
},
|
||||
{
|
||||
@@ -165,15 +188,28 @@ async def generate_activity_status_for_execution(
|
||||
f"A user ran '{graph_name}' to accomplish something. Based on this execution data, "
|
||||
f"write what they achieved in simple, user-friendly terms:\n\n"
|
||||
f"{json.dumps(execution_data, indent=2)}\n\n"
|
||||
"CRITICAL: Check overall_status.graph_error FIRST - if present, the entire execution failed.\n"
|
||||
"Then check individual node errors to understand partial failures.\n\n"
|
||||
"ANALYSIS CHECKLIST:\n"
|
||||
"1. READ graph_info.description FIRST - this tells you what the user intended to accomplish\n"
|
||||
"2. Check overall_status.graph_error - if present, the entire execution failed\n"
|
||||
"3. Look for nodes with 'Output', 'Post', 'Create', 'Send', 'Publish', 'Generate' in their block_name\n"
|
||||
"4. Check if these critical blocks have empty recent_outputs arrays - this indicates failure\n"
|
||||
"5. Look for AgentExecutorBlock (sub-agents) with no outputs - this suggests sub-task failures\n"
|
||||
"6. Count how many nodes produced outputs vs total nodes - low ratio suggests problems\n"
|
||||
"7. MOST IMPORTANT: Does the execution outcome match what graph_info.description promised?\n\n"
|
||||
"INTENTION-BASED EVALUATION:\n"
|
||||
"- If description mentions 'blog writing' → did it create blog content?\n"
|
||||
"- If description mentions 'email automation' → were emails actually sent?\n"
|
||||
"- If description mentions 'data analysis' → were analysis results produced?\n"
|
||||
"- If description mentions 'content generation' → was content actually generated?\n"
|
||||
"- If description mentions 'social media posting' → were posts actually made?\n"
|
||||
"- Match the outputs to the stated intention, not just technical completion\n\n"
|
||||
"Write 1-3 sentences about what the user accomplished, such as:\n"
|
||||
"- 'I analyzed your resume and provided detailed feedback for the IT industry.'\n"
|
||||
"- 'I couldn't analyze your resume because the input was just nonsensical text.'\n"
|
||||
"- 'I failed to complete the task due to missing API access.'\n"
|
||||
"- 'I couldn't complete the task because critical steps failed to produce any results.'\n"
|
||||
"- 'I failed to generate the content you requested due to missing API access.'\n"
|
||||
"- 'I extracted key information from your documents and organized it into a summary.'\n"
|
||||
"- 'The task failed to run due to system configuration issues.'\n\n"
|
||||
"Focus on what ACTUALLY happened, not what was attempted."
|
||||
"- 'The task failed because the blog post creation step didn't produce any output.'\n\n"
|
||||
"BE CRITICAL: If the graph's intended purpose (from description) wasn't achieved, report this as a failure even if status is 'completed'."
|
||||
),
|
||||
},
|
||||
]
|
||||
@@ -197,6 +233,7 @@ async def generate_activity_status_for_execution(
|
||||
logger.debug(
|
||||
f"Generated activity status for {graph_exec_id}: {activity_status}"
|
||||
)
|
||||
|
||||
return activity_status
|
||||
|
||||
except Exception as e:
|
||||
|
||||
115
autogpt_platform/backend/backend/executor/cluster_lock.py
Normal file
115
autogpt_platform/backend/backend/executor/cluster_lock.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClusterLock:
|
||||
"""Simple Redis-based distributed lock for preventing duplicate execution."""
|
||||
|
||||
def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300):
|
||||
self.redis = redis
|
||||
self.key = key
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
|
||||
def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
|
||||
Returns:
|
||||
- owner_id (self.owner_id) if successfully acquired
|
||||
- different owner_id if someone else holds the lock
|
||||
- None if Redis is unavailable or other error
|
||||
"""
|
||||
try:
|
||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||
if success:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
current_value = self.redis.get(self.key)
|
||||
if current_value:
|
||||
current_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
return current_owner
|
||||
|
||||
# Key doesn't exist but we failed to set it - race condition or Redis issue
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.try_acquire failed for key {self.key}: {e}")
|
||||
return None
|
||||
|
||||
def refresh(self) -> bool:
|
||||
"""Refresh lock TTL if we still own it.
|
||||
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
is_rate_limited = (
|
||||
self._last_refresh > 0
|
||||
and (current_time - self._last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = self.redis.get(self.key)
|
||||
if not current_value:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
if is_rate_limited:
|
||||
return True
|
||||
|
||||
# Perform actual refresh
|
||||
if self.redis.expire(self.key, self.timeout):
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release the lock."""
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._last_refresh = 0.0
|
||||
507
autogpt_platform/backend/backend/executor/cluster_lock_test.py
Normal file
507
autogpt_platform/backend/backend/executor/cluster_lock_test.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
Integration tests for ClusterLock - Redis-based distributed locking.
|
||||
|
||||
Tests the complete lock lifecycle without mocking Redis to ensure
|
||||
real-world behavior is correct. Covers acquisition, refresh, expiry,
|
||||
contention, and error scenarios.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from threading import Thread
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_client():
|
||||
"""Get Redis client for testing using same config as backend."""
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
|
||||
# Use same config as backend but without decode_responses since ClusterLock needs raw bytes
|
||||
client = redis.Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=False, # ClusterLock needs raw bytes for ownership verification
|
||||
)
|
||||
|
||||
# Clean up any existing test keys
|
||||
try:
|
||||
for key in client.scan_iter(match="test_lock:*"):
|
||||
client.delete(key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lock_key():
|
||||
"""Generate unique lock key for each test."""
|
||||
return f"test_lock:{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def owner_id():
|
||||
"""Generate unique owner ID for each test."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestClusterLockBasic:
|
||||
"""Basic lock acquisition and release functionality."""
|
||||
|
||||
def test_lock_acquisition_success(self, redis_client, lock_key, owner_id):
|
||||
"""Test basic lock acquisition succeeds."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Lock should be acquired successfully
|
||||
result = lock.try_acquire()
|
||||
assert result == owner_id # Returns our owner_id when successfully acquired
|
||||
assert lock._last_refresh > 0
|
||||
|
||||
# Lock key should exist in Redis
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
assert redis_client.get(lock_key).decode("utf-8") == owner_id
|
||||
|
||||
def test_lock_acquisition_contention(self, redis_client, lock_key):
|
||||
"""Test second acquisition fails when lock is held."""
|
||||
owner1 = str(uuid.uuid4())
|
||||
owner2 = str(uuid.uuid4())
|
||||
|
||||
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=60)
|
||||
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=60)
|
||||
|
||||
# First lock should succeed
|
||||
result1 = lock1.try_acquire()
|
||||
assert result1 == owner1 # Successfully acquired, returns our owner_id
|
||||
|
||||
# Second lock should fail and return the first owner
|
||||
result2 = lock2.try_acquire()
|
||||
assert result2 == owner1 # Returns the current owner (first owner)
|
||||
assert lock2._last_refresh == 0
|
||||
|
||||
def test_lock_release_deletes_redis_key(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock release deletes Redis key and marks locally as released."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
assert lock._last_refresh > 0
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
|
||||
# Release should delete Redis key and mark locally as released
|
||||
lock.release()
|
||||
assert lock._last_refresh == 0
|
||||
assert lock._last_refresh == 0.0
|
||||
|
||||
# Redis key should be deleted for immediate release
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
# Another lock should be able to acquire immediately
|
||||
new_owner_id = str(uuid.uuid4())
|
||||
new_lock = ClusterLock(redis_client, lock_key, new_owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == new_owner_id
|
||||
|
||||
|
||||
class TestClusterLockRefresh:
|
||||
"""Lock refresh and TTL management."""
|
||||
|
||||
def test_lock_refresh_success(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock refresh extends TTL."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
original_ttl = redis_client.ttl(lock_key)
|
||||
|
||||
# Wait a bit then refresh
|
||||
time.sleep(1)
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# TTL should be reset to full timeout (allow for small timing differences)
|
||||
new_ttl = redis_client.ttl(lock_key)
|
||||
assert new_ttl >= original_ttl or new_ttl >= 58 # Allow for timing variance
|
||||
|
||||
def test_lock_refresh_rate_limiting(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh is rate-limited to timeout/10."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=100
|
||||
) # 100s timeout
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# First refresh should work
|
||||
assert lock.refresh() is True
|
||||
first_refresh_time = lock._last_refresh
|
||||
|
||||
# Immediate second refresh should be skipped (rate limited) but verify key exists
|
||||
assert lock.refresh() is True # Returns True but skips actual refresh
|
||||
assert lock._last_refresh == first_refresh_time # Time unchanged
|
||||
|
||||
def test_lock_refresh_verifies_existence_during_rate_limit(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test refresh verifies lock existence even during rate limiting."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=100)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Manually delete the key (simulates expiry or external deletion)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
# Refresh should detect missing key even during rate limit period
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_lock_refresh_ownership_lost(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh fails when ownership is lost."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Simulate another process taking the lock
|
||||
different_owner = str(uuid.uuid4())
|
||||
redis_client.set(lock_key, different_owner, ex=60)
|
||||
|
||||
# Force refresh past rate limit and verify it fails
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_lock_refresh_when_not_acquired(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh fails when lock was never acquired."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Refresh without acquiring should fail
|
||||
assert lock.refresh() is False
|
||||
|
||||
|
||||
class TestClusterLockExpiry:
|
||||
"""Lock expiry and timeout behavior."""
|
||||
|
||||
def test_lock_natural_expiry(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock expires naturally via Redis TTL."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=2
|
||||
) # 2 second timeout
|
||||
|
||||
lock.try_acquire()
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
|
||||
# Wait for expiry
|
||||
time.sleep(3)
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
# New lock with same key should succeed
|
||||
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == owner_id
|
||||
|
||||
def test_lock_refresh_prevents_expiry(self, redis_client, lock_key, owner_id):
|
||||
"""Test refreshing prevents lock from expiring."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=3
|
||||
) # 3 second timeout
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Wait and refresh before expiry
|
||||
time.sleep(1)
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# Wait beyond original timeout
|
||||
time.sleep(2.5)
|
||||
assert redis_client.exists(lock_key) == 1 # Should still exist
|
||||
|
||||
|
||||
class TestClusterLockConcurrency:
|
||||
"""Concurrent access patterns."""
|
||||
|
||||
def test_multiple_threads_contention(self, redis_client, lock_key):
|
||||
"""Test multiple threads competing for same lock."""
|
||||
num_threads = 5
|
||||
successful_acquisitions = []
|
||||
|
||||
def try_acquire_lock(thread_id):
|
||||
owner_id = f"thread_{thread_id}"
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
if lock.try_acquire() == owner_id:
|
||||
successful_acquisitions.append(thread_id)
|
||||
time.sleep(0.1) # Hold lock briefly
|
||||
lock.release()
|
||||
|
||||
threads = []
|
||||
for i in range(num_threads):
|
||||
thread = Thread(target=try_acquire_lock, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one thread should have acquired the lock
|
||||
assert len(successful_acquisitions) == 1
|
||||
|
||||
def test_sequential_lock_reuse(self, redis_client, lock_key):
|
||||
"""Test lock can be reused after natural expiry."""
|
||||
owners = [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
for i, owner_id in enumerate(owners):
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=1) # 1 second
|
||||
|
||||
assert lock.try_acquire() == owner_id
|
||||
time.sleep(1.5) # Wait for expiry
|
||||
|
||||
# Verify lock expired
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
def test_refresh_during_concurrent_access(self, redis_client, lock_key):
|
||||
"""Test lock refresh works correctly during concurrent access attempts."""
|
||||
owner1 = str(uuid.uuid4())
|
||||
owner2 = str(uuid.uuid4())
|
||||
|
||||
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=5)
|
||||
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=5)
|
||||
|
||||
# Thread 1 holds lock and refreshes
|
||||
assert lock1.try_acquire() == owner1
|
||||
|
||||
def refresh_continuously():
|
||||
for _ in range(10):
|
||||
lock1._last_refresh = 0 # Force refresh
|
||||
lock1.refresh()
|
||||
time.sleep(0.1)
|
||||
|
||||
def try_acquire_continuously():
|
||||
attempts = 0
|
||||
while attempts < 20:
|
||||
if lock2.try_acquire() == owner2:
|
||||
return True
|
||||
time.sleep(0.1)
|
||||
attempts += 1
|
||||
return False
|
||||
|
||||
refresh_thread = Thread(target=refresh_continuously)
|
||||
acquire_thread = Thread(target=try_acquire_continuously)
|
||||
|
||||
refresh_thread.start()
|
||||
acquire_thread.start()
|
||||
|
||||
refresh_thread.join()
|
||||
acquire_thread.join()
|
||||
|
||||
# Lock1 should still own the lock due to refreshes
|
||||
assert lock1._last_refresh > 0
|
||||
assert lock2._last_refresh == 0
|
||||
|
||||
|
||||
class TestClusterLockErrorHandling:
|
||||
"""Error handling and edge cases."""
|
||||
|
||||
def test_redis_connection_failure_on_acquire(self, lock_key, owner_id):
|
||||
"""Test graceful handling when Redis is unavailable during acquisition."""
|
||||
# Use invalid Redis connection
|
||||
bad_redis = redis.Redis(
|
||||
host="invalid_host", port=1234, socket_connect_timeout=1
|
||||
)
|
||||
lock = ClusterLock(bad_redis, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Should return None for Redis connection failures
|
||||
result = lock.try_acquire()
|
||||
assert result is None # Returns None when Redis fails
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_redis_connection_failure_on_refresh(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test graceful handling when Redis fails during refresh."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Acquire normally
|
||||
assert lock.try_acquire() == owner_id
|
||||
|
||||
# Replace Redis client with failing one
|
||||
lock.redis = redis.Redis(
|
||||
host="invalid_host", port=1234, socket_connect_timeout=1
|
||||
)
|
||||
|
||||
# Refresh should fail gracefully
|
||||
lock._last_refresh = 0 # Force refresh
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_invalid_lock_parameters(self, redis_client):
|
||||
"""Test validation of lock parameters."""
|
||||
owner_id = str(uuid.uuid4())
|
||||
|
||||
# All parameters are now simple - no validation needed
|
||||
# Just test basic construction works
|
||||
lock = ClusterLock(redis_client, "test_key", owner_id, timeout=60)
|
||||
assert lock.key == "test_key"
|
||||
assert lock.owner_id == owner_id
|
||||
assert lock.timeout == 60
|
||||
|
||||
def test_refresh_after_redis_key_deleted(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh behavior when Redis key is manually deleted."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Manually delete the key (simulates external deletion)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
# Refresh should fail and mark as not acquired
|
||||
lock._last_refresh = 0 # Force refresh
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
|
||||
class TestClusterLockDynamicRefreshInterval:
|
||||
"""Dynamic refresh interval based on timeout."""
|
||||
|
||||
def test_refresh_interval_calculation(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh interval is calculated as max(timeout/10, 1)."""
|
||||
test_cases = [
|
||||
(5, 1), # 5/10 = 0, but minimum is 1
|
||||
(10, 1), # 10/10 = 1
|
||||
(30, 3), # 30/10 = 3
|
||||
(100, 10), # 100/10 = 10
|
||||
(200, 20), # 200/10 = 20
|
||||
(1000, 100), # 1000/10 = 100
|
||||
]
|
||||
|
||||
for timeout, expected_interval in test_cases:
|
||||
lock = ClusterLock(
|
||||
redis_client, f"{lock_key}_{timeout}", owner_id, timeout=timeout
|
||||
)
|
||||
lock.try_acquire()
|
||||
|
||||
# Calculate expected interval using same logic as implementation
|
||||
refresh_interval = max(timeout // 10, 1)
|
||||
assert refresh_interval == expected_interval
|
||||
|
||||
# Test rate limiting works with calculated interval
|
||||
assert lock.refresh() is True
|
||||
first_refresh_time = lock._last_refresh
|
||||
|
||||
# Sleep less than interval - should be rate limited
|
||||
time.sleep(0.1)
|
||||
assert lock.refresh() is True
|
||||
assert lock._last_refresh == first_refresh_time # No actual refresh
|
||||
|
||||
|
||||
class TestClusterLockRealWorldScenarios:
|
||||
"""Real-world usage patterns."""
|
||||
|
||||
def test_execution_coordination_simulation(self, redis_client):
|
||||
"""Simulate graph execution coordination across multiple pods."""
|
||||
graph_exec_id = str(uuid.uuid4())
|
||||
lock_key = f"execution:{graph_exec_id}"
|
||||
|
||||
# Simulate 3 pods trying to execute same graph
|
||||
pods = [f"pod_{i}" for i in range(3)]
|
||||
execution_results = {}
|
||||
|
||||
def execute_graph(pod_id):
|
||||
"""Simulate graph execution with cluster lock."""
|
||||
lock = ClusterLock(redis_client, lock_key, pod_id, timeout=300)
|
||||
|
||||
if lock.try_acquire() == pod_id:
|
||||
# Simulate execution work
|
||||
execution_results[pod_id] = "executed"
|
||||
time.sleep(0.1)
|
||||
lock.release()
|
||||
else:
|
||||
execution_results[pod_id] = "rejected"
|
||||
|
||||
threads = []
|
||||
for pod_id in pods:
|
||||
thread = Thread(target=execute_graph, args=(pod_id,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one pod should have executed
|
||||
executed_count = sum(
|
||||
1 for result in execution_results.values() if result == "executed"
|
||||
)
|
||||
rejected_count = sum(
|
||||
1 for result in execution_results.values() if result == "rejected"
|
||||
)
|
||||
|
||||
assert executed_count == 1
|
||||
assert rejected_count == 2
|
||||
|
||||
def test_long_running_execution_with_refresh(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test lock maintains ownership during long execution with periodic refresh."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=30
|
||||
) # 30 second timeout, refresh interval = max(30//10, 1) = 3 seconds
|
||||
|
||||
def long_execution_with_refresh():
|
||||
"""Simulate long-running execution with periodic refresh."""
|
||||
assert lock.try_acquire() == owner_id
|
||||
|
||||
# Simulate 10 seconds of work with refreshes every 2 seconds
|
||||
# This respects rate limiting - actual refreshes will happen at 0s, 3s, 6s, 9s
|
||||
try:
|
||||
for i in range(5): # 5 iterations * 2 seconds = 10 seconds total
|
||||
time.sleep(2)
|
||||
refresh_success = lock.refresh()
|
||||
assert refresh_success is True, f"Refresh failed at iteration {i}"
|
||||
return "completed"
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
# Should complete successfully without losing lock
|
||||
result = long_execution_with_refresh()
|
||||
assert result == "completed"
|
||||
|
||||
def test_graceful_degradation_pattern(self, redis_client, lock_key):
|
||||
"""Test graceful degradation when Redis becomes unavailable."""
|
||||
owner_id = str(uuid.uuid4())
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=3
|
||||
) # Use shorter timeout
|
||||
|
||||
# Normal operation
|
||||
assert lock.try_acquire() == owner_id
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# Simulate Redis becoming unavailable
|
||||
original_redis = lock.redis
|
||||
lock.redis = redis.Redis(
|
||||
host="invalid_host",
|
||||
port=1234,
|
||||
socket_connect_timeout=1,
|
||||
decode_responses=False,
|
||||
)
|
||||
|
||||
# Should degrade gracefully
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
# Restore Redis and verify can acquire again
|
||||
lock.redis = original_redis
|
||||
# Wait for original lock to expire (use longer wait for 3s timeout)
|
||||
time.sleep(4)
|
||||
|
||||
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == owner_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run specific test for quick validation
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -9,6 +9,7 @@ from backend.data.execution import (
|
||||
get_execution_kv_data,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_graph_executions_count,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
@@ -28,11 +29,13 @@ from backend.data.graph import (
|
||||
get_node,
|
||||
)
|
||||
from backend.data.notifications import (
|
||||
clear_all_user_notification_batches,
|
||||
create_or_add_to_user_notification_batch,
|
||||
empty_user_notification_batch,
|
||||
get_all_batches_by_type,
|
||||
get_user_notification_batch,
|
||||
get_user_notification_oldest_message_in_batch,
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
@@ -71,7 +74,6 @@ async def _get_credits(user_id: str) -> int:
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
@@ -111,6 +113,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
# Executions
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_executions_count = _(get_graph_executions_count)
|
||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution = _(get_node_execution)
|
||||
@@ -147,10 +150,12 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = _(empty_user_notification_batch)
|
||||
remove_notifications_from_batch = _(remove_notifications_from_batch)
|
||||
get_all_batches_by_type = _(get_all_batches_by_type)
|
||||
get_user_notification_batch = _(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
@@ -179,6 +184,7 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
|
||||
# Executions
|
||||
get_graph_executions = _(d.get_graph_executions)
|
||||
get_graph_executions_count = _(d.get_graph_executions_count)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
@@ -241,10 +247,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
# Notifications
|
||||
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
||||
create_or_add_to_user_notification_batch = (
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = d.empty_user_notification_batch
|
||||
remove_notifications_from_batch = d.remove_notifications_from_batch
|
||||
get_all_batches_by_type = d.get_all_batches_by_type
|
||||
get_user_notification_batch = d.get_user_notification_batch
|
||||
get_user_notification_oldest_message_in_batch = (
|
||||
|
||||
@@ -3,16 +3,42 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
import sentry_sdk
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockOutputEntry,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
@@ -25,50 +51,21 @@ from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockOutputEntry,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
parse_execution_output,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
@@ -84,13 +81,24 @@ from backend.util.decorator import (
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.retry import (
|
||||
continuous_retry,
|
||||
func_retry,
|
||||
send_rate_limited_discord_alert,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = TruncatedLogger(_logger, prefix="[GraphExecutor]")
|
||||
settings = Settings()
|
||||
@@ -106,6 +114,7 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
|
||||
@@ -117,10 +126,14 @@ def init_worker():
|
||||
|
||||
|
||||
def execute_graph(
|
||||
graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event
|
||||
graph_exec_entry: "GraphExecutionEntry",
|
||||
cancel_event: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||||
return _tls.processor.on_graph_execution(graph_exec_entry, cancel_event)
|
||||
return _tls.processor.on_graph_execution(
|
||||
graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -177,6 +190,7 @@ async def execute_node(
|
||||
_input_data.inputs = input_data
|
||||
if nodes_input_masks:
|
||||
_input_data.nodes_input_masks = nodes_input_masks
|
||||
_input_data.user_id = user_id
|
||||
input_data = _input_data.model_dump()
|
||||
data.inputs = input_data
|
||||
|
||||
@@ -211,6 +225,23 @@ async def execute_node(
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
|
||||
output_size = 0
|
||||
|
||||
# sentry tracking nonsense to get user counts for blocks because isolation scopes don't work :(
|
||||
scope = sentry_sdk.get_current_scope()
|
||||
|
||||
# save the tags
|
||||
original_user = scope._user
|
||||
original_tags = dict(scope._tags) if scope._tags else {}
|
||||
# Set user ID for error tracking
|
||||
scope.set_user({"id": user_id})
|
||||
|
||||
scope.set_tag("graph_id", graph_id)
|
||||
scope.set_tag("node_id", node_id)
|
||||
scope.set_tag("block_name", node_block.name)
|
||||
scope.set_tag("block_id", node_block.id)
|
||||
for k, v in (data.user_context or UserContext(timezone="UTC")).model_dump().items():
|
||||
scope.set_tag(f"user_context.{k}", v)
|
||||
|
||||
try:
|
||||
async for output_name, output_data in node_block.execute(
|
||||
input_data, **extra_exec_kwargs
|
||||
@@ -219,6 +250,12 @@ async def execute_node(
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.debug("Node produced output", **{output_name: output_data})
|
||||
yield output_name, output_data
|
||||
except Exception:
|
||||
# Capture exception WITH context still set before restoring scope
|
||||
sentry_sdk.capture_exception(scope=scope)
|
||||
sentry_sdk.flush() # Ensure it's sent before we restore scope
|
||||
# Re-raise to maintain normal error flow
|
||||
raise
|
||||
finally:
|
||||
# Ensure credentials are released even if execution fails
|
||||
if creds_lock and (await creds_lock.locked()) and (await creds_lock.owned()):
|
||||
@@ -233,6 +270,10 @@ async def execute_node(
|
||||
execution_stats.input_size = input_size
|
||||
execution_stats.output_size = output_size
|
||||
|
||||
# Restore scope AFTER error has been captured
|
||||
scope._user = original_user
|
||||
scope._tags = original_tags
|
||||
|
||||
|
||||
async def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
@@ -429,7 +470,7 @@ class ExecutionProcessor:
|
||||
graph_id=node_exec.graph_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
block_name=b.name if (b := get_block(node_exec.block_id)) else "-",
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
@@ -557,7 +598,6 @@ class ExecutionProcessor:
|
||||
await persist_output(
|
||||
"error", str(stats.error) or type(stats.error).__name__
|
||||
)
|
||||
|
||||
return status
|
||||
|
||||
@func_retry
|
||||
@@ -583,6 +623,7 @@ class ExecutionProcessor:
|
||||
self,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -641,6 +682,7 @@ class ExecutionProcessor:
|
||||
cancel=cancel,
|
||||
log_metadata=log_metadata,
|
||||
execution_stats=exec_stats,
|
||||
cluster_lock=cluster_lock,
|
||||
)
|
||||
exec_stats.walltime += timing_info.wall_time
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
@@ -742,6 +784,7 @@ class ExecutionProcessor:
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
execution_stats: GraphExecutionStats,
|
||||
cluster_lock: ClusterLock,
|
||||
) -> ExecutionStatus:
|
||||
"""
|
||||
Returns:
|
||||
@@ -927,7 +970,7 @@ class ExecutionProcessor:
|
||||
and execution_queue.empty()
|
||||
and (running_node_execution or running_node_evaluation)
|
||||
):
|
||||
# There is nothing to execute, and no output to process, let's relax for a while.
|
||||
cluster_lock.refresh()
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
@@ -969,16 +1012,31 @@ class ExecutionProcessor:
|
||||
if isinstance(e, Exception)
|
||||
else Exception(f"{e.__class__.__name__}: {e}")
|
||||
)
|
||||
if not execution_stats.error:
|
||||
execution_stats.error = str(error)
|
||||
|
||||
known_errors = (InsufficientBalanceError, ModerationError)
|
||||
if isinstance(error, known_errors):
|
||||
execution_stats.error = str(error)
|
||||
return ExecutionStatus.FAILED
|
||||
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
log_metadata.exception(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
|
||||
# Send rate-limited Discord alert for unknown/unexpected errors
|
||||
send_rate_limited_discord_alert(
|
||||
"graph_execution",
|
||||
error,
|
||||
"unknown_error",
|
||||
f"🚨 **Unknown Graph Execution Error**\n"
|
||||
f"User: {graph_exec.user_id}\n"
|
||||
f"Graph ID: {graph_exec.graph_id}\n"
|
||||
f"Execution ID: {graph_exec.graph_exec_id}\n"
|
||||
f"Error Type: {type(error).__name__}\n"
|
||||
f"Error: {str(error)[:200]}{'...' if len(str(error)) > 200 else ''}\n",
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
@@ -1153,9 +1211,9 @@ class ExecutionProcessor:
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||||
f"Current balance: ${e.balance/100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount)/100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall)/100:.2f}\n"
|
||||
f"Current balance: ${e.balance / 100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
|
||||
@@ -1202,9 +1260,9 @@ class ExecutionProcessor:
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD/100:.2f}\n"
|
||||
f"Current balance: ${current_balance/100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost/100:.2f}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
|
||||
f"Current balance: ${current_balance / 100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
@@ -1219,6 +1277,7 @@ class ExecutionManager(AppProcess):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
@@ -1228,6 +1287,8 @@ class ExecutionManager(AppProcess):
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._execution_locks = {}
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
@@ -1432,20 +1493,78 @@ class ExecutionManager(AppProcess):
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
user_id = graph_exec_entry.user_id
|
||||
graph_id = graph_exec_entry.graph_id
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}, user_id={user_id}"
|
||||
)
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
# TODO: Make this check cluster-wide, prevent duplicate runs across executor pods.
|
||||
logger.error(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
|
||||
# Check user rate limit before processing
|
||||
try:
|
||||
# Only check executions from the last 24 hours for performance
|
||||
current_running_count = get_db_client().get_graph_executions_count(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
statuses=[ExecutionStatus.RUNNING],
|
||||
created_time_gte=datetime.now(timezone.utc) - timedelta(hours=24),
|
||||
)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
|
||||
if (
|
||||
current_running_count
|
||||
>= settings.config.max_concurrent_graph_executions_per_user
|
||||
):
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Rate limit exceeded for user {user_id} on graph {graph_id}: "
|
||||
f"{current_running_count}/{settings.config.max_concurrent_graph_executions_per_user} running executions"
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.service_name}] Failed to check rate limit for user {user_id}: {e}, proceeding with execution"
|
||||
)
|
||||
# If rate limit check fails, proceed to avoid blocking executions
|
||||
|
||||
# Check for local duplicate execution first
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running locally; rejecting duplicate."
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide execution lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"exec_lock:{graph_exec_id}",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
# Either someone else has it or Redis is unavailable
|
||||
if current_owner is not None:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
self._execution_locks[graph_exec_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"[{self.service_name}] Acquired cluster lock for {graph_exec_id} with executor {self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
|
||||
future = self.executor.submit(execute_graph, graph_exec_entry, cancel_event)
|
||||
future = self.executor.submit(
|
||||
execute_graph, graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
self._update_prompt_metrics()
|
||||
|
||||
@@ -1464,6 +1583,10 @@ class ExecutionManager(AppProcess):
|
||||
f"[{self.service_name}] Error in run completion callback: {e}"
|
||||
)
|
||||
finally:
|
||||
# Release the cluster-wide execution lock
|
||||
if graph_exec_id in self._execution_locks:
|
||||
self._execution_locks[graph_exec_id].release()
|
||||
del self._execution_locks[graph_exec_id]
|
||||
self._cleanup_completed_runs()
|
||||
|
||||
future.add_done_callback(_on_run_done)
|
||||
@@ -1546,6 +1669,10 @@ class ExecutionManager(AppProcess):
|
||||
f"{prefix} ⏳ Still waiting for {len(self.active_graph_runs)} executions: {ids}"
|
||||
)
|
||||
|
||||
for graph_exec_id in self.active_graph_runs:
|
||||
if lock := self._execution_locks.get(graph_exec_id):
|
||||
lock.refresh()
|
||||
|
||||
time.sleep(wait_interval)
|
||||
waited += wait_interval
|
||||
|
||||
@@ -1563,6 +1690,15 @@ class ExecutionManager(AppProcess):
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
|
||||
|
||||
# Release remaining execution locks
|
||||
try:
|
||||
for lock in self._execution_locks.values():
|
||||
lock.release()
|
||||
self._execution_locks.clear()
|
||||
logger.info(f"{prefix} ✅ Released execution locks")
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} ⚠️ Failed to release all locks: {e}")
|
||||
|
||||
# Disconnect the run execution consumer
|
||||
self._stop_message_consumers(
|
||||
self.run_thread,
|
||||
@@ -1668,15 +1804,18 @@ def update_graph_execution_state(
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def synchronized(key: str, timeout: int = 60):
|
||||
async def synchronized(key: str, timeout: int = settings.config.cluster_lock_timeout):
|
||||
r = await redis.get_redis_async()
|
||||
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
lock: AsyncRedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
await lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if await lock.locked() and await lock.owned():
|
||||
await lock.release()
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release lock for key {key}: {e}")
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
|
||||
@@ -4,7 +4,7 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, Mapping, Optional, cast
|
||||
from typing import Mapping, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
@@ -20,6 +20,9 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import prisma
|
||||
|
||||
# Import dynamic field utilities from centralized location
|
||||
from backend.data.dynamic_fields import merge_execution_input
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
@@ -39,7 +42,6 @@ from backend.util.clients import (
|
||||
)
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
@@ -186,195 +188,7 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
|
||||
|
||||
# ============ Execution Input Helpers ============ #
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Delimiters
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
_DELIMS = (LIST_SPLIT, DICT_SPLIT, OBJC_SPLIT)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Tokenisation utilities
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _next_delim(s: str) -> tuple[str | None, int]:
|
||||
"""
|
||||
Return the *earliest* delimiter appearing in `s` and its index.
|
||||
|
||||
If none present → (None, -1).
|
||||
"""
|
||||
first: str | None = None
|
||||
pos = len(s) # sentinel: larger than any real index
|
||||
for d in _DELIMS:
|
||||
i = s.find(d)
|
||||
if 0 <= i < pos:
|
||||
first, pos = d, i
|
||||
return first, (pos if first else -1)
|
||||
|
||||
|
||||
def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Convert the raw path string (starting with a delimiter) into
|
||||
[ (delimiter, identifier), … ] or None if the syntax is malformed.
|
||||
"""
|
||||
tokens: list[tuple[str, str]] = []
|
||||
while path:
|
||||
# 1. Which delimiter starts this chunk?
|
||||
delim = next((d for d in _DELIMS if path.startswith(d)), None)
|
||||
if delim is None:
|
||||
return None # invalid syntax
|
||||
|
||||
# 2. Slice off the delimiter, then up to the next delimiter (or EOS)
|
||||
path = path[len(delim) :]
|
||||
nxt_delim, pos = _next_delim(path)
|
||||
token, path = (
|
||||
path[: pos if pos != -1 else len(path)],
|
||||
path[pos if pos != -1 else len(path) :],
|
||||
)
|
||||
if token == "":
|
||||
return None # empty identifier is invalid
|
||||
tokens.append((delim, token))
|
||||
return tokens
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Public API – parsing (flattened ➜ concrete)
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | None:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path)
|
||||
returns **None**.
|
||||
"""
|
||||
base_name, data = output
|
||||
|
||||
# Exact match → whole object
|
||||
if name == base_name:
|
||||
return data
|
||||
|
||||
# Must start with the expected name
|
||||
if not name.startswith(base_name):
|
||||
return None
|
||||
path = name[len(base_name) :]
|
||||
if not path:
|
||||
return None # nothing left to parse
|
||||
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: JsonValue = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
return None
|
||||
if not isinstance(cur, list) or idx >= len(cur):
|
||||
return None
|
||||
cur = cur[idx]
|
||||
|
||||
elif delim == DICT_SPLIT:
|
||||
if not isinstance(cur, dict) or ident not in cur:
|
||||
return None
|
||||
cur = cur[ident]
|
||||
|
||||
elif delim == OBJC_SPLIT:
|
||||
if not hasattr(cur, ident):
|
||||
return None
|
||||
cur = getattr(cur, ident)
|
||||
|
||||
else:
|
||||
return None # unreachable
|
||||
|
||||
return cur
|
||||
|
||||
|
||||
def _assign(container: Any, tokens: list[tuple[str, str]], value: Any) -> Any:
|
||||
"""
|
||||
Recursive helper that *returns* the (possibly new) container with
|
||||
`value` assigned along the remaining `tokens` path.
|
||||
"""
|
||||
if not tokens:
|
||||
return value # leaf reached
|
||||
|
||||
delim, ident = tokens[0]
|
||||
rest = tokens[1:]
|
||||
|
||||
# ---------- list ----------
|
||||
if delim == LIST_SPLIT:
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
raise ValueError("index must be an integer")
|
||||
|
||||
if container is None:
|
||||
container = []
|
||||
elif not isinstance(container, list):
|
||||
container = list(container) if hasattr(container, "__iter__") else []
|
||||
|
||||
while len(container) <= idx:
|
||||
container.append(None)
|
||||
container[idx] = _assign(container[idx], rest, value)
|
||||
return container
|
||||
|
||||
# ---------- dict ----------
|
||||
if delim == DICT_SPLIT:
|
||||
if container is None:
|
||||
container = {}
|
||||
elif not isinstance(container, dict):
|
||||
container = dict(container) if hasattr(container, "items") else {}
|
||||
container[ident] = _assign(container.get(ident), rest, value)
|
||||
return container
|
||||
|
||||
# ---------- object ----------
|
||||
if delim == OBJC_SPLIT:
|
||||
if container is None or not isinstance(container, MockObject):
|
||||
container = MockObject()
|
||||
setattr(
|
||||
container,
|
||||
ident,
|
||||
_assign(getattr(container, ident, None), rest, value),
|
||||
)
|
||||
return container
|
||||
|
||||
return value # unreachable
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Reconstruct nested objects from a *flattened* dict of key → value.
|
||||
|
||||
Raises ValueError on syntactically invalid list indices.
|
||||
"""
|
||||
merged: BlockInput = {}
|
||||
|
||||
for key, value in data.items():
|
||||
# Split off the base name (before the first delimiter, if any)
|
||||
delim, pos = _next_delim(key)
|
||||
if delim is None:
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
base, path = key[:pos], key[pos:]
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
# Invalid key; treat as scalar under the raw name
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
merged[base] = _assign(merged.get(base), tokens, value)
|
||||
|
||||
data.update(merged)
|
||||
return data
|
||||
# Dynamic field utilities are now imported from backend.data.dynamic_fields
|
||||
|
||||
|
||||
def validate_exec(
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import cast
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.executor.utils import merge_execution_input, parse_execution_output
|
||||
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
|
||||
|
||||
@@ -151,7 +151,10 @@ class IntegrationCredentialsManager:
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
await _lock.release()
|
||||
try:
|
||||
await _lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release OAuth refresh lock: {e}")
|
||||
|
||||
credentials = fresh_credentials
|
||||
return credentials
|
||||
@@ -184,7 +187,10 @@ class IntegrationCredentialsManager:
|
||||
yield
|
||||
finally:
|
||||
if (await lock.locked()) and (await lock.owned()):
|
||||
await lock.release()
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release credentials lock: {e}")
|
||||
|
||||
async def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
|
||||
@@ -25,7 +25,11 @@ from backend.data.notifications import (
|
||||
get_summary_params_type,
|
||||
)
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import generate_unsubscribe_link
|
||||
from backend.data.user import (
|
||||
disable_all_user_notifications,
|
||||
generate_unsubscribe_link,
|
||||
set_user_email_verification,
|
||||
)
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.logging import TruncatedLogger
|
||||
@@ -38,7 +42,7 @@ from backend.util.service import (
|
||||
endpoint_to_sync,
|
||||
expose,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[NotificationManager]")
|
||||
settings = Settings()
|
||||
@@ -124,6 +128,12 @@ def get_routing_key(event_type: NotificationType) -> str:
|
||||
|
||||
def queue_notification(event: NotificationEventModel) -> NotificationResult:
|
||||
"""Queue a notification - exposed method for other services to call"""
|
||||
# Disable in production
|
||||
if settings.config.app_env == AppEnvironment.PRODUCTION:
|
||||
return NotificationResult(
|
||||
success=True,
|
||||
message="Queueing notifications is disabled in production",
|
||||
)
|
||||
try:
|
||||
logger.debug(f"Received Request to queue {event=}")
|
||||
|
||||
@@ -151,6 +161,12 @@ def queue_notification(event: NotificationEventModel) -> NotificationResult:
|
||||
|
||||
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult:
|
||||
"""Queue a notification - exposed method for other services to call"""
|
||||
# Disable in production
|
||||
if settings.config.app_env == AppEnvironment.PRODUCTION:
|
||||
return NotificationResult(
|
||||
success=True,
|
||||
message="Queueing notifications is disabled in production",
|
||||
)
|
||||
try:
|
||||
logger.debug(f"Received Request to queue {event=}")
|
||||
|
||||
@@ -213,6 +229,9 @@ class NotificationManager(AppService):
|
||||
|
||||
@expose
|
||||
async def queue_weekly_summary(self):
|
||||
# disable in prod
|
||||
if settings.config.app_env == AppEnvironment.PRODUCTION:
|
||||
return
|
||||
# Use the existing event loop instead of creating a new one with asyncio.run()
|
||||
asyncio.create_task(self._queue_weekly_summary())
|
||||
|
||||
@@ -226,7 +245,9 @@ class NotificationManager(AppService):
|
||||
logger.info(
|
||||
f"Querying for active users between {start_time} and {current_time}"
|
||||
)
|
||||
users = await get_database_manager_async_client().get_active_user_ids_in_timerange(
|
||||
users = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_active_user_ids_in_timerange(
|
||||
end_time=current_time.isoformat(),
|
||||
start_time=start_time.isoformat(),
|
||||
)
|
||||
@@ -253,6 +274,9 @@ class NotificationManager(AppService):
|
||||
async def process_existing_batches(
|
||||
self, notification_types: list[NotificationType]
|
||||
):
|
||||
# disable in prod
|
||||
if settings.config.app_env == AppEnvironment.PRODUCTION:
|
||||
return
|
||||
# Use the existing event loop instead of creating a new process
|
||||
asyncio.create_task(self._process_existing_batches(notification_types))
|
||||
|
||||
@@ -266,15 +290,15 @@ class NotificationManager(AppService):
|
||||
|
||||
for notification_type in notification_types:
|
||||
# Get all batches for this notification type
|
||||
batches = (
|
||||
await get_database_manager_async_client().get_all_batches_by_type(
|
||||
notification_type
|
||||
)
|
||||
)
|
||||
batches = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_all_batches_by_type(notification_type)
|
||||
|
||||
for batch in batches:
|
||||
# Check if batch has aged out
|
||||
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
|
||||
oldest_message = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_oldest_message_in_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -289,9 +313,9 @@ class NotificationManager(AppService):
|
||||
|
||||
# If batch has aged out, process it
|
||||
if oldest_message.created_at + max_delay < current_time:
|
||||
recipient_email = await get_database_manager_async_client().get_user_email_by_id(
|
||||
batch.user_id
|
||||
)
|
||||
recipient_email = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_email_by_id(batch.user_id)
|
||||
|
||||
if not recipient_email:
|
||||
logger.error(
|
||||
@@ -308,21 +332,25 @@ class NotificationManager(AppService):
|
||||
f"User {batch.user_id} does not want to receive {notification_type} notifications"
|
||||
)
|
||||
# Clear the batch
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
|
||||
batch_data = await get_database_manager_async_client().get_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
batch_data = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_batch(batch.user_id, notification_type)
|
||||
|
||||
if not batch_data or not batch_data.notifications:
|
||||
logger.error(
|
||||
f"Batch data not found for user {batch.user_id}"
|
||||
)
|
||||
# Clear the batch
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
@@ -358,7 +386,9 @@ class NotificationManager(AppService):
|
||||
)
|
||||
|
||||
# Clear the batch
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -413,15 +443,13 @@ class NotificationManager(AppService):
|
||||
self, user_id: str, event_type: NotificationType
|
||||
) -> bool:
|
||||
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
|
||||
validated_email = (
|
||||
await get_database_manager_async_client().get_user_email_verification(
|
||||
user_id
|
||||
)
|
||||
)
|
||||
validated_email = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_email_verification(user_id)
|
||||
preference = (
|
||||
await get_database_manager_async_client().get_user_notification_preference(
|
||||
user_id
|
||||
)
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_preference(user_id)
|
||||
).preferences.get(event_type, True)
|
||||
# only if both are true, should we email this person
|
||||
return validated_email and preference
|
||||
@@ -437,7 +465,9 @@ class NotificationManager(AppService):
|
||||
|
||||
try:
|
||||
# Get summary data from the database
|
||||
summary_data = await get_database_manager_async_client().get_user_execution_summary_data(
|
||||
summary_data = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_execution_summary_data(
|
||||
user_id=user_id,
|
||||
start_time=params.start_date,
|
||||
end_time=params.end_date,
|
||||
@@ -524,13 +554,13 @@ class NotificationManager(AppService):
|
||||
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
|
||||
) -> bool:
|
||||
|
||||
await get_database_manager_async_client().create_or_add_to_user_notification_batch(
|
||||
user_id, event_type, event
|
||||
)
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).create_or_add_to_user_notification_batch(user_id, event_type, event)
|
||||
|
||||
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
|
||||
user_id, event_type
|
||||
)
|
||||
oldest_message = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_oldest_message_in_batch(user_id, event_type)
|
||||
if not oldest_message:
|
||||
logger.error(
|
||||
f"Batch for user {user_id} and type {event_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
|
||||
@@ -580,11 +610,9 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.debug(f"Processing immediate notification: {event}")
|
||||
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
)
|
||||
recipient_email = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -619,11 +647,9 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.info(f"Processing batch notification: {event}")
|
||||
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
)
|
||||
recipient_email = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -642,11 +668,9 @@ class NotificationManager(AppService):
|
||||
if not should_send:
|
||||
logger.info("Batch not old enough to send")
|
||||
return False
|
||||
batch = (
|
||||
await get_database_manager_async_client().get_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
)
|
||||
)
|
||||
batch = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_batch(event.user_id, event.type)
|
||||
if not batch or not batch.notifications:
|
||||
logger.error(f"Batch not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -657,6 +681,7 @@ class NotificationManager(AppService):
|
||||
get_notif_data_type(db_event.type)
|
||||
].model_validate(
|
||||
{
|
||||
"id": db_event.id, # Include ID from database
|
||||
"user_id": event.user_id,
|
||||
"type": db_event.type,
|
||||
"data": db_event.data,
|
||||
@@ -679,6 +704,9 @@ class NotificationManager(AppService):
|
||||
chunk_sent = False
|
||||
for attempt_size in [chunk_size, 50, 25, 10, 5, 1]:
|
||||
chunk = batch_messages[i : i + attempt_size]
|
||||
chunk_ids = [
|
||||
msg.id for msg in chunk if msg.id
|
||||
] # Extract IDs for removal
|
||||
|
||||
try:
|
||||
# Try to render the email to check its size
|
||||
@@ -705,6 +733,23 @@ class NotificationManager(AppService):
|
||||
user_unsub_link=unsub_link,
|
||||
)
|
||||
|
||||
# Remove successfully sent notifications immediately
|
||||
if chunk_ids:
|
||||
try:
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).remove_notifications_from_batch(
|
||||
event.user_id, event.type, chunk_ids
|
||||
)
|
||||
logger.info(
|
||||
f"Removed {len(chunk_ids)} sent notifications from batch"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to remove sent notifications: {e}"
|
||||
)
|
||||
# Continue anyway - better to risk duplicates than lose emails
|
||||
|
||||
# Track successful sends
|
||||
successfully_sent_count += len(chunk)
|
||||
|
||||
@@ -722,13 +767,137 @@ class NotificationManager(AppService):
|
||||
i += len(chunk)
|
||||
chunk_sent = True
|
||||
break
|
||||
else:
|
||||
# Message is too large even after size reduction
|
||||
if attempt_size == 1:
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Single notification exceeds email size limit "
|
||||
f"({len(test_message):,} chars > {MAX_EMAIL_SIZE:,} chars). "
|
||||
f"Removing permanently from batch - will not retry."
|
||||
)
|
||||
|
||||
# Remove the oversized notification permanently - it will NEVER fit
|
||||
if chunk_ids:
|
||||
try:
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).remove_notifications_from_batch(
|
||||
event.user_id, event.type, chunk_ids
|
||||
)
|
||||
logger.info(
|
||||
f"Removed oversized notification {chunk_ids[0]} from batch permanently"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to remove oversized notification: {e}"
|
||||
)
|
||||
|
||||
failed_indices.append(i)
|
||||
i += 1
|
||||
chunk_sent = True
|
||||
break
|
||||
# Try smaller chunk size
|
||||
continue
|
||||
except Exception as e:
|
||||
# Check if it's a Postmark API error
|
||||
if attempt_size == 1:
|
||||
# Even single notification is too large
|
||||
logger.error(
|
||||
f"Single notification too large to send: {e}. "
|
||||
f"Skipping notification at index {i}"
|
||||
)
|
||||
# Single notification failed - determine the actual cause
|
||||
error_message = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
|
||||
# Check for HTTP 406 - Inactive recipient (common in Postmark errors)
|
||||
if "406" in error_message or "inactive" in error_message:
|
||||
logger.warning(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Recipient marked as inactive by Postmark. "
|
||||
f"Error: {e}. Disabling ALL notifications for this user."
|
||||
)
|
||||
|
||||
# 1. Mark email as unverified
|
||||
try:
|
||||
await set_user_email_verification(
|
||||
event.user_id, False
|
||||
)
|
||||
logger.info(
|
||||
f"Set email verification to false for user {event.user_id}"
|
||||
)
|
||||
except Exception as deactivation_error:
|
||||
logger.error(
|
||||
f"Failed to deactivate email for user {event.user_id}: "
|
||||
f"{deactivation_error}"
|
||||
)
|
||||
|
||||
# 2. Disable all notification preferences
|
||||
try:
|
||||
await disable_all_user_notifications(event.user_id)
|
||||
logger.info(
|
||||
f"Disabled all notification preferences for user {event.user_id}"
|
||||
)
|
||||
except Exception as disable_error:
|
||||
logger.error(
|
||||
f"Failed to disable notification preferences: {disable_error}"
|
||||
)
|
||||
|
||||
# 3. Clear ALL notification batches for this user
|
||||
try:
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).clear_all_user_notification_batches(event.user_id)
|
||||
logger.info(
|
||||
f"Cleared ALL notification batches for user {event.user_id}"
|
||||
)
|
||||
except Exception as remove_error:
|
||||
logger.error(
|
||||
f"Failed to clear batches for inactive recipient: {remove_error}"
|
||||
)
|
||||
|
||||
# Stop processing - we've nuked everything for this user
|
||||
return True
|
||||
# Check for HTTP 422 - Malformed data
|
||||
elif (
|
||||
"422" in error_message
|
||||
or "unprocessable" in error_message
|
||||
):
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Malformed notification data rejected by Postmark. "
|
||||
f"Error: {e}. Removing from batch permanently."
|
||||
)
|
||||
|
||||
# Remove from batch - 422 means bad data that won't fix itself
|
||||
if chunk_ids:
|
||||
try:
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).remove_notifications_from_batch(
|
||||
event.user_id, event.type, chunk_ids
|
||||
)
|
||||
logger.info(
|
||||
"Removed malformed notification from batch permanently"
|
||||
)
|
||||
except Exception as remove_error:
|
||||
logger.error(
|
||||
f"Failed to remove malformed notification: {remove_error}"
|
||||
)
|
||||
# Check if it's a ValueError for size limit
|
||||
elif (
|
||||
isinstance(e, ValueError)
|
||||
and "too large" in error_message
|
||||
):
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Notification size exceeds email limit. "
|
||||
f"Error: {e}. Skipping this notification."
|
||||
)
|
||||
# Other API errors
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Email API error ({error_type}): {e}. "
|
||||
f"Skipping this notification."
|
||||
)
|
||||
|
||||
failed_indices.append(i)
|
||||
i += 1
|
||||
chunk_sent = True
|
||||
@@ -742,18 +911,20 @@ class NotificationManager(AppService):
|
||||
failed_indices.append(i)
|
||||
i += 1
|
||||
|
||||
# Only empty the batch if ALL notifications were sent successfully
|
||||
if successfully_sent_count == len(batch_messages):
|
||||
# Check what remains in the batch (notifications are removed as sent)
|
||||
remaining_batch = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_batch(event.user_id, event.type)
|
||||
|
||||
if not remaining_batch or not remaining_batch.notifications:
|
||||
logger.info(
|
||||
f"Successfully sent all {successfully_sent_count} notifications, clearing batch"
|
||||
)
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
f"All {successfully_sent_count} notifications sent and removed from batch"
|
||||
)
|
||||
else:
|
||||
remaining_count = len(remaining_batch.notifications)
|
||||
logger.warning(
|
||||
f"Only sent {successfully_sent_count} of {len(batch_messages)} notifications. "
|
||||
f"Failed indices: {failed_indices}. Batch will be retained for retry."
|
||||
f"Sent {successfully_sent_count} notifications. "
|
||||
f"{remaining_count} remain in batch for retry due to errors."
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -771,11 +942,9 @@ class NotificationManager(AppService):
|
||||
|
||||
logger.info(f"Processing summary notification: {model}")
|
||||
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
)
|
||||
recipient_email = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
|
||||
@@ -0,0 +1,598 @@
|
||||
"""Tests for notification error handling in NotificationManager."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import AgentRunData, NotificationEventModel
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
|
||||
class TestNotificationErrorHandling:
|
||||
"""Test cases for notification error handling in NotificationManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def notification_manager(self):
|
||||
"""Create a NotificationManager instance for testing."""
|
||||
with patch("backend.notifications.notifications.AppService.__init__"):
|
||||
manager = NotificationManager()
|
||||
manager.email_sender = MagicMock()
|
||||
# Mock the _get_template method used by _process_batch
|
||||
template_mock = Mock()
|
||||
template_mock.base_template = "base"
|
||||
template_mock.subject_template = "subject"
|
||||
template_mock.body_template = "body"
|
||||
manager.email_sender._get_template = Mock(return_value=template_mock)
|
||||
# Mock the formatter
|
||||
manager.email_sender.formatter = Mock()
|
||||
manager.email_sender.formatter.format_email = Mock(
|
||||
return_value=("subject", "body content")
|
||||
)
|
||||
manager.email_sender.formatter.env = Mock()
|
||||
manager.email_sender.formatter.env.globals = {
|
||||
"base_url": "http://example.com"
|
||||
}
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_event(self):
|
||||
"""Create a sample batch event for testing."""
|
||||
return NotificationEventModel(
|
||||
type=NotificationType.AGENT_RUN,
|
||||
user_id="user_1",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
data=AgentRunData(
|
||||
agent_name="Test Agent",
|
||||
credits_used=10.0,
|
||||
execution_time=5.0,
|
||||
node_count=3,
|
||||
graph_id="graph_1",
|
||||
outputs=[],
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_notifications(self):
|
||||
"""Create sample batch notifications for testing."""
|
||||
notifications = []
|
||||
for i in range(3):
|
||||
notification = Mock()
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
return notifications
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_406_stops_all_processing_for_user(
|
||||
self, notification_manager, sample_batch_event
|
||||
):
|
||||
"""Test that 406 inactive recipient error stops ALL processing for that user."""
|
||||
with patch("backend.notifications.notifications.logger"), patch(
|
||||
"backend.notifications.notifications.set_user_email_verification",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_verification, patch(
|
||||
"backend.notifications.notifications.disable_all_user_notifications",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_disable_all, patch(
|
||||
"backend.notifications.notifications.get_database_manager_async_client"
|
||||
) as mock_db_client, patch(
|
||||
"backend.notifications.notifications.generate_unsubscribe_link"
|
||||
) as mock_unsub_link:
|
||||
|
||||
# Create batch of 5 notifications
|
||||
notifications = []
|
||||
for i in range(5):
|
||||
notification = Mock()
|
||||
notification.id = f"notif_{i}"
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
|
||||
# Setup mocks
|
||||
mock_db = mock_db_client.return_value
|
||||
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
|
||||
mock_db.get_user_notification_batch = AsyncMock(
|
||||
return_value=Mock(notifications=notifications)
|
||||
)
|
||||
mock_db.clear_all_user_notification_batches = AsyncMock()
|
||||
mock_db.remove_notifications_from_batch = AsyncMock()
|
||||
mock_unsub_link.return_value = "http://example.com/unsub"
|
||||
|
||||
# Mock internal methods
|
||||
notification_manager._should_email_user_based_on_preference = AsyncMock(
|
||||
return_value=True
|
||||
)
|
||||
notification_manager._should_batch = AsyncMock(return_value=True)
|
||||
notification_manager._parse_message = Mock(return_value=sample_batch_event)
|
||||
|
||||
# Track calls
|
||||
call_count = [0]
|
||||
|
||||
def send_side_effect(*args, **kwargs):
|
||||
data = kwargs.get("data", [])
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
current_call = call_count[0]
|
||||
call_count[0] += 1
|
||||
|
||||
# First two succeed, third hits 406
|
||||
if current_call < 2:
|
||||
return None
|
||||
else:
|
||||
raise Exception("Recipient marked as inactive (406)")
|
||||
# Force single processing
|
||||
raise Exception("Force single processing")
|
||||
|
||||
notification_manager.email_sender.send_templated.side_effect = (
|
||||
send_side_effect
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await notification_manager._process_batch(
|
||||
sample_batch_event.model_dump_json()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# Only 3 calls should have been made (2 successful, 1 failed with 406)
|
||||
assert call_count[0] == 3
|
||||
|
||||
# User should be deactivated
|
||||
mock_set_verification.assert_called_once_with("user_1", False)
|
||||
mock_disable_all.assert_called_once_with("user_1")
|
||||
mock_db.clear_all_user_notification_batches.assert_called_once_with(
|
||||
"user_1"
|
||||
)
|
||||
|
||||
# No further processing should occur after 406
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_422_permanently_removes_malformed_notification(
|
||||
self, notification_manager, sample_batch_event
|
||||
):
|
||||
"""Test that 422 error permanently removes the malformed notification from batch and continues with others."""
|
||||
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
|
||||
"backend.notifications.notifications.get_database_manager_async_client"
|
||||
) as mock_db_client, patch(
|
||||
"backend.notifications.notifications.generate_unsubscribe_link"
|
||||
) as mock_unsub_link:
|
||||
|
||||
# Create batch of 5 notifications
|
||||
notifications = []
|
||||
for i in range(5):
|
||||
notification = Mock()
|
||||
notification.id = f"notif_{i}"
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
|
||||
# Setup mocks
|
||||
mock_db = mock_db_client.return_value
|
||||
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
|
||||
mock_db.get_user_notification_batch = AsyncMock(
|
||||
side_effect=[
|
||||
Mock(notifications=notifications),
|
||||
Mock(notifications=[]), # Empty after processing
|
||||
]
|
||||
)
|
||||
mock_db.remove_notifications_from_batch = AsyncMock()
|
||||
mock_unsub_link.return_value = "http://example.com/unsub"
|
||||
|
||||
# Mock internal methods
|
||||
notification_manager._should_email_user_based_on_preference = AsyncMock(
|
||||
return_value=True
|
||||
)
|
||||
notification_manager._should_batch = AsyncMock(return_value=True)
|
||||
notification_manager._parse_message = Mock(return_value=sample_batch_event)
|
||||
|
||||
# Track calls
|
||||
call_count = [0]
|
||||
successful_indices = []
|
||||
removed_notification_ids = []
|
||||
|
||||
# Capture what gets removed
|
||||
def remove_side_effect(user_id, notif_type, notif_ids):
|
||||
removed_notification_ids.extend(notif_ids)
|
||||
return None
|
||||
|
||||
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
|
||||
|
||||
def send_side_effect(*args, **kwargs):
|
||||
data = kwargs.get("data", [])
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
current_call = call_count[0]
|
||||
call_count[0] += 1
|
||||
|
||||
# Index 2 has malformed data (422)
|
||||
if current_call == 2:
|
||||
raise Exception(
|
||||
"Unprocessable entity (422): Malformed email data"
|
||||
)
|
||||
else:
|
||||
successful_indices.append(current_call)
|
||||
return None
|
||||
# Force single processing
|
||||
raise Exception("Force single processing")
|
||||
|
||||
notification_manager.email_sender.send_templated.side_effect = (
|
||||
send_side_effect
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await notification_manager._process_batch(
|
||||
sample_batch_event.model_dump_json()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert call_count[0] == 5 # All 5 attempted
|
||||
assert len(successful_indices) == 4 # 4 succeeded (all except index 2)
|
||||
assert 2 not in successful_indices # Index 2 failed
|
||||
|
||||
# Verify 422 error was logged
|
||||
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
||||
assert any(
|
||||
"422" in call or "malformed" in call.lower() for call in error_calls
|
||||
)
|
||||
|
||||
# Verify all notifications were removed (4 successful + 1 malformed)
|
||||
assert mock_db.remove_notifications_from_batch.call_count == 5
|
||||
assert (
|
||||
"notif_2" in removed_notification_ids
|
||||
) # Malformed one was removed permanently
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_notification_permanently_removed(
|
||||
self, notification_manager, sample_batch_event
|
||||
):
|
||||
"""Test that oversized notifications are permanently removed from batch but others continue."""
|
||||
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
|
||||
"backend.notifications.notifications.get_database_manager_async_client"
|
||||
) as mock_db_client, patch(
|
||||
"backend.notifications.notifications.generate_unsubscribe_link"
|
||||
) as mock_unsub_link:
|
||||
|
||||
# Create batch of 5 notifications
|
||||
notifications = []
|
||||
for i in range(5):
|
||||
notification = Mock()
|
||||
notification.id = f"notif_{i}"
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
|
||||
# Setup mocks
|
||||
mock_db = mock_db_client.return_value
|
||||
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
|
||||
mock_db.get_user_notification_batch = AsyncMock(
|
||||
side_effect=[
|
||||
Mock(notifications=notifications),
|
||||
Mock(notifications=[]), # Empty after processing
|
||||
]
|
||||
)
|
||||
mock_db.remove_notifications_from_batch = AsyncMock()
|
||||
mock_unsub_link.return_value = "http://example.com/unsub"
|
||||
|
||||
# Mock internal methods
|
||||
notification_manager._should_email_user_based_on_preference = AsyncMock(
|
||||
return_value=True
|
||||
)
|
||||
notification_manager._should_batch = AsyncMock(return_value=True)
|
||||
notification_manager._parse_message = Mock(return_value=sample_batch_event)
|
||||
|
||||
# Override formatter to simulate oversized on index 3
|
||||
# original_format = notification_manager.email_sender.formatter.format_email
|
||||
|
||||
def format_side_effect(*args, **kwargs):
|
||||
# Check if we're formatting index 3
|
||||
data = kwargs.get("data", {}).get("notifications", [])
|
||||
if data and len(data) == 1:
|
||||
# Check notification content to identify index 3
|
||||
if any(
|
||||
"Test Agent 3" in str(n.data)
|
||||
for n in data
|
||||
if hasattr(n, "data")
|
||||
):
|
||||
# Return oversized message for index 3
|
||||
return ("subject", "x" * 5_000_000) # Over 4.5MB limit
|
||||
return ("subject", "normal sized content")
|
||||
|
||||
notification_manager.email_sender.formatter.format_email = Mock(
|
||||
side_effect=format_side_effect
|
||||
)
|
||||
|
||||
# Track calls
|
||||
successful_indices = []
|
||||
|
||||
def send_side_effect(*args, **kwargs):
|
||||
data = kwargs.get("data", [])
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
# Track which notification was sent based on content
|
||||
for i, notif in enumerate(notifications):
|
||||
if any(
|
||||
f"Test Agent {i}" in str(n.data)
|
||||
for n in data
|
||||
if hasattr(n, "data")
|
||||
):
|
||||
successful_indices.append(i)
|
||||
return None
|
||||
return None
|
||||
# Force single processing
|
||||
raise Exception("Force single processing")
|
||||
|
||||
notification_manager.email_sender.send_templated.side_effect = (
|
||||
send_side_effect
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await notification_manager._process_batch(
|
||||
sample_batch_event.model_dump_json()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert (
|
||||
len(successful_indices) == 4
|
||||
) # Only 4 sent (index 3 skipped due to size)
|
||||
assert 3 not in successful_indices # Index 3 was not sent
|
||||
|
||||
# Verify oversized error was logged
|
||||
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
||||
assert any(
|
||||
"exceeds email size limit" in call or "oversized" in call.lower()
|
||||
for call in error_calls
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generic_api_error_keeps_notification_for_retry(
|
||||
self, notification_manager, sample_batch_event
|
||||
):
|
||||
"""Test that generic API errors keep notifications in batch for retry while others continue."""
|
||||
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
|
||||
"backend.notifications.notifications.get_database_manager_async_client"
|
||||
) as mock_db_client, patch(
|
||||
"backend.notifications.notifications.generate_unsubscribe_link"
|
||||
) as mock_unsub_link:
|
||||
|
||||
# Create batch of 5 notifications
|
||||
notifications = []
|
||||
for i in range(5):
|
||||
notification = Mock()
|
||||
notification.id = f"notif_{i}"
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
|
||||
# Notification that failed with generic error
|
||||
failed_notifications = [notifications[1]] # Only index 1 remains for retry
|
||||
|
||||
# Setup mocks
|
||||
mock_db = mock_db_client.return_value
|
||||
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
|
||||
mock_db.get_user_notification_batch = AsyncMock(
|
||||
side_effect=[
|
||||
Mock(notifications=notifications),
|
||||
Mock(
|
||||
notifications=failed_notifications
|
||||
), # Failed ones remain for retry
|
||||
]
|
||||
)
|
||||
mock_db.remove_notifications_from_batch = AsyncMock()
|
||||
mock_unsub_link.return_value = "http://example.com/unsub"
|
||||
|
||||
# Mock internal methods
|
||||
notification_manager._should_email_user_based_on_preference = AsyncMock(
|
||||
return_value=True
|
||||
)
|
||||
notification_manager._should_batch = AsyncMock(return_value=True)
|
||||
notification_manager._parse_message = Mock(return_value=sample_batch_event)
|
||||
|
||||
# Track calls
|
||||
successful_indices = []
|
||||
failed_indices = []
|
||||
removed_notification_ids = []
|
||||
|
||||
# Capture what gets removed
|
||||
def remove_side_effect(user_id, notif_type, notif_ids):
|
||||
removed_notification_ids.extend(notif_ids)
|
||||
return None
|
||||
|
||||
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
|
||||
|
||||
def send_side_effect(*args, **kwargs):
|
||||
data = kwargs.get("data", [])
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
# Track which notification based on content
|
||||
for i, notif in enumerate(notifications):
|
||||
if any(
|
||||
f"Test Agent {i}" in str(n.data)
|
||||
for n in data
|
||||
if hasattr(n, "data")
|
||||
):
|
||||
# Index 1 has generic API error
|
||||
if i == 1:
|
||||
failed_indices.append(i)
|
||||
raise Exception("Network timeout - temporary failure")
|
||||
else:
|
||||
successful_indices.append(i)
|
||||
return None
|
||||
return None
|
||||
# Force single processing
|
||||
raise Exception("Force single processing")
|
||||
|
||||
notification_manager.email_sender.send_templated.side_effect = (
|
||||
send_side_effect
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await notification_manager._process_batch(
|
||||
sample_batch_event.model_dump_json()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert len(successful_indices) == 4 # 4 succeeded (0, 2, 3, 4)
|
||||
assert len(failed_indices) == 1 # 1 failed
|
||||
assert 1 in failed_indices # Index 1 failed
|
||||
|
||||
# Verify generic error was logged
|
||||
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
||||
assert any(
|
||||
"api error" in call.lower() or "skipping" in call.lower()
|
||||
for call in error_calls
|
||||
)
|
||||
|
||||
# Only successful ones should be removed from batch (failed one stays for retry)
|
||||
assert mock_db.remove_notifications_from_batch.call_count == 4
|
||||
assert (
|
||||
"notif_1" not in removed_notification_ids
|
||||
) # Failed one NOT removed (stays for retry)
|
||||
assert "notif_0" in removed_notification_ids # Successful one removed
|
||||
assert "notif_2" in removed_notification_ids # Successful one removed
|
||||
assert "notif_3" in removed_notification_ids # Successful one removed
|
||||
assert "notif_4" in removed_notification_ids # Successful one removed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_all_notifications_sent_successfully(
|
||||
self, notification_manager, sample_batch_event
|
||||
):
|
||||
"""Test successful batch processing where all notifications are sent without errors."""
|
||||
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
|
||||
"backend.notifications.notifications.get_database_manager_async_client"
|
||||
) as mock_db_client, patch(
|
||||
"backend.notifications.notifications.generate_unsubscribe_link"
|
||||
) as mock_unsub_link:
|
||||
|
||||
# Create batch of 5 notifications
|
||||
notifications = []
|
||||
for i in range(5):
|
||||
notification = Mock()
|
||||
notification.id = f"notif_{i}"
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
|
||||
# Setup mocks
|
||||
mock_db = mock_db_client.return_value
|
||||
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
|
||||
mock_db.get_user_notification_batch = AsyncMock(
|
||||
side_effect=[
|
||||
Mock(notifications=notifications),
|
||||
Mock(notifications=[]), # Empty after all sent successfully
|
||||
]
|
||||
)
|
||||
mock_db.remove_notifications_from_batch = AsyncMock()
|
||||
mock_unsub_link.return_value = "http://example.com/unsub"
|
||||
|
||||
# Mock internal methods
|
||||
notification_manager._should_email_user_based_on_preference = AsyncMock(
|
||||
return_value=True
|
||||
)
|
||||
notification_manager._should_batch = AsyncMock(return_value=True)
|
||||
notification_manager._parse_message = Mock(return_value=sample_batch_event)
|
||||
|
||||
# Track successful sends
|
||||
successful_indices = []
|
||||
removed_notification_ids = []
|
||||
|
||||
# Capture what gets removed
|
||||
def remove_side_effect(user_id, notif_type, notif_ids):
|
||||
removed_notification_ids.extend(notif_ids)
|
||||
return None
|
||||
|
||||
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
|
||||
|
||||
def send_side_effect(*args, **kwargs):
|
||||
data = kwargs.get("data", [])
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
# Track which notification was sent
|
||||
for i, notif in enumerate(notifications):
|
||||
if any(
|
||||
f"Test Agent {i}" in str(n.data)
|
||||
for n in data
|
||||
if hasattr(n, "data")
|
||||
):
|
||||
successful_indices.append(i)
|
||||
return None
|
||||
return None # Success
|
||||
# Force single processing
|
||||
raise Exception("Force single processing")
|
||||
|
||||
notification_manager.email_sender.send_templated.side_effect = (
|
||||
send_side_effect
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await notification_manager._process_batch(
|
||||
sample_batch_event.model_dump_json()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# All 5 notifications should be sent successfully
|
||||
assert len(successful_indices) == 5
|
||||
assert successful_indices == [0, 1, 2, 3, 4]
|
||||
|
||||
# All notifications should be removed from batch
|
||||
assert mock_db.remove_notifications_from_batch.call_count == 5
|
||||
assert len(removed_notification_ids) == 5
|
||||
for i in range(5):
|
||||
assert f"notif_{i}" in removed_notification_ids
|
||||
|
||||
# No errors should be logged
|
||||
assert mock_logger.error.call_count == 0
|
||||
|
||||
# Info message about successful sends should be logged
|
||||
info_calls = [call[0][0] for call in mock_logger.info.call_args_list]
|
||||
assert any("sent and removed" in call.lower() for call in info_calls)
|
||||
@@ -32,6 +32,7 @@ from backend.data.model import (
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
)
|
||||
from backend.data.onboarding import complete_webhook_trigger_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
@@ -63,7 +64,7 @@ class LoginResponse(BaseModel):
|
||||
state_token: str
|
||||
|
||||
|
||||
@router.get("/{provider}/login")
|
||||
@router.get("/{provider}/login", summary="Initiate OAuth flow")
|
||||
async def login(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to initiate an OAuth flow for")
|
||||
@@ -101,7 +102,7 @@ class CredentialsMetaResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback")
|
||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||
async def callback(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The target provider for this OAuth exchange")
|
||||
@@ -179,7 +180,7 @@ async def callback(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/credentials")
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
@@ -220,7 +221,9 @@ async def list_credentials_by_provider(
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{provider}/credentials/{cred_id}")
|
||||
@router.get(
|
||||
"/{provider}/credentials/{cred_id}", summary="Get Specific Credential By ID"
|
||||
)
|
||||
async def get_credential(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to retrieve credentials for")
|
||||
@@ -241,7 +244,7 @@ async def get_credential(
|
||||
return credential
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
async def create_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
provider: Annotated[
|
||||
@@ -367,6 +370,8 @@ async def webhook_ingress_generic(
|
||||
return
|
||||
|
||||
executions: list[Awaitable] = []
|
||||
await complete_webhook_trigger_step(user_id)
|
||||
|
||||
for node in webhook.triggered_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
class SecurityHeadersMiddleware:
|
||||
"""
|
||||
Middleware to add security headers to responses, with cache control
|
||||
disabled by default for all endpoints except those explicitly allowed.
|
||||
@@ -25,6 +23,8 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"/api/health",
|
||||
"/api/v1/health",
|
||||
"/api/status",
|
||||
"/api/blocks",
|
||||
"/api/v1/blocks",
|
||||
# Public store/marketplace pages (read-only)
|
||||
"/api/store/agents",
|
||||
"/api/v1/store/agents",
|
||||
@@ -49,7 +49,7 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
}
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
super().__init__(app)
|
||||
self.app = app
|
||||
# Compile regex patterns for wildcard matching
|
||||
self.cacheable_patterns = [
|
||||
re.compile(pattern.replace("*", "[^/]+"))
|
||||
@@ -72,26 +72,42 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response: Response = await call_next(request)
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Pure ASGI middleware implementation for better performance than BaseHTTPMiddleware."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Add general security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
# Extract path from scope
|
||||
path = scope["path"]
|
||||
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in request.url.path:
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow"
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
# Add security headers to the response
|
||||
headers = dict(message.get("headers", []))
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(request.url.path):
|
||||
response.headers["Cache-Control"] = (
|
||||
"no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
# Add general security headers (HTTP spec requires proper capitalization)
|
||||
headers[b"X-Content-Type-Options"] = b"nosniff"
|
||||
headers[b"X-Frame-Options"] = b"DENY"
|
||||
headers[b"X-XSS-Protection"] = b"1; mode=block"
|
||||
headers[b"Referrer-Policy"] = b"strict-origin-when-cross-origin"
|
||||
|
||||
return response
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in path:
|
||||
headers[b"X-Robots-Tag"] = b"noindex, nofollow"
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(path):
|
||||
headers[b"Cache-Control"] = (
|
||||
b"no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
headers[b"Pragma"] = b"no-cache"
|
||||
headers[b"Expires"] = b"0"
|
||||
|
||||
# Convert headers back to list format
|
||||
message["headers"] = list(headers.items())
|
||||
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import platform
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -11,6 +12,7 @@ import uvicorn
|
||||
from autogpt_libs.auth import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth import verify_settings as verify_auth_settings
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
@@ -70,6 +72,26 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
await backend.data.db.connect()
|
||||
|
||||
# Configure thread pool for FastAPI sync operation performance
|
||||
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
|
||||
# - Any endpoint defined with 'def' (not async def)
|
||||
# - Any dependency function defined with 'def' (not async def)
|
||||
# - Manual run_in_threadpool() calls (like JWT decoding)
|
||||
# Default pool size is only 40 threads, causing bottlenecks under high concurrency
|
||||
config = backend.util.settings.Config()
|
||||
try:
|
||||
import anyio.to_thread
|
||||
|
||||
anyio.to_thread.current_default_thread_limiter().total_tokens = (
|
||||
config.fastapi_thread_pool_size
|
||||
)
|
||||
logger.info(
|
||||
f"Thread pool size set to {config.fastapi_thread_pool_size} for sync endpoint/dependency performance"
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(f"Could not configure thread pool size: {e}")
|
||||
# Continue without thread pool configuration
|
||||
|
||||
# Ensure SDK auto-registration is patched before initializing blocks
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
@@ -140,6 +162,9 @@ app = fastapi.FastAPI(
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
# Add GZip compression middleware for large responses (like /api/blocks)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=50_000) # 50KB threshold
|
||||
|
||||
# Add 401 responses to authenticated endpoints in OpenAPI spec
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
@@ -273,12 +298,28 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
uvicorn.run(
|
||||
server_app,
|
||||
host=backend.util.settings.Config().agent_api_host,
|
||||
port=backend.util.settings.Config().agent_api_port,
|
||||
log_config=None,
|
||||
)
|
||||
config = backend.util.settings.Config()
|
||||
|
||||
# Configure uvicorn with performance optimizations from Kludex FastAPI tips
|
||||
uvicorn_config = {
|
||||
"app": server_app,
|
||||
"host": config.agent_api_host,
|
||||
"port": config.agent_api_port,
|
||||
"log_config": None,
|
||||
# Use httptools for HTTP parsing (if available)
|
||||
"http": "httptools",
|
||||
# Only use uvloop on Unix-like systems (not supported on Windows)
|
||||
"loop": "uvloop" if platform.system() != "Windows" else "auto",
|
||||
}
|
||||
|
||||
# Only add debug in local environment (not supported in all uvicorn versions)
|
||||
if config.app_env == backend.util.settings.AppEnvironment.LOCAL:
|
||||
import os
|
||||
|
||||
# Enable asyncio debug mode via environment variable
|
||||
os.environ["PYTHONASYNCIODEBUG"] = "1"
|
||||
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
|
||||
@@ -24,6 +24,8 @@ from fastapi import (
|
||||
Security,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
@@ -85,6 +87,7 @@ from backend.server.model import (
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.json import dumps
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
@@ -263,13 +266,10 @@ async def is_onboarding_enabled():
|
||||
########################################################
|
||||
|
||||
|
||||
@cached()
|
||||
def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
|
||||
def _compute_blocks_sync() -> str:
|
||||
"""
|
||||
Get cached blocks with thundering herd protection.
|
||||
|
||||
Uses sync_cache decorator to prevent multiple concurrent requests
|
||||
from all executing the expensive block loading operation.
|
||||
Synchronous function to compute blocks data.
|
||||
This does the heavy lifting: instantiate 226+ blocks, compute costs, serialize.
|
||||
"""
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
@@ -279,11 +279,27 @@ def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
|
||||
for block_class in block_classes.values():
|
||||
block_instance = block_class()
|
||||
if not block_instance.disabled:
|
||||
# Get costs for this specific block class without creating another instance
|
||||
costs = get_block_cost(block_instance)
|
||||
result.append({**block_instance.to_dict(), "costs": costs})
|
||||
# Convert BlockCost BaseModel objects to dictionaries for JSON serialization
|
||||
costs_dict = [
|
||||
cost.model_dump() if isinstance(cost, BaseModel) else cost
|
||||
for cost in costs
|
||||
]
|
||||
result.append({**block_instance.to_dict(), "costs": costs_dict})
|
||||
|
||||
return result
|
||||
# Use our JSON utility which properly handles complex types through to_dict conversion
|
||||
return dumps(result)
|
||||
|
||||
|
||||
@cached()
|
||||
async def _get_cached_blocks() -> str:
|
||||
"""
|
||||
Async cached function with thundering herd protection.
|
||||
On cache miss: runs heavy work in thread pool
|
||||
On cache hit: returns cached string immediately (no thread pool needed)
|
||||
"""
|
||||
# Only run in thread pool on cache miss - cache hits return immediately
|
||||
return await run_in_threadpool(_compute_blocks_sync)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -291,9 +307,28 @@ def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
|
||||
summary="List available blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(requires_user)],
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"items": {"additionalProperties": True, "type": "object"},
|
||||
"type": "array",
|
||||
"title": "Response Getv1List Available Blocks",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
return _get_cached_blocks()
|
||||
async def get_graph_blocks() -> Response:
|
||||
# Cache hit: returns immediately, Cache miss: runs in thread pool
|
||||
content = await _get_cached_blocks()
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
|
||||
@@ -7,6 +7,7 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
@@ -86,6 +87,11 @@ async def review_submission(
|
||||
StoreSubmission with updated review information
|
||||
"""
|
||||
try:
|
||||
already_approved = (
|
||||
await backend.server.v2.store.db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
)
|
||||
submission = await backend.server.v2.store.db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
@@ -93,6 +99,11 @@ async def review_submission(
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches when the request is approved as it updates what is shown on the store
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.exception("Error reviewing submission: %s", e)
|
||||
|
||||
@@ -118,6 +118,17 @@ def get_blocks(
|
||||
)
|
||||
|
||||
|
||||
def get_block_by_id(block_id: str) -> BlockInfo | None:
|
||||
"""
|
||||
Get a specific block by its ID.
|
||||
"""
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
if block.id == block_id:
|
||||
return block.get_info()
|
||||
return None
|
||||
|
||||
|
||||
def search_blocks(
|
||||
include_blocks: bool = True,
|
||||
include_integrations: bool = True,
|
||||
|
||||
@@ -53,16 +53,6 @@ class ProviderResponse(BaseModel):
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
# Search
|
||||
class SearchRequest(BaseModel):
|
||||
search_query: str | None = None
|
||||
filter: list[FilterType] | None = None
|
||||
by_creator: list[str] | None = None
|
||||
search_id: str | None = None
|
||||
page: int | None = None
|
||||
page_size: int | None = None
|
||||
|
||||
|
||||
class SearchBlocksResponse(BaseModel):
|
||||
blocks: BlockResponse
|
||||
total_block_count: int
|
||||
|
||||
@@ -110,6 +110,25 @@ async def get_blocks(
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/blocks/batch",
|
||||
summary="Get specific blocks",
|
||||
response_model=list[builder_model.BlockInfo],
|
||||
)
|
||||
async def get_specific_blocks(
|
||||
block_ids: Annotated[list[str], fastapi.Query()],
|
||||
) -> list[builder_model.BlockInfo]:
|
||||
"""
|
||||
Get specific blocks by their IDs.
|
||||
"""
|
||||
blocks = []
|
||||
for block_id in block_ids:
|
||||
block = builder_db.get_block_by_id(block_id)
|
||||
if block:
|
||||
blocks.append(block)
|
||||
return blocks
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
summary="Get Builder integration providers",
|
||||
@@ -128,30 +147,34 @@ async def get_providers(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
# Not using post method because on frontend, orval doesn't support Infinite Query with POST method.
|
||||
@router.get(
|
||||
"/search",
|
||||
summary="Builder search",
|
||||
tags=["store", "private"],
|
||||
response_model=builder_model.SearchResponse,
|
||||
)
|
||||
async def search(
|
||||
options: builder_model.SearchRequest,
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
search_query: Annotated[str | None, fastapi.Query()] = None,
|
||||
filter: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||
search_id: Annotated[str | None, fastapi.Query()] = None,
|
||||
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||
page: Annotated[int, fastapi.Query()] = 1,
|
||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||
) -> builder_model.SearchResponse:
|
||||
"""
|
||||
Search for blocks (including integrations), marketplace agents, and user library agents.
|
||||
"""
|
||||
# If no filters are provided, then we will return all types
|
||||
if not options.filter:
|
||||
options.filter = [
|
||||
if not filter:
|
||||
filter = [
|
||||
"blocks",
|
||||
"integrations",
|
||||
"marketplace_agents",
|
||||
"my_agents",
|
||||
]
|
||||
options.search_query = sanitize_query(options.search_query)
|
||||
options.page = options.page or 1
|
||||
options.page_size = options.page_size or 50
|
||||
search_query = sanitize_query(search_query)
|
||||
|
||||
# Blocks&Integrations
|
||||
blocks = builder_model.SearchBlocksResponse(
|
||||
@@ -162,13 +185,13 @@ async def search(
|
||||
total_block_count=0,
|
||||
total_integration_count=0,
|
||||
)
|
||||
if "blocks" in options.filter or "integrations" in options.filter:
|
||||
if "blocks" in filter or "integrations" in filter:
|
||||
blocks = builder_db.search_blocks(
|
||||
include_blocks="blocks" in options.filter,
|
||||
include_integrations="integrations" in options.filter,
|
||||
query=options.search_query or "",
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
include_blocks="blocks" in filter,
|
||||
include_integrations="integrations" in filter,
|
||||
query=search_query or "",
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
# Library Agents
|
||||
@@ -176,12 +199,12 @@ async def search(
|
||||
agents=[],
|
||||
pagination=Pagination.empty(),
|
||||
)
|
||||
if "my_agents" in options.filter:
|
||||
if "my_agents" in filter:
|
||||
my_agents = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=options.search_query,
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
search_term=search_query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
# Marketplace Agents
|
||||
@@ -189,12 +212,12 @@ async def search(
|
||||
agents=[],
|
||||
pagination=Pagination.empty(),
|
||||
)
|
||||
if "marketplace_agents" in options.filter:
|
||||
if "marketplace_agents" in filter:
|
||||
marketplace_agents = await store_db.get_store_agents(
|
||||
creators=options.by_creator,
|
||||
search_query=options.search_query,
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
creators=by_creator,
|
||||
search_query=search_query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
more_pages = False
|
||||
@@ -214,7 +237,7 @@ async def search(
|
||||
"marketplace_agents": marketplace_agents.pagination.total_items,
|
||||
"my_agents": my_agents.pagination.total_items,
|
||||
},
|
||||
page=options.page,
|
||||
page=page,
|
||||
more_pages=more_pages,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.settings import Config
|
||||
@@ -61,11 +61,11 @@ async def list_library_agents(
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||
raise store_exceptions.DatabaseError("Invalid pagination input")
|
||||
raise DatabaseError("Invalid pagination input")
|
||||
|
||||
if search_term and len(search_term.strip()) > 100:
|
||||
logger.warning(f"Search term too long: {repr(search_term)}")
|
||||
raise store_exceptions.DatabaseError("Search term is too long")
|
||||
raise DatabaseError("Search term is too long")
|
||||
|
||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
@@ -101,7 +101,9 @@ async def list_library_agents(
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
@@ -141,7 +143,7 @@ async def list_library_agents(
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agents: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agents") from e
|
||||
raise DatabaseError("Failed to fetch library agents") from e
|
||||
|
||||
|
||||
async def list_favorite_library_agents(
|
||||
@@ -170,7 +172,7 @@ async def list_favorite_library_agents(
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||
raise store_exceptions.DatabaseError("Invalid pagination input")
|
||||
raise DatabaseError("Invalid pagination input")
|
||||
|
||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
@@ -185,7 +187,9 @@ async def list_favorite_library_agents(
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
@@ -225,9 +229,7 @@ async def list_favorite_library_agents(
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching favorite library agents: {e}")
|
||||
raise store_exceptions.DatabaseError(
|
||||
"Failed to fetch favorite library agents"
|
||||
) from e
|
||||
raise DatabaseError("Failed to fetch favorite library agents") from e
|
||||
|
||||
|
||||
async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent:
|
||||
@@ -269,7 +271,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agent: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
|
||||
raise DatabaseError("Failed to fetch library agent") from e
|
||||
|
||||
|
||||
async def get_library_agent_by_store_version_id(
|
||||
@@ -334,7 +336,7 @@ async def get_library_agent_by_graph_id(
|
||||
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agent by graph ID: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
|
||||
raise DatabaseError("Failed to fetch library agent") from e
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
@@ -417,7 +419,9 @@ async def create_library_agent(
|
||||
}
|
||||
},
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
for graph_entry in graph_entries
|
||||
)
|
||||
@@ -473,9 +477,7 @@ async def update_agent_version_in_library(
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating agent version in library: {e}")
|
||||
raise store_exceptions.DatabaseError(
|
||||
"Failed to update agent version in library"
|
||||
) from e
|
||||
raise DatabaseError("Failed to update agent version in library") from e
|
||||
|
||||
|
||||
async def update_library_agent(
|
||||
@@ -538,7 +540,7 @@ async def update_library_agent(
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating library agent: {str(e)}")
|
||||
raise store_exceptions.DatabaseError("Failed to update library agent") from e
|
||||
raise DatabaseError("Failed to update library agent") from e
|
||||
|
||||
|
||||
async def delete_library_agent(
|
||||
@@ -566,7 +568,7 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting library agent: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to delete library agent") from e
|
||||
raise DatabaseError("Failed to delete library agent") from e
|
||||
|
||||
|
||||
async def add_store_agent_to_library(
|
||||
@@ -642,7 +644,9 @@ async def add_store_agent_to_library(
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
@@ -655,7 +659,7 @@ async def add_store_agent_to_library(
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error adding agent to library: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to add agent to library") from e
|
||||
raise DatabaseError("Failed to add agent to library") from e
|
||||
|
||||
|
||||
##############################################
|
||||
@@ -689,7 +693,7 @@ async def list_presets(
|
||||
logger.warning(
|
||||
"Invalid pagination input: page=%d, page_size=%d", page, page_size
|
||||
)
|
||||
raise store_exceptions.DatabaseError("Invalid pagination parameters")
|
||||
raise DatabaseError("Invalid pagination parameters")
|
||||
|
||||
query_filter: prisma.types.AgentPresetWhereInput = {
|
||||
"userId": user_id,
|
||||
@@ -725,7 +729,7 @@ async def list_presets(
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error getting presets: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fetch presets") from e
|
||||
raise DatabaseError("Failed to fetch presets") from e
|
||||
|
||||
|
||||
async def get_preset(
|
||||
@@ -755,7 +759,7 @@ async def get_preset(
|
||||
return library_model.LibraryAgentPreset.from_db(preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error getting preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fetch preset") from e
|
||||
raise DatabaseError("Failed to fetch preset") from e
|
||||
|
||||
|
||||
async def create_preset(
|
||||
@@ -805,7 +809,7 @@ async def create_preset(
|
||||
return library_model.LibraryAgentPreset.from_db(new_preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to create preset") from e
|
||||
raise DatabaseError("Failed to create preset") from e
|
||||
|
||||
|
||||
async def create_preset_from_graph_execution(
|
||||
@@ -943,7 +947,7 @@ async def update_preset(
|
||||
return library_model.LibraryAgentPreset.from_db(updated)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to update preset") from e
|
||||
raise DatabaseError("Failed to update preset") from e
|
||||
|
||||
|
||||
async def set_preset_webhook(
|
||||
@@ -989,7 +993,7 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to delete preset") from e
|
||||
raise DatabaseError("Failed to delete preset") from e
|
||||
|
||||
|
||||
async def fork_library_agent(
|
||||
@@ -1017,7 +1021,7 @@ async def fork_library_agent(
|
||||
# TODO: once we have open/closed sourced agents this needs to be enabled ~kcze
|
||||
# + update library/agents/[id]/page.tsx agent actions
|
||||
# if not original_agent.can_access_graph:
|
||||
# raise store_exceptions.DatabaseError(
|
||||
# raise DatabaseError(
|
||||
# f"User {user_id} cannot access library agent graph {library_agent_id}"
|
||||
# )
|
||||
|
||||
@@ -1031,4 +1035,4 @@ async def fork_library_agent(
|
||||
return (await create_library_agent(new_graph, user_id))[0]
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error cloning library agent: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fork library agent") from e
|
||||
raise DatabaseError("Failed to fork library agent") from e
|
||||
|
||||
@@ -177,7 +177,9 @@ async def test_add_agent_to_library(mocker):
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
include=library_agent_include(
|
||||
"test-user", include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from fastapi.responses import Response
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -221,7 +221,7 @@ async def add_marketplace_agent_to_library(
|
||||
"to add to library"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except store_exceptions.DatabaseError as e:
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Database error while adding agent to library: {e}", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@@ -275,7 +275,7 @@ async def update_library_agent(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except store_exceptions.DatabaseError as e:
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Database error while updating library agent: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
||||
76
autogpt_platform/backend/backend/server/v2/store/cache.py
Normal file
76
autogpt_platform/backend/backend/server/v2/store/cache.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.server.v2.store.db
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
##############################################
|
||||
|
||||
|
||||
def clear_all_caches():
|
||||
"""Clear all caches."""
|
||||
_get_cached_store_agents.cache_clear()
|
||||
_get_cached_agent_details.cache_clear()
|
||||
_get_cached_store_creators.cache_clear()
|
||||
_get_cached_creator_details.cache_clear()
|
||||
|
||||
|
||||
# Cache store agents list for 5 minutes
|
||||
# Different cache entries for different query combinations
|
||||
@cached(maxsize=5000, ttl_seconds=300)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: str | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
|
||||
|
||||
# Cache creators list for 5 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store creators."""
|
||||
return await backend.server.v2.store.db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual creator details for 5 minutes
|
||||
@cached(maxsize=100, ttl_seconds=300)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
@@ -25,6 +25,7 @@ from backend.data.notifications import (
|
||||
NotificationEventModel,
|
||||
)
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -70,8 +71,7 @@ async def get_store_agents(
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
sanitized_query = sanitize_query(search_query)
|
||||
|
||||
search_term = sanitize_query(search_query)
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
@@ -80,10 +80,10 @@ async def get_store_agents(
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
if sanitized_query:
|
||||
if search_term:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"agent_name": {"contains": search_term, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_term, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
@@ -142,9 +142,25 @@ async def get_store_agents(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agents: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch store agents"
|
||||
) from e
|
||||
raise DatabaseError("Failed to fetch store agents") from e
|
||||
# TODO: commenting this out as we concerned about potential db load issues
|
||||
# finally:
|
||||
# if search_term:
|
||||
# await log_search_term(search_query=search_term)
|
||||
|
||||
|
||||
async def log_search_term(search_query: str):
|
||||
"""Log a search term to the database"""
|
||||
|
||||
# Anonymize the data by preventing correlation with other logs
|
||||
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
try:
|
||||
await prisma.models.SearchTerms.prisma().create(
|
||||
data={"searchTerm": search_query, "createdDate": date}
|
||||
)
|
||||
except Exception as e:
|
||||
# Fail silently here so that logging search terms doesn't break the app
|
||||
logger.error(f"Error logging search term: {e}")
|
||||
|
||||
|
||||
async def get_store_agent_details(
|
||||
@@ -237,9 +253,7 @@ async def get_store_agent_details(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agent details: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent details"
|
||||
) from e
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
|
||||
|
||||
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
@@ -266,9 +280,7 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent"
|
||||
) from e
|
||||
raise DatabaseError("Failed to fetch agent") from e
|
||||
|
||||
|
||||
async def get_store_agent_by_version_id(
|
||||
@@ -308,9 +320,7 @@ async def get_store_agent_by_version_id(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agent details: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent details"
|
||||
) from e
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
|
||||
|
||||
async def get_store_creators(
|
||||
@@ -336,9 +346,7 @@ async def get_store_creators(
|
||||
# Sanitize and validate search query by escaping special characters
|
||||
sanitized_query = search_query.strip()
|
||||
if not sanitized_query or len(sanitized_query) > 100: # Reasonable length limit
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Invalid search query"
|
||||
)
|
||||
raise DatabaseError("Invalid search query")
|
||||
|
||||
# Escape special SQL characters
|
||||
sanitized_query = (
|
||||
@@ -364,11 +372,9 @@ async def get_store_creators(
|
||||
try:
|
||||
# Validate pagination parameters
|
||||
if not isinstance(page, int) or page < 1:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Invalid page number"
|
||||
)
|
||||
raise DatabaseError("Invalid page number")
|
||||
if not isinstance(page_size, int) or page_size < 1 or page_size > 100:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError("Invalid page size")
|
||||
raise DatabaseError("Invalid page size")
|
||||
|
||||
# Get total count for pagination using sanitized where clause
|
||||
total = await prisma.models.Creator.prisma().count(
|
||||
@@ -423,9 +429,7 @@ async def get_store_creators(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store creators: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch store creators"
|
||||
) from e
|
||||
raise DatabaseError("Failed to fetch store creators") from e
|
||||
|
||||
|
||||
async def get_store_creator_details(
|
||||
@@ -460,9 +464,7 @@ async def get_store_creator_details(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store creator details: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch creator details"
|
||||
) from e
|
||||
raise DatabaseError("Failed to fetch creator details") from e
|
||||
|
||||
|
||||
async def get_store_submissions(
|
||||
@@ -725,7 +727,21 @@ async def create_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
changes_summary=changes_summary,
|
||||
)
|
||||
|
||||
except prisma.errors.UniqueViolationError as exc:
|
||||
# Attempt to check if the error was due to the slug field being unique
|
||||
error_str = str(exc)
|
||||
if "slug" in error_str.lower():
|
||||
logger.debug(
|
||||
f"Slug '{slug}' is already in use by another agent (agent_id: {agent_id}) for user {user_id}"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.SlugAlreadyInUseError(
|
||||
f"The URL slug '{slug}' is already in use by another one of your agents. Please choose a different slug."
|
||||
) from exc
|
||||
else:
|
||||
# Reraise as a generic database error for other unique violations
|
||||
raise DatabaseError(
|
||||
f"Unique constraint violated (not slug): {error_str}"
|
||||
) from exc
|
||||
except (
|
||||
backend.server.v2.store.exceptions.AgentNotFoundError,
|
||||
backend.server.v2.store.exceptions.ListingExistsError,
|
||||
@@ -733,9 +749,7 @@ async def create_store_submission(
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating store submission: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store submission"
|
||||
) from e
|
||||
raise DatabaseError("Failed to create store submission") from e
|
||||
|
||||
|
||||
async def edit_store_submission(
|
||||
@@ -856,9 +870,7 @@ async def edit_store_submission(
|
||||
)
|
||||
|
||||
if not updated_version:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update store listing version"
|
||||
)
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
@@ -894,9 +906,7 @@ async def edit_store_submission(
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error editing store submission: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to edit store submission"
|
||||
) from e
|
||||
raise DatabaseError("Failed to edit store submission") from e
|
||||
|
||||
|
||||
async def create_store_version(
|
||||
@@ -1011,9 +1021,7 @@ async def create_store_version(
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create new store version"
|
||||
) from e
|
||||
raise DatabaseError("Failed to create new store version") from e
|
||||
|
||||
|
||||
async def create_store_review(
|
||||
@@ -1053,9 +1061,7 @@ async def create_store_review(
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating store review: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store review"
|
||||
) from e
|
||||
raise DatabaseError("Failed to create store review") from e
|
||||
|
||||
|
||||
async def get_user_profile(
|
||||
@@ -1079,9 +1085,7 @@ async def get_user_profile(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user profile: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to get user profile"
|
||||
) from e
|
||||
raise DatabaseError("Failed to get user profile") from e
|
||||
|
||||
|
||||
async def update_profile(
|
||||
@@ -1118,7 +1122,7 @@ async def update_profile(
|
||||
logger.error(
|
||||
f"Unauthorized update attempt for profile {existing_profile.id} by user {user_id}"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
raise DatabaseError(
|
||||
f"Unauthorized update attempt for profile {existing_profile.id} by user {user_id}"
|
||||
)
|
||||
|
||||
@@ -1143,9 +1147,7 @@ async def update_profile(
|
||||
)
|
||||
if updated_profile is None:
|
||||
logger.error(f"Failed to update profile for user {user_id}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update profile"
|
||||
)
|
||||
raise DatabaseError("Failed to update profile")
|
||||
|
||||
return backend.server.v2.store.model.CreatorDetails(
|
||||
name=updated_profile.name,
|
||||
@@ -1160,9 +1162,7 @@ async def update_profile(
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating profile: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update profile"
|
||||
) from e
|
||||
raise DatabaseError("Failed to update profile") from e
|
||||
|
||||
|
||||
async def get_my_agents(
|
||||
@@ -1230,9 +1230,7 @@ async def get_my_agents(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting my agents: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch my agents"
|
||||
) from e
|
||||
raise DatabaseError("Failed to fetch my agents") from e
|
||||
|
||||
|
||||
async def get_agent(store_listing_version_id: str) -> GraphModel:
|
||||
@@ -1494,7 +1492,7 @@ async def review_store_submission(
|
||||
)
|
||||
|
||||
if not submission:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
raise DatabaseError(
|
||||
f"Failed to update store listing version {store_listing_version_id}"
|
||||
)
|
||||
|
||||
@@ -1609,9 +1607,7 @@ async def review_store_submission(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Could not create store submission review: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store submission review"
|
||||
) from e
|
||||
raise DatabaseError("Failed to create store submission review") from e
|
||||
|
||||
|
||||
async def get_admin_listings_with_versions(
|
||||
@@ -1790,6 +1786,27 @@ async def get_admin_listings_with_versions(
|
||||
)
|
||||
|
||||
|
||||
async def check_submission_already_approved(
|
||||
store_listing_version_id: str,
|
||||
) -> bool:
|
||||
"""Check the submission status of a store listing version."""
|
||||
try:
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}
|
||||
)
|
||||
)
|
||||
if not store_listing_version:
|
||||
return False
|
||||
return (
|
||||
store_listing_version.submissionStatus
|
||||
== prisma.enums.SubmissionStatus.APPROVED
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking submission status: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_agent_as_admin(
|
||||
user_id: str | None,
|
||||
store_listing_version_id: str,
|
||||
|
||||
@@ -42,6 +42,7 @@ async def test_get_store_agents(mocker):
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -84,6 +85,7 @@ async def test_get_store_agent_details(mocker):
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Mock active version agent (what we want to return for active version)
|
||||
@@ -105,6 +107,7 @@ async def test_get_store_agent_details(mocker):
|
||||
versions=["1.0", "2.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Create a mock StoreListing result
|
||||
@@ -248,6 +251,7 @@ async def test_create_store_submission(mocker):
|
||||
isAvailable=True,
|
||||
)
|
||||
],
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
@@ -275,7 +279,6 @@ async def test_create_store_submission(mocker):
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
class MediaUploadError(Exception):
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
|
||||
class MediaUploadError(ValueError):
|
||||
"""Base exception for media upload errors"""
|
||||
|
||||
pass
|
||||
@@ -48,19 +51,19 @@ class VirusScanError(MediaUploadError):
|
||||
pass
|
||||
|
||||
|
||||
class StoreError(Exception):
|
||||
class StoreError(ValueError):
|
||||
"""Base exception for store-related errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentNotFoundError(StoreError):
|
||||
class AgentNotFoundError(NotFoundError):
|
||||
"""Raised when an agent is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreatorNotFoundError(StoreError):
|
||||
class CreatorNotFoundError(NotFoundError):
|
||||
"""Raised when a creator is not found"""
|
||||
|
||||
pass
|
||||
@@ -72,25 +75,19 @@ class ListingExistsError(StoreError):
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(StoreError):
|
||||
"""Raised when there is an error interacting with the database"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProfileNotFoundError(StoreError):
|
||||
class ProfileNotFoundError(NotFoundError):
|
||||
"""Raised when a profile is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ListingNotFoundError(StoreError):
|
||||
class ListingNotFoundError(NotFoundError):
|
||||
"""Raised when a store listing is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubmissionNotFoundError(StoreError):
|
||||
class SubmissionNotFoundError(NotFoundError):
|
||||
"""Raised when a submission is not found"""
|
||||
|
||||
pass
|
||||
@@ -106,3 +103,9 @@ class UnauthorizedError(StoreError):
|
||||
"""Raised when a user is not authorized to perform an action"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SlugAlreadyInUseError(StoreError):
|
||||
"""Raised when a slug is already in use by another agent owned by the user"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -6,9 +6,9 @@ import urllib.parse
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.image_gen
|
||||
@@ -21,117 +21,6 @@ logger = logging.getLogger(__name__)
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
##############################################
|
||||
|
||||
|
||||
# Cache user profiles for 1 hour per user
|
||||
@cached(maxsize=1000, ttl_seconds=3600)
|
||||
async def _get_cached_user_profile(user_id: str):
|
||||
"""Cached helper to get user profile."""
|
||||
return await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
|
||||
|
||||
# Cache store agents list for 15 minutes
|
||||
# Different cache entries for different query combinations
|
||||
@cached(maxsize=5000, ttl_seconds=900)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: str | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=900)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
|
||||
|
||||
# Cache agent graphs for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_agent_graph(store_listing_version_id: str):
|
||||
"""Cached helper to get agent graph."""
|
||||
return await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache agent by version for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
|
||||
"""Cached helper to get store agent by version ID."""
|
||||
return await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache creators list for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store creators."""
|
||||
return await backend.server.v2.store.db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual creator details for 1 hour
|
||||
@cached(maxsize=100, ttl_seconds=3600)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
|
||||
|
||||
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
|
||||
@cached(maxsize=500, ttl_seconds=300)
|
||||
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's agents."""
|
||||
return await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
|
||||
@cached(maxsize=500, ttl_seconds=3600)
|
||||
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's submissions."""
|
||||
return await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Profile Endpoints ############
|
||||
##############################################
|
||||
@@ -152,7 +41,7 @@ async def get_profile(
|
||||
Cached for 1 hour per user.
|
||||
"""
|
||||
try:
|
||||
profile = await _get_cached_user_profile(user_id)
|
||||
profile = await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
if profile is None:
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
@@ -198,8 +87,6 @@ async def update_or_create_profile(
|
||||
updated_profile = await backend.server.v2.store.db.update_profile(
|
||||
user_id=user_id, profile=profile
|
||||
)
|
||||
# Clear the cache for this user after profile update
|
||||
_get_cached_user_profile.cache_delete(user_id)
|
||||
return updated_profile
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update profile for user %s: %s", user_id, e)
|
||||
@@ -234,7 +121,6 @@ async def get_agents(
|
||||
):
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
Results are cached for 15 minutes.
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
@@ -270,7 +156,7 @@ async def get_agents(
|
||||
)
|
||||
|
||||
try:
|
||||
agents = await _get_cached_store_agents(
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
@@ -300,7 +186,6 @@ async def get_agents(
|
||||
async def get_agent(username: str, agent_name: str):
|
||||
"""
|
||||
This is only used on the AgentDetails Page.
|
||||
Results are cached for 15 minutes.
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
@@ -308,7 +193,7 @@ async def get_agent(username: str, agent_name: str):
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
agent = await _get_cached_agent_details(
|
||||
agent = await store_cache._get_cached_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
return agent
|
||||
@@ -331,10 +216,11 @@ async def get_agent(username: str, agent_name: str):
|
||||
async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: str):
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
Results are cached for 1 hour.
|
||||
"""
|
||||
try:
|
||||
graph = await _get_cached_agent_graph(store_listing_version_id)
|
||||
graph = await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
return graph
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting agent graph")
|
||||
@@ -354,10 +240,12 @@ async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: s
|
||||
async def get_store_agent(store_listing_version_id: str):
|
||||
"""
|
||||
Get Store Agent Details from Store Listing Version ID.
|
||||
Results are cached for 1 hour.
|
||||
"""
|
||||
try:
|
||||
agent = await _get_cached_store_agent_by_version(store_listing_version_id)
|
||||
agent = await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
return agent
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store agent")
|
||||
@@ -435,8 +323,6 @@ async def get_creators(
|
||||
- Home Page Featured Creators
|
||||
- Search Results Page
|
||||
|
||||
Results are cached for 1 hour.
|
||||
|
||||
---
|
||||
|
||||
To support this functionality we need:
|
||||
@@ -455,7 +341,7 @@ async def get_creators(
|
||||
)
|
||||
|
||||
try:
|
||||
creators = await _get_cached_store_creators(
|
||||
creators = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
@@ -482,12 +368,11 @@ async def get_creator(
|
||||
):
|
||||
"""
|
||||
Get the details of a creator.
|
||||
Results are cached for 1 hour.
|
||||
- Creator Details Page
|
||||
"""
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await _get_cached_creator_details(username=username)
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
return creator
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting creator details")
|
||||
@@ -518,10 +403,11 @@ async def get_my_agents(
|
||||
):
|
||||
"""
|
||||
Get user's own agents.
|
||||
Results are cached for 5 minutes per user.
|
||||
"""
|
||||
try:
|
||||
agents = await _get_cached_my_agents(user_id, page=page, page_size=page_size)
|
||||
agents = await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting my agents")
|
||||
@@ -558,13 +444,6 @@ async def delete_submission(
|
||||
submission_id=submission_id,
|
||||
)
|
||||
|
||||
# Clear submissions cache for this specific user after deletion
|
||||
if result:
|
||||
# Clear user's own agents cache - we don't know all page/size combinations
|
||||
for page in range(1, 20):
|
||||
# Clear user's submissions cache for common defaults
|
||||
_get_cached_submissions.cache_delete(user_id, page=page, page_size=20)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst deleting store submission")
|
||||
@@ -588,7 +467,6 @@ async def get_submissions(
|
||||
):
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
Results are cached for 1 hour per user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
@@ -611,8 +489,10 @@ async def get_submissions(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
try:
|
||||
listings = await _get_cached_submissions(
|
||||
user_id, page=page, page_size=page_size
|
||||
listings = await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception:
|
||||
@@ -666,12 +546,13 @@ async def create_submission(
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
# Clear user's own agents cache - we don't know all page/size combinations
|
||||
for page in range(1, 20):
|
||||
# Clear user's submissions cache for common defaults
|
||||
_get_cached_submissions.cache_delete(user_id, page=page, page_size=20)
|
||||
|
||||
return result
|
||||
except backend.server.v2.store.exceptions.SlugAlreadyInUseError as e:
|
||||
logger.warning("Slug already in use: %s", str(e))
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=409,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store submission")
|
||||
return fastapi.responses.JSONResponse(
|
||||
@@ -720,11 +601,6 @@ async def edit_submission(
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
# Clear user's own agents cache - we don't know all page/size combinations
|
||||
for page in range(1, 20):
|
||||
# Clear user's submissions cache for common defaults
|
||||
_get_cached_submissions.cache_delete(user_id, page=page, page_size=20)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -917,15 +793,10 @@ async def get_cache_metrics():
|
||||
)
|
||||
|
||||
# Add metrics for each cache
|
||||
add_cache_metrics("user_profile", _get_cached_user_profile)
|
||||
add_cache_metrics("store_agents", _get_cached_store_agents)
|
||||
add_cache_metrics("agent_details", _get_cached_agent_details)
|
||||
add_cache_metrics("agent_graph", _get_cached_agent_graph)
|
||||
add_cache_metrics("agent_by_version", _get_cached_store_agent_by_version)
|
||||
add_cache_metrics("store_creators", _get_cached_store_creators)
|
||||
add_cache_metrics("creator_details", _get_cached_creator_details)
|
||||
add_cache_metrics("my_agents", _get_cached_my_agents)
|
||||
add_cache_metrics("submissions", _get_cached_submissions)
|
||||
add_cache_metrics("store_agents", store_cache._get_cached_store_agents)
|
||||
add_cache_metrics("agent_details", store_cache._get_cached_agent_details)
|
||||
add_cache_metrics("store_creators", store_cache._get_cached_store_creators)
|
||||
add_cache_metrics("creator_details", store_cache._get_cached_creator_details)
|
||||
|
||||
# Add metadata/help text at the beginning
|
||||
prometheus_output = [
|
||||
|
||||
@@ -4,18 +4,12 @@ Test suite for verifying cache_delete functionality in store routes.
|
||||
Tests that specific cache entries can be deleted while preserving others.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.store import routes
|
||||
from backend.server.v2.store.model import (
|
||||
ProfileDetails,
|
||||
StoreAgent,
|
||||
StoreAgentDetails,
|
||||
StoreAgentsResponse,
|
||||
)
|
||||
from backend.server.v2.store import cache as store_cache
|
||||
from backend.server.v2.store.model import StoreAgent, StoreAgentsResponse
|
||||
from backend.util.models import Pagination
|
||||
|
||||
|
||||
@@ -54,10 +48,10 @@ class TestCacheDeletion:
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
# Clear cache first
|
||||
routes._get_cached_store_agents.cache_clear()
|
||||
store_cache._get_cached_store_agents.cache_clear()
|
||||
|
||||
# First call - should hit database
|
||||
result1 = await routes._get_cached_store_agents(
|
||||
result1 = await store_cache._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -70,7 +64,7 @@ class TestCacheDeletion:
|
||||
assert result1.agents[0].agent_name == "Test Agent"
|
||||
|
||||
# Second call with same params - should use cache
|
||||
await routes._get_cached_store_agents(
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -82,7 +76,7 @@ class TestCacheDeletion:
|
||||
assert mock_db.call_count == 1 # No additional DB call
|
||||
|
||||
# Third call with different params - should hit database
|
||||
await routes._get_cached_store_agents(
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=True, # Different param
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -94,7 +88,7 @@ class TestCacheDeletion:
|
||||
assert mock_db.call_count == 2 # New DB call
|
||||
|
||||
# Delete specific cache entry
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -106,7 +100,7 @@ class TestCacheDeletion:
|
||||
assert deleted is True # Entry was deleted
|
||||
|
||||
# Try to delete non-existent entry
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=False,
|
||||
creator="nonexistent",
|
||||
sorted_by=None,
|
||||
@@ -118,7 +112,7 @@ class TestCacheDeletion:
|
||||
assert deleted is False # Entry didn't exist
|
||||
|
||||
# Call with deleted params - should hit database again
|
||||
await routes._get_cached_store_agents(
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -130,7 +124,7 @@ class TestCacheDeletion:
|
||||
assert mock_db.call_count == 3 # New DB call after deletion
|
||||
|
||||
# Call with featured=True - should still be cached
|
||||
await routes._get_cached_store_agents(
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -141,105 +135,11 @@ class TestCacheDeletion:
|
||||
)
|
||||
assert mock_db.call_count == 3 # No additional DB call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_details_cache_delete(self):
|
||||
"""Test that specific agent details cache entries can be deleted."""
|
||||
mock_response = StoreAgentDetails(
|
||||
store_listing_version_id="version1",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="https://example.com/video.mp4",
|
||||
agent_image=["https://example.com/image.jpg"],
|
||||
creator="testuser",
|
||||
creator_avatar="https://example.com/avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
categories=["productivity"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=[],
|
||||
last_updated=datetime.datetime(2024, 1, 1),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_agent_details",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
# Clear cache first
|
||||
routes._get_cached_agent_details.cache_clear()
|
||||
|
||||
# First call - should hit database
|
||||
await routes._get_cached_agent_details(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
await routes._get_cached_agent_details(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert mock_db.call_count == 1 # No additional DB call
|
||||
|
||||
# Delete specific entry
|
||||
deleted = routes._get_cached_agent_details.cache_delete(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert deleted is True
|
||||
|
||||
# Call again - should hit database
|
||||
await routes._get_cached_agent_details(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert mock_db.call_count == 2 # New DB call after deletion
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_profile_cache_delete(self):
|
||||
"""Test that user profile cache entries can be deleted."""
|
||||
mock_response = ProfileDetails(
|
||||
name="Test User",
|
||||
username="testuser",
|
||||
description="Test profile",
|
||||
links=["https://example.com"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_user_profile",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
# Clear cache first
|
||||
routes._get_cached_user_profile.cache_clear()
|
||||
|
||||
# First call - should hit database
|
||||
await routes._get_cached_user_profile("user123")
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
await routes._get_cached_user_profile("user123")
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Different user - should hit database
|
||||
await routes._get_cached_user_profile("user456")
|
||||
assert mock_db.call_count == 2
|
||||
|
||||
# Delete specific user's cache
|
||||
deleted = routes._get_cached_user_profile.cache_delete("user123")
|
||||
assert deleted is True
|
||||
|
||||
# user123 should hit database again
|
||||
await routes._get_cached_user_profile("user123")
|
||||
assert mock_db.call_count == 3
|
||||
|
||||
# user456 should still be cached
|
||||
await routes._get_cached_user_profile("user456")
|
||||
assert mock_db.call_count == 3 # No additional DB call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info_after_deletions(self):
|
||||
"""Test that cache_info correctly reflects deletions."""
|
||||
# Clear all caches first
|
||||
routes._get_cached_store_agents.cache_clear()
|
||||
store_cache._get_cached_store_agents.cache_clear()
|
||||
|
||||
mock_response = StoreAgentsResponse(
|
||||
agents=[],
|
||||
@@ -258,7 +158,7 @@ class TestCacheDeletion:
|
||||
):
|
||||
# Add multiple entries
|
||||
for i in range(5):
|
||||
await routes._get_cached_store_agents(
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=f"creator{i}",
|
||||
sorted_by=None,
|
||||
@@ -269,12 +169,12 @@ class TestCacheDeletion:
|
||||
)
|
||||
|
||||
# Check cache size
|
||||
info = routes._get_cached_store_agents.cache_info()
|
||||
info = store_cache._get_cached_store_agents.cache_info()
|
||||
assert info["size"] == 5
|
||||
|
||||
# Delete some entries
|
||||
for i in range(2):
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=False,
|
||||
creator=f"creator{i}",
|
||||
sorted_by=None,
|
||||
@@ -286,7 +186,7 @@ class TestCacheDeletion:
|
||||
assert deleted is True
|
||||
|
||||
# Check cache size after deletion
|
||||
info = routes._get_cached_store_agents.cache_info()
|
||||
info = store_cache._get_cached_store_agents.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -307,10 +207,10 @@ class TestCacheDeletion:
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
routes._get_cached_store_agents.cache_clear()
|
||||
store_cache._get_cached_store_agents.cache_clear()
|
||||
|
||||
# Test with all parameters
|
||||
await routes._get_cached_store_agents(
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
@@ -322,7 +222,7 @@ class TestCacheDeletion:
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Delete with exact same parameters
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
@@ -334,7 +234,7 @@ class TestCacheDeletion:
|
||||
assert deleted is True
|
||||
|
||||
# Try to delete with slightly different parameters
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
|
||||
@@ -34,12 +34,14 @@ def get_database_manager_client() -> "DatabaseManagerClient":
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_async_client() -> "DatabaseManagerAsyncClient":
|
||||
def get_database_manager_async_client(
|
||||
should_retry: bool = True,
|
||||
) -> "DatabaseManagerAsyncClient":
|
||||
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, request_retry=True)
|
||||
return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry)
|
||||
|
||||
|
||||
@thread_cached
|
||||
|
||||
124
autogpt_platform/backend/backend/util/dynamic_fields.py
Normal file
124
autogpt_platform/backend/backend/util/dynamic_fields.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Utilities for handling dynamic field names and delimiters in the AutoGPT Platform.
|
||||
|
||||
Dynamic fields allow graphs to connect complex data structures using special delimiters:
|
||||
- _#_ for dictionary keys (e.g., "values_#_name" → values["name"])
|
||||
- _$_ for list indices (e.g., "items_$_0" → items[0])
|
||||
- _@_ for object attributes (e.g., "obj_@_attr" → obj.attr)
|
||||
|
||||
This module provides utilities for:
|
||||
- Extracting base field names from dynamic field names
|
||||
- Generating proper schemas for base fields
|
||||
- Creating helper functions for field sanitization
|
||||
"""
|
||||
|
||||
from backend.data.dynamic_fields import DICT_SPLIT, LIST_SPLIT, OBJC_SPLIT
|
||||
|
||||
# All dynamic field delimiters
|
||||
DYNAMIC_DELIMITERS = (DICT_SPLIT, LIST_SPLIT, OBJC_SPLIT)
|
||||
|
||||
|
||||
def extract_base_field_name(field_name: str) -> str:
|
||||
"""
|
||||
Extract the base field name from a dynamic field name.
|
||||
|
||||
Examples:
|
||||
extract_base_field_name("values_#_name") → "values"
|
||||
extract_base_field_name("items_$_0") → "items"
|
||||
extract_base_field_name("obj_@_attr") → "obj"
|
||||
extract_base_field_name("regular_field") → "regular_field"
|
||||
|
||||
Args:
|
||||
field_name: The field name that may contain dynamic delimiters
|
||||
|
||||
Returns:
|
||||
The base field name without any dynamic suffixes
|
||||
"""
|
||||
base_name = field_name
|
||||
for delimiter in DYNAMIC_DELIMITERS:
|
||||
if delimiter in base_name:
|
||||
base_name = base_name.split(delimiter)[0]
|
||||
return base_name
|
||||
|
||||
|
||||
def is_dynamic_field(field_name: str) -> bool:
|
||||
"""
|
||||
Check if a field name contains dynamic delimiters.
|
||||
|
||||
Args:
|
||||
field_name: The field name to check
|
||||
|
||||
Returns:
|
||||
True if the field contains any dynamic delimiters, False otherwise
|
||||
"""
|
||||
return any(delimiter in field_name for delimiter in DYNAMIC_DELIMITERS)
|
||||
|
||||
|
||||
def get_dynamic_field_description(
|
||||
base_field_name: str, original_field_name: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate a description for a dynamic field based on its base field and structure.
|
||||
|
||||
Args:
|
||||
base_field_name: The base field name (e.g., "values")
|
||||
original_field_name: The full dynamic field name (e.g., "values_#_name")
|
||||
|
||||
Returns:
|
||||
A descriptive string explaining what this dynamic field represents
|
||||
"""
|
||||
if DICT_SPLIT in original_field_name:
|
||||
key_part = (
|
||||
original_field_name.split(DICT_SPLIT, 1)[1].split(DICT_SPLIT[0])[0]
|
||||
if DICT_SPLIT in original_field_name
|
||||
else "key"
|
||||
)
|
||||
return f"Dictionary value for {base_field_name}['{key_part}']"
|
||||
elif LIST_SPLIT in original_field_name:
|
||||
index_part = (
|
||||
original_field_name.split(LIST_SPLIT, 1)[1].split(LIST_SPLIT[0])[0]
|
||||
if LIST_SPLIT in original_field_name
|
||||
else "index"
|
||||
)
|
||||
return f"List item for {base_field_name}[{index_part}]"
|
||||
elif OBJC_SPLIT in original_field_name:
|
||||
attr_part = (
|
||||
original_field_name.split(OBJC_SPLIT, 1)[1].split(OBJC_SPLIT[0])[0]
|
||||
if OBJC_SPLIT in original_field_name
|
||||
else "attr"
|
||||
)
|
||||
return f"Object attribute for {base_field_name}.{attr_part}"
|
||||
else:
|
||||
return f"Dynamic value for {base_field_name}"
|
||||
|
||||
|
||||
def group_fields_by_base_name(field_names: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Group a list of field names by their base field names.
|
||||
|
||||
Args:
|
||||
field_names: List of field names that may contain dynamic delimiters
|
||||
|
||||
Returns:
|
||||
Dictionary mapping base field names to lists of original field names
|
||||
|
||||
Example:
|
||||
group_fields_by_base_name([
|
||||
"values_#_name",
|
||||
"values_#_age",
|
||||
"items_$_0",
|
||||
"regular_field"
|
||||
])
|
||||
→ {
|
||||
"values": ["values_#_name", "values_#_age"],
|
||||
"items": ["items_$_0"],
|
||||
"regular_field": ["regular_field"]
|
||||
}
|
||||
"""
|
||||
grouped = {}
|
||||
for field_name in field_names:
|
||||
base_name = extract_base_field_name(field_name)
|
||||
if base_name not in grouped:
|
||||
grouped[base_name] = []
|
||||
grouped[base_name].append(field_name)
|
||||
return grouped
|
||||
175
autogpt_platform/backend/backend/util/dynamic_fields_test.py
Normal file
175
autogpt_platform/backend/backend/util/dynamic_fields_test.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Tests for dynamic field utilities."""
|
||||
|
||||
from backend.util.dynamic_fields import (
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
group_fields_by_base_name,
|
||||
is_dynamic_field,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractBaseFieldName:
|
||||
"""Test extracting base field names from dynamic field names."""
|
||||
|
||||
def test_extract_dict_field(self):
|
||||
"""Test extracting base name from dictionary fields."""
|
||||
assert extract_base_field_name("values_#_name") == "values"
|
||||
assert extract_base_field_name("data_#_key1_#_key2") == "data"
|
||||
assert extract_base_field_name("config_#_database_#_host") == "config"
|
||||
|
||||
def test_extract_list_field(self):
|
||||
"""Test extracting base name from list fields."""
|
||||
assert extract_base_field_name("items_$_0") == "items"
|
||||
assert extract_base_field_name("results_$_5_$_10") == "results"
|
||||
assert extract_base_field_name("nested_$_0_$_1_$_2") == "nested"
|
||||
|
||||
def test_extract_object_field(self):
|
||||
"""Test extracting base name from object fields."""
|
||||
assert extract_base_field_name("user_@_name") == "user"
|
||||
assert extract_base_field_name("response_@_data_@_items") == "response"
|
||||
assert extract_base_field_name("obj_@_attr1_@_attr2") == "obj"
|
||||
|
||||
def test_extract_mixed_fields(self):
|
||||
"""Test extracting base name from mixed dynamic fields."""
|
||||
assert extract_base_field_name("data_$_0_#_key") == "data"
|
||||
assert extract_base_field_name("items_#_user_@_name") == "items"
|
||||
assert extract_base_field_name("complex_$_0_@_attr_#_key") == "complex"
|
||||
|
||||
def test_extract_regular_field(self):
|
||||
"""Test extracting base name from regular (non-dynamic) fields."""
|
||||
assert extract_base_field_name("regular_field") == "regular_field"
|
||||
assert extract_base_field_name("simple") == "simple"
|
||||
assert extract_base_field_name("") == ""
|
||||
|
||||
def test_extract_field_with_underscores(self):
|
||||
"""Test fields with regular underscores (not dynamic delimiters)."""
|
||||
assert extract_base_field_name("field_name_here") == "field_name_here"
|
||||
assert extract_base_field_name("my_field_#_key") == "my_field"
|
||||
|
||||
|
||||
class TestIsDynamicField:
|
||||
"""Test identifying dynamic fields."""
|
||||
|
||||
def test_is_dynamic_dict_field(self):
|
||||
"""Test identifying dictionary dynamic fields."""
|
||||
assert is_dynamic_field("values_#_name") is True
|
||||
assert is_dynamic_field("data_#_key1_#_key2") is True
|
||||
|
||||
def test_is_dynamic_list_field(self):
|
||||
"""Test identifying list dynamic fields."""
|
||||
assert is_dynamic_field("items_$_0") is True
|
||||
assert is_dynamic_field("results_$_5_$_10") is True
|
||||
|
||||
def test_is_dynamic_object_field(self):
|
||||
"""Test identifying object dynamic fields."""
|
||||
assert is_dynamic_field("user_@_name") is True
|
||||
assert is_dynamic_field("response_@_data_@_items") is True
|
||||
|
||||
def test_is_dynamic_mixed_field(self):
|
||||
"""Test identifying mixed dynamic fields."""
|
||||
assert is_dynamic_field("data_$_0_#_key") is True
|
||||
assert is_dynamic_field("items_#_user_@_name") is True
|
||||
|
||||
def test_is_not_dynamic_field(self):
|
||||
"""Test identifying non-dynamic fields."""
|
||||
assert is_dynamic_field("regular_field") is False
|
||||
assert is_dynamic_field("field_name_here") is False
|
||||
assert is_dynamic_field("simple") is False
|
||||
assert is_dynamic_field("") is False
|
||||
|
||||
|
||||
class TestGetDynamicFieldDescription:
|
||||
"""Test generating descriptions for dynamic fields."""
|
||||
|
||||
def test_dict_field_description(self):
|
||||
"""Test descriptions for dictionary fields."""
|
||||
desc = get_dynamic_field_description("values", "values_#_name")
|
||||
assert "Dictionary value for values['name']" == desc
|
||||
|
||||
desc = get_dynamic_field_description("config", "config_#_database")
|
||||
assert "Dictionary value for config['database']" == desc
|
||||
|
||||
def test_list_field_description(self):
|
||||
"""Test descriptions for list fields."""
|
||||
desc = get_dynamic_field_description("items", "items_$_0")
|
||||
assert "List item for items[0]" == desc
|
||||
|
||||
desc = get_dynamic_field_description("results", "results_$_5")
|
||||
assert "List item for results[5]" == desc
|
||||
|
||||
def test_object_field_description(self):
|
||||
"""Test descriptions for object fields."""
|
||||
desc = get_dynamic_field_description("user", "user_@_name")
|
||||
assert "Object attribute for user.name" == desc
|
||||
|
||||
desc = get_dynamic_field_description("response", "response_@_data")
|
||||
assert "Object attribute for response.data" == desc
|
||||
|
||||
def test_fallback_description(self):
|
||||
"""Test fallback description for non-dynamic fields."""
|
||||
desc = get_dynamic_field_description("field", "field")
|
||||
assert "Dynamic value for field" == desc
|
||||
|
||||
|
||||
class TestGroupFieldsByBaseName:
|
||||
"""Test grouping fields by their base names."""
|
||||
|
||||
def test_group_mixed_fields(self):
|
||||
"""Test grouping a mix of dynamic and regular fields."""
|
||||
fields = [
|
||||
"values_#_name",
|
||||
"values_#_age",
|
||||
"items_$_0",
|
||||
"items_$_1",
|
||||
"user_@_email",
|
||||
"regular_field",
|
||||
"another_field",
|
||||
]
|
||||
|
||||
result = group_fields_by_base_name(fields)
|
||||
|
||||
expected = {
|
||||
"values": ["values_#_name", "values_#_age"],
|
||||
"items": ["items_$_0", "items_$_1"],
|
||||
"user": ["user_@_email"],
|
||||
"regular_field": ["regular_field"],
|
||||
"another_field": ["another_field"],
|
||||
}
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_group_empty_list(self):
|
||||
"""Test grouping an empty list."""
|
||||
result = group_fields_by_base_name([])
|
||||
assert result == {}
|
||||
|
||||
def test_group_single_field(self):
|
||||
"""Test grouping a single field."""
|
||||
result = group_fields_by_base_name(["values_#_name"])
|
||||
assert result == {"values": ["values_#_name"]}
|
||||
|
||||
def test_group_complex_dynamic_fields(self):
|
||||
"""Test grouping complex nested dynamic fields."""
|
||||
fields = [
|
||||
"data_$_0_#_key1",
|
||||
"data_$_0_#_key2",
|
||||
"data_$_1_#_key1",
|
||||
"other_@_attr",
|
||||
]
|
||||
|
||||
result = group_fields_by_base_name(fields)
|
||||
|
||||
expected = {
|
||||
"data": ["data_$_0_#_key1", "data_$_0_#_key2", "data_$_1_#_key1"],
|
||||
"other": ["other_@_attr"],
|
||||
}
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_preserve_order(self):
|
||||
"""Test that field order is preserved within groups."""
|
||||
fields = ["values_#_c", "values_#_a", "values_#_b"]
|
||||
result = group_fields_by_base_name(fields)
|
||||
|
||||
# Should preserve the original order
|
||||
assert result["values"] == ["values_#_c", "values_#_a", "values_#_b"]
|
||||
@@ -86,3 +86,9 @@ class GraphValidationError(ValueError):
|
||||
for node_id, errors in self.node_errors.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DatabaseError(Exception):
|
||||
"""Raised when there is an error interacting with the database"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -37,6 +37,11 @@ class Flag(str, Enum):
|
||||
AGENT_ACTIVITY = "agent-activity"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
"""Check if LaunchDarkly is configured with an SDK key."""
|
||||
return bool(settings.secrets.launch_darkly_sdk_key)
|
||||
|
||||
|
||||
def get_client() -> LDClient:
|
||||
"""Get the LaunchDarkly client singleton."""
|
||||
if not _is_initialized:
|
||||
|
||||
@@ -66,6 +66,18 @@ async def store_media_file(
|
||||
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Security fix: Add disk space limits to prevent DoS
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file
|
||||
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
||||
|
||||
# Check total disk usage in base_path
|
||||
if base_path.exists():
|
||||
current_usage = get_dir_size(base_path)
|
||||
if current_usage > MAX_TOTAL_DISK_USAGE:
|
||||
raise ValueError(
|
||||
f"Disk usage limit exceeded: {current_usage} bytes > {MAX_TOTAL_DISK_USAGE} bytes"
|
||||
)
|
||||
|
||||
# Helper functions
|
||||
def _extension_from_mime(mime: str) -> str:
|
||||
ext = mimetypes.guess_extension(mime, strict=False)
|
||||
@@ -108,6 +120,12 @@ async def store_media_file(
|
||||
filename = Path(path_part).name or f"{uuid.uuid4()}.bin"
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
|
||||
# Check file size limit
|
||||
if len(cloud_content) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the cloud content before writing locally
|
||||
await scan_content_safe(cloud_content, filename=filename)
|
||||
target_path.write_bytes(cloud_content)
|
||||
@@ -129,6 +147,12 @@ async def store_media_file(
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
content = base64.b64decode(b64_content)
|
||||
|
||||
# Check file size limit
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the base64 content before writing
|
||||
await scan_content_safe(content, filename=filename)
|
||||
target_path.write_bytes(content)
|
||||
@@ -142,6 +166,12 @@ async def store_media_file(
|
||||
# Download and save
|
||||
resp = await Requests().get(file)
|
||||
|
||||
# Check file size limit
|
||||
if len(resp.content) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the downloaded content before writing
|
||||
await scan_content_safe(resp.content, filename=filename)
|
||||
target_path.write_bytes(resp.content)
|
||||
@@ -159,6 +189,18 @@ async def store_media_file(
|
||||
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
||||
|
||||
|
||||
def get_dir_size(path: Path) -> int:
|
||||
"""Get total size of directory."""
|
||||
total = 0
|
||||
try:
|
||||
for entry in path.glob("**/*"):
|
||||
if entry.is_file():
|
||||
total += entry.stat().st_size
|
||||
except Exception:
|
||||
pass
|
||||
return total
|
||||
|
||||
|
||||
def get_mime_type(file: str) -> str:
|
||||
"""
|
||||
Get the MIME type of a file, whether it's a data URI, URL, or local path.
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Type, TypeGuard, TypeVar, overload
|
||||
|
||||
import jsonschema
|
||||
import orjson
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from prisma import Json
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .type import type_match
|
||||
|
||||
# Try to import orjson for better performance
|
||||
try:
|
||||
import orjson
|
||||
# Precompiled regex to remove PostgreSQL-incompatible control characters
|
||||
# Removes \u0000-\u0008, \u000B-\u000C, \u000E-\u001F, \u007F (keeps tab \u0009, newline \u000A, carriage return \u000D)
|
||||
POSTGRES_CONTROL_CHARS = re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]")
|
||||
|
||||
HAS_ORJSON = True
|
||||
except ImportError:
|
||||
HAS_ORJSON = False
|
||||
# Comprehensive regex to remove all PostgreSQL-incompatible control character sequences in JSON
|
||||
# Handles both Unicode escapes (\\u0000-\\u0008, \\u000B-\\u000C, \\u000E-\\u001F, \\u007F)
|
||||
# and JSON single-char escapes (\\b, \\f) while preserving legitimate file paths
|
||||
POSTGRES_JSON_ESCAPES = re.compile(
|
||||
r"\\u000[0-8]|\\u000[bB]|\\u000[cC]|\\u00[0-1][0-9a-fA-F]|\\u007[fF]|(?<!\\)\\[bf](?!\\)"
|
||||
)
|
||||
|
||||
|
||||
def to_dict(data) -> dict:
|
||||
@@ -23,22 +28,28 @@ def to_dict(data) -> dict:
|
||||
return jsonable_encoder(data)
|
||||
|
||||
|
||||
def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
def dumps(
|
||||
data: Any, *args: Any, indent: int | None = None, option: int = 0, **kwargs: Any
|
||||
) -> str:
|
||||
"""
|
||||
Serialize data to JSON string with automatic conversion of Pydantic models and complex types.
|
||||
|
||||
This function converts the input data to a JSON-serializable format using FastAPI's
|
||||
jsonable_encoder before dumping to JSON. It handles Pydantic models, complex types,
|
||||
and ensures proper serialization. Uses orjson for better performance when available.
|
||||
and ensures proper serialization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : Any
|
||||
The data to serialize. Can be any type including Pydantic models, dicts, lists, etc.
|
||||
*args : Any
|
||||
Additional positional arguments passed to json.dumps() (ignored if using orjson)
|
||||
Additional positional arguments
|
||||
indent : int | None
|
||||
If not None, pretty-print with indentation
|
||||
option : int
|
||||
orjson option flags (default: 0)
|
||||
**kwargs : Any
|
||||
Additional keyword arguments passed to json.dumps() (limited support with orjson)
|
||||
Additional keyword arguments. Supported: default, ensure_ascii, separators, indent
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -55,16 +66,19 @@ def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
"""
|
||||
serializable_data = to_dict(data)
|
||||
|
||||
if HAS_ORJSON:
|
||||
# orjson is faster but has limited options support
|
||||
option = 0
|
||||
if kwargs.get("indent") is not None:
|
||||
option |= orjson.OPT_INDENT_2
|
||||
# orjson.dumps returns bytes, so we decode to str
|
||||
return orjson.dumps(serializable_data, option=option).decode("utf-8")
|
||||
else:
|
||||
# Fallback to standard json
|
||||
return json.dumps(serializable_data, *args, **kwargs)
|
||||
# Handle indent parameter
|
||||
if indent is not None or kwargs.get("indent") is not None:
|
||||
option |= orjson.OPT_INDENT_2
|
||||
|
||||
# orjson only accepts specific parameters, filter out stdlib json params
|
||||
# ensure_ascii: orjson always produces UTF-8 (better than ASCII)
|
||||
# separators: orjson uses compact separators by default
|
||||
supported_orjson_params = {"default"}
|
||||
orjson_kwargs = {k: v for k, v in kwargs.items() if k in supported_orjson_params}
|
||||
|
||||
return orjson.dumps(serializable_data, option=option, **orjson_kwargs).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -81,14 +95,7 @@ def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
) -> Any:
|
||||
if HAS_ORJSON:
|
||||
# orjson can handle both str and bytes directly
|
||||
parsed = orjson.loads(data)
|
||||
else:
|
||||
# Standard json requires string input
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
parsed = orjson.loads(data)
|
||||
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
@@ -123,41 +130,24 @@ def convert_pydantic_to_json(output_data: Any) -> Any:
|
||||
return output_data
|
||||
|
||||
|
||||
def _sanitize_null_bytes(data: Any) -> Any:
|
||||
"""
|
||||
Recursively sanitize null bytes from data structures to prevent PostgreSQL 22P05 errors.
|
||||
PostgreSQL cannot store null bytes (\u0000) in text fields.
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return data.replace("\u0000", "")
|
||||
elif isinstance(data, dict):
|
||||
return {key: _sanitize_null_bytes(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [_sanitize_null_bytes(item) for item in data]
|
||||
elif isinstance(data, tuple):
|
||||
return tuple(_sanitize_null_bytes(item) for item in data)
|
||||
else:
|
||||
# For other types (int, float, bool, None, etc.), return as-is
|
||||
return data
|
||||
|
||||
|
||||
def SafeJson(data: Any) -> Json:
|
||||
"""
|
||||
Safely serialize data and return Prisma's Json type.
|
||||
Sanitizes null bytes to prevent PostgreSQL 22P05 errors.
|
||||
"""
|
||||
# Sanitize null bytes before serialization
|
||||
sanitized_data = _sanitize_null_bytes(data)
|
||||
|
||||
if isinstance(sanitized_data, BaseModel):
|
||||
return Json(
|
||||
sanitized_data.model_dump(
|
||||
mode="json",
|
||||
warnings="error",
|
||||
exclude_none=True,
|
||||
fallback=lambda v: None,
|
||||
)
|
||||
if isinstance(data, BaseModel):
|
||||
json_string = data.model_dump_json(
|
||||
warnings="error",
|
||||
exclude_none=True,
|
||||
fallback=lambda v: None,
|
||||
)
|
||||
# Round-trip through JSON to ensure proper serialization with fallback for non-serializable values
|
||||
json_string = dumps(sanitized_data, default=lambda v: None)
|
||||
return Json(json.loads(json_string))
|
||||
else:
|
||||
json_string = dumps(data, default=lambda v: None)
|
||||
|
||||
# Remove PostgreSQL-incompatible control characters in JSON string
|
||||
# Single comprehensive regex handles all control character sequences
|
||||
sanitized_json = POSTGRES_JSON_ESCAPES.sub("", json_string)
|
||||
|
||||
# Remove any remaining raw control characters (fallback safety net)
|
||||
sanitized_json = POSTGRES_CONTROL_CHARS.sub("", sanitized_json)
|
||||
return Json(json.loads(sanitized_json))
|
||||
|
||||
@@ -4,8 +4,11 @@ from enum import Enum
|
||||
import sentry_sdk
|
||||
from pydantic import SecretStr
|
||||
from sentry_sdk.integrations.anthropic import AnthropicIntegration
|
||||
from sentry_sdk.integrations.asyncio import AsyncioIntegration
|
||||
from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration
|
||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
from backend.util.feature_flag import get_client, is_configured
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
@@ -18,6 +21,9 @@ class DiscordChannel(str, Enum):
|
||||
|
||||
def sentry_init():
|
||||
sentry_dsn = settings.secrets.sentry_dsn
|
||||
integrations = []
|
||||
if is_configured():
|
||||
integrations.append(LaunchDarklyIntegration(get_client()))
|
||||
sentry_sdk.init(
|
||||
dsn=sentry_dsn,
|
||||
traces_sample_rate=1.0,
|
||||
@@ -25,11 +31,13 @@ def sentry_init():
|
||||
environment=f"app:{settings.config.app_env.value}-behave:{settings.config.behave_as.value}",
|
||||
_experiments={"enable_logs": True},
|
||||
integrations=[
|
||||
AsyncioIntegration(),
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
AnthropicIntegration(
|
||||
include_prompts=False,
|
||||
),
|
||||
],
|
||||
]
|
||||
+ integrations,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -19,9 +19,48 @@ def _msg_tokens(msg: dict, enc) -> int:
|
||||
"""
|
||||
OpenAI counts ≈3 wrapper tokens per chat message, plus 1 if "name"
|
||||
is present, plus the tokenised content length.
|
||||
For tool calls, we need to count tokens in tool_calls and content fields.
|
||||
"""
|
||||
WRAPPER = 3 + (1 if "name" in msg else 0)
|
||||
return WRAPPER + _tok_len(msg.get("content") or "", enc)
|
||||
|
||||
# Count content tokens
|
||||
content_tokens = _tok_len(msg.get("content") or "", enc)
|
||||
|
||||
# Count tool call tokens for both OpenAI and Anthropic formats
|
||||
tool_call_tokens = 0
|
||||
|
||||
# OpenAI format: tool_calls array at message level
|
||||
if "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
||||
for tool_call in msg["tool_calls"]:
|
||||
# Count the tool call structure tokens
|
||||
tool_call_tokens += _tok_len(tool_call.get("id", ""), enc)
|
||||
tool_call_tokens += _tok_len(tool_call.get("type", ""), enc)
|
||||
if "function" in tool_call:
|
||||
tool_call_tokens += _tok_len(tool_call["function"].get("name", ""), enc)
|
||||
tool_call_tokens += _tok_len(
|
||||
tool_call["function"].get("arguments", ""), enc
|
||||
)
|
||||
|
||||
# Anthropic format: tool_use within content array
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "tool_use":
|
||||
# Count the tool use structure tokens
|
||||
tool_call_tokens += _tok_len(item.get("id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("name", ""), enc)
|
||||
tool_call_tokens += _tok_len(json.dumps(item.get("input", {})), enc)
|
||||
elif isinstance(item, dict) and item.get("type") == "tool_result":
|
||||
# Count tool result tokens
|
||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
elif isinstance(item, dict) and "content" in item:
|
||||
# Other content types with content field
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
# For list content, override content_tokens since we counted everything above
|
||||
content_tokens = 0
|
||||
|
||||
return WRAPPER + content_tokens + tool_call_tokens
|
||||
|
||||
|
||||
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
|
||||
278
autogpt_platform/backend/backend/util/prompt_test.py
Normal file
278
autogpt_platform/backend/backend/util/prompt_test.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Tests for prompt utility functions, especially tool call token counting."""
|
||||
|
||||
import pytest
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import _msg_tokens, estimate_token_count
|
||||
|
||||
|
||||
class TestMsgTokens:
|
||||
"""Test the _msg_tokens function with various message types."""
|
||||
|
||||
@pytest.fixture
|
||||
def enc(self):
|
||||
"""Get the encoding for gpt-4o model."""
|
||||
return encoding_for_model("gpt-4o")
|
||||
|
||||
def test_regular_message_token_counting(self, enc):
|
||||
"""Test that regular messages are counted correctly (backward compatibility)."""
|
||||
msg = {"role": "user", "content": "What's the weather like in San Francisco?"}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should be wrapper (3) + content tokens
|
||||
expected = 3 + len(enc.encode(msg["content"]))
|
||||
assert tokens == expected
|
||||
assert tokens > 3 # Has content
|
||||
|
||||
def test_regular_message_with_name(self, enc):
|
||||
"""Test that messages with name field get extra wrapper token."""
|
||||
msg = {"role": "user", "name": "test_user", "content": "Hello!"}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should be wrapper (3 + 1 for name) + content tokens
|
||||
expected = 4 + len(enc.encode(msg["content"]))
|
||||
assert tokens == expected
|
||||
|
||||
def test_openai_tool_call_token_counting(self, enc):
|
||||
"""Test OpenAI format tool call token counting."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco", "unit": "celsius"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count wrapper + all tool call components
|
||||
expected_tool_tokens = (
|
||||
len(enc.encode("call_abc123"))
|
||||
+ len(enc.encode("function"))
|
||||
+ len(enc.encode("get_weather"))
|
||||
+ len(enc.encode('{"location": "San Francisco", "unit": "celsius"}'))
|
||||
)
|
||||
expected = 3 + expected_tool_tokens # wrapper + tool tokens
|
||||
|
||||
assert tokens == expected
|
||||
assert tokens > 8 # Should be significantly more than just wrapper
|
||||
|
||||
def test_openai_multiple_tool_calls(self, enc):
|
||||
"""Test OpenAI format with multiple tool calls."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "func1", "arguments": '{"arg": "value1"}'},
|
||||
},
|
||||
{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {"name": "func2", "arguments": '{"arg": "value2"}'},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count all tool calls
|
||||
assert tokens > 20 # Should be more than single tool call
|
||||
|
||||
def test_anthropic_tool_use_token_counting(self, enc):
|
||||
"""Test Anthropic format tool use token counting."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_xyz456",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "San Francisco", "unit": "celsius"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count wrapper + tool use components
|
||||
expected_tool_tokens = (
|
||||
len(enc.encode("toolu_xyz456"))
|
||||
+ len(enc.encode("get_weather"))
|
||||
+ len(
|
||||
enc.encode(json.dumps({"location": "San Francisco", "unit": "celsius"}))
|
||||
)
|
||||
)
|
||||
expected = 3 + expected_tool_tokens # wrapper + tool tokens
|
||||
|
||||
assert tokens == expected
|
||||
assert tokens > 8 # Should be significantly more than just wrapper
|
||||
|
||||
def test_anthropic_tool_result_token_counting(self, enc):
|
||||
"""Test Anthropic format tool result token counting."""
|
||||
msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_xyz456",
|
||||
"content": "The weather in San Francisco is 22°C and sunny.",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count wrapper + tool result components
|
||||
expected_tool_tokens = len(enc.encode("toolu_xyz456")) + len(
|
||||
enc.encode("The weather in San Francisco is 22°C and sunny.")
|
||||
)
|
||||
expected = 3 + expected_tool_tokens # wrapper + tool tokens
|
||||
|
||||
assert tokens == expected
|
||||
assert tokens > 8 # Should be significantly more than just wrapper
|
||||
|
||||
def test_anthropic_mixed_content(self, enc):
|
||||
"""Test Anthropic format with mixed content types."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "content": "I'll check the weather for you."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "SF"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count all content items
|
||||
assert tokens > 15 # Should count both text and tool use
|
||||
|
||||
def test_empty_content(self, enc):
|
||||
"""Test message with empty or None content."""
|
||||
msg = {"role": "assistant", "content": None}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
assert tokens == 3 # Just wrapper tokens
|
||||
|
||||
msg["content"] = ""
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
assert tokens == 3 # Just wrapper tokens
|
||||
|
||||
def test_string_content_with_tool_calls(self, enc):
|
||||
"""Test OpenAI format where content is string but tool_calls exist."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": "Let me check that for you.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {"name": "test_func", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count both content and tool calls
|
||||
content_tokens = len(enc.encode("Let me check that for you."))
|
||||
tool_tokens = (
|
||||
len(enc.encode("call_123"))
|
||||
+ len(enc.encode("function"))
|
||||
+ len(enc.encode("test_func"))
|
||||
+ len(enc.encode("{}"))
|
||||
)
|
||||
expected = 3 + content_tokens + tool_tokens
|
||||
|
||||
assert tokens == expected
|
||||
|
||||
|
||||
class TestEstimateTokenCount:
|
||||
"""Test the estimate_token_count function with conversations containing tool calls."""
|
||||
|
||||
def test_conversation_with_tool_calls(self):
|
||||
"""Test token counting for a complete conversation with tool calls."""
|
||||
conversation = [
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "San Francisco"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_123",
|
||||
"content": "22°C and sunny",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The weather in San Francisco is 22°C and sunny.",
|
||||
},
|
||||
]
|
||||
|
||||
total_tokens = estimate_token_count(conversation)
|
||||
|
||||
# Verify total equals sum of individual messages
|
||||
enc = encoding_for_model("gpt-4o")
|
||||
expected_total = sum(_msg_tokens(msg, enc) for msg in conversation)
|
||||
|
||||
assert total_tokens == expected_total
|
||||
assert total_tokens > 40 # Should be substantial for this conversation
|
||||
|
||||
def test_openai_conversation(self):
|
||||
"""Test token counting for OpenAI format conversation."""
|
||||
conversation = [
|
||||
{"role": "user", "content": "Calculate 2 + 2"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_calc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"arguments": '{"expression": "2 + 2"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_calc", "content": "4"},
|
||||
{"role": "assistant", "content": "The result is 4."},
|
||||
]
|
||||
|
||||
total_tokens = estimate_token_count(conversation)
|
||||
|
||||
# Verify total equals sum of individual messages
|
||||
enc = encoding_for_model("gpt-4o")
|
||||
expected_total = sum(_msg_tokens(msg, enc) for msg in conversation)
|
||||
|
||||
assert total_tokens == expected_total
|
||||
assert total_tokens > 20 # Should be substantial
|
||||
@@ -20,33 +20,72 @@ logger = logging.getLogger(__name__)
|
||||
# Alert threshold for excessive retries
|
||||
EXCESSIVE_RETRY_THRESHOLD = 50
|
||||
|
||||
# Rate limiting for alerts - track last alert time per function+error combination
|
||||
_alert_rate_limiter = {}
|
||||
_rate_limiter_lock = threading.Lock()
|
||||
ALERT_RATE_LIMIT_SECONDS = 300 # 5 minutes between same alerts
|
||||
|
||||
|
||||
def should_send_alert(func_name: str, exception: Exception, context: str = "") -> bool:
|
||||
"""Check if we should send an alert based on rate limiting."""
|
||||
# Create a unique key for this function+error+context combination
|
||||
error_signature = (
|
||||
f"{context}:{func_name}:{type(exception).__name__}:{str(exception)[:100]}"
|
||||
)
|
||||
current_time = time.time()
|
||||
|
||||
with _rate_limiter_lock:
|
||||
last_alert_time = _alert_rate_limiter.get(error_signature, 0)
|
||||
if current_time - last_alert_time >= ALERT_RATE_LIMIT_SECONDS:
|
||||
_alert_rate_limiter[error_signature] = current_time
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def send_rate_limited_discord_alert(
|
||||
func_name: str, exception: Exception, context: str, alert_msg: str, channel=None
|
||||
) -> bool:
|
||||
"""
|
||||
Send a Discord alert with rate limiting.
|
||||
|
||||
Returns True if alert was sent, False if rate limited.
|
||||
"""
|
||||
if not should_send_alert(func_name, exception, context):
|
||||
return False
|
||||
|
||||
def _send_retry_alert(
|
||||
func_name: str, attempt_number: int, exception: Exception, context: str = ""
|
||||
):
|
||||
"""Send alert for excessive retry attempts."""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
from backend.util.metrics import DiscordChannel
|
||||
|
||||
notification_client = get_notification_manager_client()
|
||||
|
||||
prefix = f"{context}: " if context else ""
|
||||
alert_msg = (
|
||||
f"🚨 Excessive Retry Alert: {prefix}'{func_name}' has failed {attempt_number} times!\n\n"
|
||||
f"Error: {type(exception).__name__}: {exception}\n\n"
|
||||
f"This indicates a persistent issue that requires investigation. "
|
||||
f"The operation has been retrying for an extended period."
|
||||
)
|
||||
|
||||
notification_client.discord_system_alert(alert_msg)
|
||||
logger.critical(
|
||||
f"ALERT SENT: Excessive retries detected for {func_name} after {attempt_number} attempts"
|
||||
notification_client.discord_system_alert(
|
||||
alert_msg, channel or DiscordChannel.PLATFORM
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as alert_error:
|
||||
logger.error(f"Failed to send retry alert: {alert_error}")
|
||||
# Don't let alerting failures break the main flow
|
||||
logger.error(f"Failed to send Discord alert: {alert_error}")
|
||||
return False
|
||||
|
||||
|
||||
def _send_critical_retry_alert(
|
||||
func_name: str, attempt_number: int, exception: Exception, context: str = ""
|
||||
):
|
||||
"""Send alert when a function is approaching the retry failure threshold."""
|
||||
|
||||
prefix = f"{context}: " if context else ""
|
||||
if send_rate_limited_discord_alert(
|
||||
func_name,
|
||||
exception,
|
||||
context,
|
||||
f"🚨 CRITICAL: Operation Approaching Failure Threshold: {prefix}'{func_name}'\n\n"
|
||||
f"Current attempt: {attempt_number}/{EXCESSIVE_RETRY_THRESHOLD}\n"
|
||||
f"Error: {type(exception).__name__}: {exception}\n\n"
|
||||
f"This operation is about to fail permanently. Investigate immediately.",
|
||||
):
|
||||
logger.critical(
|
||||
f"CRITICAL ALERT SENT: Operation {func_name} at attempt {attempt_number}"
|
||||
)
|
||||
|
||||
|
||||
def _create_retry_callback(context: str = ""):
|
||||
@@ -59,22 +98,23 @@ def _create_retry_callback(context: str = ""):
|
||||
|
||||
prefix = f"{context}: " if context else ""
|
||||
|
||||
# Send alert if we've exceeded the threshold
|
||||
if attempt_number >= EXCESSIVE_RETRY_THRESHOLD:
|
||||
_send_retry_alert(func_name, attempt_number, exception, context)
|
||||
|
||||
if retry_state.outcome.failed and retry_state.next_action is None:
|
||||
# Final failure
|
||||
# Final failure - just log the error (alert was already sent at excessive threshold)
|
||||
logger.error(
|
||||
f"{prefix}Giving up after {attempt_number} attempts for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
else:
|
||||
# Retry attempt
|
||||
logger.warning(
|
||||
f"{prefix}Retry attempt {attempt_number} for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
# Retry attempt - send critical alert only once at threshold (rate limited)
|
||||
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
|
||||
_send_critical_retry_alert(
|
||||
func_name, attempt_number, exception, context
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{prefix}Retry attempt {attempt_number} for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
@@ -130,7 +170,7 @@ def _log_prefix(resource_name: str, conn_id: str):
|
||||
def conn_retry(
|
||||
resource_name: str,
|
||||
action_name: str,
|
||||
max_retry: int = 5,
|
||||
max_retry: int = 100,
|
||||
max_wait: float = 30,
|
||||
):
|
||||
conn_id = str(uuid4())
|
||||
@@ -139,16 +179,28 @@ def conn_retry(
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
exception = retry_state.outcome.exception()
|
||||
attempt_number = retry_state.attempt_number
|
||||
|
||||
# Send alert if we've exceeded the threshold
|
||||
if attempt_number >= EXCESSIVE_RETRY_THRESHOLD:
|
||||
func_name = f"{resource_name}:{action_name}"
|
||||
context = f"Connection retry {resource_name}"
|
||||
_send_retry_alert(func_name, attempt_number, exception, context)
|
||||
func_name = getattr(retry_state.fn, "__name__", "unknown")
|
||||
|
||||
if retry_state.outcome.failed and retry_state.next_action is None:
|
||||
logger.error(f"{prefix} {action_name} failed after retries: {exception}")
|
||||
else:
|
||||
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
|
||||
if send_rate_limited_discord_alert(
|
||||
func_name,
|
||||
exception,
|
||||
f"{resource_name}_infrastructure",
|
||||
f"🚨 **Critical Infrastructure Connection Issue**\n"
|
||||
f"Resource: {resource_name}\n"
|
||||
f"Action: {action_name}\n"
|
||||
f"Function: {func_name}\n"
|
||||
f"Current attempt: {attempt_number}/{max_retry + 1}\n"
|
||||
f"Error: {type(exception).__name__}: {str(exception)[:200]}{'...' if len(str(exception)) > 200 else ''}\n\n"
|
||||
f"Infrastructure component is approaching failure threshold. Investigate immediately.",
|
||||
):
|
||||
logger.critical(
|
||||
f"INFRASTRUCTURE ALERT SENT: {resource_name} at {attempt_number} attempts"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"{prefix} {action_name} failed: {exception}. Retrying now..."
|
||||
)
|
||||
@@ -224,8 +276,8 @@ def continuous_retry(*, retry_delay: float = 1.0):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
counter = 0
|
||||
while True:
|
||||
counter = 0
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.retry import (
|
||||
ALERT_RATE_LIMIT_SECONDS,
|
||||
_alert_rate_limiter,
|
||||
_rate_limiter_lock,
|
||||
_send_critical_retry_alert,
|
||||
conn_retry,
|
||||
create_retry_decorator,
|
||||
should_send_alert,
|
||||
)
|
||||
|
||||
|
||||
def test_conn_retry_sync_function():
|
||||
@@ -47,3 +58,194 @@ async def test_conn_retry_async_function():
|
||||
with pytest.raises(ValueError) as e:
|
||||
await test_function()
|
||||
assert str(e.value) == "Test error"
|
||||
|
||||
|
||||
class TestRetryRateLimiting:
|
||||
"""Test the rate limiting functionality for critical retry alerts."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset rate limiter state before each test."""
|
||||
with _rate_limiter_lock:
|
||||
_alert_rate_limiter.clear()
|
||||
|
||||
def test_should_send_alert_allows_first_occurrence(self):
|
||||
"""Test that the first occurrence of an error allows alert."""
|
||||
exc = ValueError("test error")
|
||||
assert should_send_alert("test_func", exc, "test_context") is True
|
||||
|
||||
def test_should_send_alert_rate_limits_duplicate(self):
|
||||
"""Test that duplicate errors are rate limited."""
|
||||
exc = ValueError("test error")
|
||||
|
||||
# First call should be allowed
|
||||
assert should_send_alert("test_func", exc, "test_context") is True
|
||||
|
||||
# Second call should be rate limited
|
||||
assert should_send_alert("test_func", exc, "test_context") is False
|
||||
|
||||
def test_should_send_alert_allows_different_errors(self):
|
||||
"""Test that different errors are allowed even if same function."""
|
||||
exc1 = ValueError("error 1")
|
||||
exc2 = ValueError("error 2")
|
||||
|
||||
# First error should be allowed
|
||||
assert should_send_alert("test_func", exc1, "test_context") is True
|
||||
|
||||
# Different error should also be allowed
|
||||
assert should_send_alert("test_func", exc2, "test_context") is True
|
||||
|
||||
def test_should_send_alert_allows_different_contexts(self):
|
||||
"""Test that same error in different contexts is allowed."""
|
||||
exc = ValueError("test error")
|
||||
|
||||
# First context should be allowed
|
||||
assert should_send_alert("test_func", exc, "context1") is True
|
||||
|
||||
# Different context should also be allowed
|
||||
assert should_send_alert("test_func", exc, "context2") is True
|
||||
|
||||
def test_should_send_alert_allows_different_functions(self):
|
||||
"""Test that same error in different functions is allowed."""
|
||||
exc = ValueError("test error")
|
||||
|
||||
# First function should be allowed
|
||||
assert should_send_alert("func1", exc, "test_context") is True
|
||||
|
||||
# Different function should also be allowed
|
||||
assert should_send_alert("func2", exc, "test_context") is True
|
||||
|
||||
def test_should_send_alert_respects_time_window(self):
|
||||
"""Test that alerts are allowed again after the rate limit window."""
|
||||
exc = ValueError("test error")
|
||||
|
||||
# First call should be allowed
|
||||
assert should_send_alert("test_func", exc, "test_context") is True
|
||||
|
||||
# Immediately after should be rate limited
|
||||
assert should_send_alert("test_func", exc, "test_context") is False
|
||||
|
||||
# Mock time to simulate passage of rate limit window
|
||||
current_time = time.time()
|
||||
with patch("backend.util.retry.time.time") as mock_time:
|
||||
# Simulate time passing beyond rate limit window
|
||||
mock_time.return_value = current_time + ALERT_RATE_LIMIT_SECONDS + 1
|
||||
assert should_send_alert("test_func", exc, "test_context") is True
|
||||
|
||||
def test_should_send_alert_thread_safety(self):
|
||||
"""Test that rate limiting is thread-safe."""
|
||||
exc = ValueError("test error")
|
||||
results = []
|
||||
|
||||
def check_alert():
|
||||
result = should_send_alert("test_func", exc, "test_context")
|
||||
results.append(result)
|
||||
|
||||
# Create multiple threads trying to send the same alert
|
||||
threads = [threading.Thread(target=check_alert) for _ in range(10)]
|
||||
|
||||
# Start all threads
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one thread should have been allowed to send the alert
|
||||
assert sum(results) == 1
|
||||
assert len([r for r in results if r is True]) == 1
|
||||
assert len([r for r in results if r is False]) == 9
|
||||
|
||||
@patch("backend.util.clients.get_notification_manager_client")
|
||||
def test_send_critical_retry_alert_rate_limiting(self, mock_get_client):
|
||||
"""Test that _send_critical_retry_alert respects rate limiting."""
|
||||
mock_client = Mock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
exc = ValueError("spend_credits API error")
|
||||
|
||||
# First alert should be sent
|
||||
_send_critical_retry_alert("spend_credits", 50, exc, "Service communication")
|
||||
assert mock_client.discord_system_alert.call_count == 1
|
||||
|
||||
# Second identical alert should be rate limited (not sent)
|
||||
_send_critical_retry_alert("spend_credits", 50, exc, "Service communication")
|
||||
assert mock_client.discord_system_alert.call_count == 1 # Still 1, not 2
|
||||
|
||||
# Different error should be allowed
|
||||
exc2 = ValueError("different API error")
|
||||
_send_critical_retry_alert("spend_credits", 50, exc2, "Service communication")
|
||||
assert mock_client.discord_system_alert.call_count == 2
|
||||
|
||||
@patch("backend.util.clients.get_notification_manager_client")
|
||||
def test_send_critical_retry_alert_handles_notification_failure(
|
||||
self, mock_get_client
|
||||
):
|
||||
"""Test that notification failures don't break the rate limiter."""
|
||||
mock_client = Mock()
|
||||
mock_client.discord_system_alert.side_effect = Exception("Notification failed")
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
exc = ValueError("test error")
|
||||
|
||||
# Should not raise exception even if notification fails
|
||||
_send_critical_retry_alert("test_func", 50, exc, "test_context")
|
||||
|
||||
# Rate limiter should still work for subsequent calls
|
||||
assert should_send_alert("test_func", exc, "test_context") is False
|
||||
|
||||
def test_error_signature_generation(self):
|
||||
"""Test that error signatures are generated correctly for rate limiting."""
|
||||
# Test with long exception message (should be truncated to 100 chars)
|
||||
long_message = "x" * 200
|
||||
exc = ValueError(long_message)
|
||||
|
||||
# Should not raise exception and should work normally
|
||||
assert should_send_alert("test_func", exc, "test_context") is True
|
||||
assert should_send_alert("test_func", exc, "test_context") is False
|
||||
|
||||
def test_real_world_scenario_spend_credits_spam(self):
|
||||
"""Test the real-world scenario that was causing spam."""
|
||||
# Simulate the exact error that was causing issues
|
||||
exc = Exception(
|
||||
"HTTP 500: Server error '500 Internal Server Error' for url 'http://autogpt-database-manager.prod-agpt.svc.cluster.local:8005/spend_credits'"
|
||||
)
|
||||
|
||||
# First 50 attempts reach threshold - should send alert
|
||||
with patch(
|
||||
"backend.util.clients.get_notification_manager_client"
|
||||
) as mock_get_client:
|
||||
mock_client = Mock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
_send_critical_retry_alert(
|
||||
"_call_method_sync", 50, exc, "Service communication"
|
||||
)
|
||||
assert mock_client.discord_system_alert.call_count == 1
|
||||
|
||||
# Next 950 failures should not send alerts (rate limited)
|
||||
for _ in range(950):
|
||||
_send_critical_retry_alert(
|
||||
"_call_method_sync", 50, exc, "Service communication"
|
||||
)
|
||||
|
||||
# Still only 1 alert sent total
|
||||
assert mock_client.discord_system_alert.call_count == 1
|
||||
|
||||
@patch("backend.util.clients.get_notification_manager_client")
|
||||
def test_retry_decorator_with_excessive_failures(self, mock_get_client):
|
||||
"""Test retry decorator behavior when it hits the alert threshold."""
|
||||
mock_client = Mock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
@create_retry_decorator(
|
||||
max_attempts=60, max_wait=0.1
|
||||
) # More than EXCESSIVE_RETRY_THRESHOLD, but fast
|
||||
def always_failing_function():
|
||||
raise ValueError("persistent failure")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
always_failing_function()
|
||||
|
||||
# Should have sent exactly one alert at the threshold
|
||||
assert mock_client.discord_system_alert.call_count == 1
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property, update_wrapper
|
||||
from functools import update_wrapper
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
@@ -28,6 +28,7 @@ from fastapi import FastAPI, Request, responses
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
|
||||
import backend.util.exceptions as exceptions
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
@@ -283,6 +284,24 @@ class AppService(BaseAppService, ABC):
|
||||
super().run()
|
||||
self.fastapi_app = FastAPI()
|
||||
|
||||
# Add Prometheus instrumentation to all services
|
||||
try:
|
||||
instrument_fastapi(
|
||||
self.fastapi_app,
|
||||
service_name=self.service_name,
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=False,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
f"Prometheus instrumentation not available for {self.service_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to instrument {self.service_name} with Prometheus: {e}"
|
||||
)
|
||||
|
||||
# Register the exposed API routes.
|
||||
for attr_name, attr in vars(type(self)).items():
|
||||
if getattr(attr, EXPOSED_FLAG, False):
|
||||
@@ -375,6 +394,8 @@ def get_service_client(
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
self._connection_failure_count = 0
|
||||
self._last_client_reset = 0
|
||||
self._async_clients = {} # None key for default async client
|
||||
self._sync_clients = {} # For sync clients (no event loop concept)
|
||||
|
||||
def _create_sync_client(self) -> httpx.Client:
|
||||
return httpx.Client(
|
||||
@@ -398,13 +419,33 @@ def get_service_client(
|
||||
),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
return self._create_sync_client()
|
||||
"""Get the sync client (thread-safe singleton)."""
|
||||
# Use service name as key for better identification
|
||||
service_name = service_client_type.get_service_type().__name__
|
||||
if client := self._sync_clients.get(service_name):
|
||||
return client
|
||||
return self._sync_clients.setdefault(
|
||||
service_name, self._create_sync_client()
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
return self._create_async_client()
|
||||
"""Get the appropriate async client for the current context.
|
||||
|
||||
Returns per-event-loop client when in async context,
|
||||
falls back to default client otherwise.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop, use None as default key
|
||||
loop = None
|
||||
|
||||
if client := self._async_clients.get(loop):
|
||||
return client
|
||||
return self._async_clients.setdefault(loop, self._create_async_client())
|
||||
|
||||
def _handle_connection_error(self, error: Exception) -> None:
|
||||
"""Handle connection errors and implement self-healing"""
|
||||
@@ -423,10 +464,8 @@ def get_service_client(
|
||||
|
||||
# Clear cached clients to force recreation on next access
|
||||
# Only recreate when there's actually a problem
|
||||
if hasattr(self, "sync_client"):
|
||||
delattr(self, "sync_client")
|
||||
if hasattr(self, "async_client"):
|
||||
delattr(self, "async_client")
|
||||
self._sync_clients.clear()
|
||||
self._async_clients.clear()
|
||||
|
||||
# Reset counters
|
||||
self._connection_failure_count = 0
|
||||
@@ -492,28 +531,37 @@ def get_service_client(
|
||||
raise
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
if hasattr(self, "async_client"):
|
||||
await self.async_client.aclose()
|
||||
# Close all sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
self._sync_clients.clear()
|
||||
|
||||
# Close all async clients (including default with None key)
|
||||
for client in self._async_clients.values():
|
||||
await client.aclose()
|
||||
self._async_clients.clear()
|
||||
|
||||
def close(self) -> None:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
# Note: Cannot close async client synchronously
|
||||
# Close all sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
self._sync_clients.clear()
|
||||
# Note: Cannot close async clients synchronously
|
||||
# They will be cleaned up by garbage collection
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup HTTP clients on garbage collection to prevent resource leaks."""
|
||||
try:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
if hasattr(self, "async_client"):
|
||||
# Note: Can't await in __del__, so we just close sync
|
||||
# The async client will be cleaned up by garbage collection
|
||||
# Close any remaining sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
|
||||
# Warn if async clients weren't properly closed
|
||||
if self._async_clients:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"DynamicClient async client not explicitly closed. "
|
||||
"DynamicClient async clients not explicitly closed. "
|
||||
"Call aclose() before destroying the client.",
|
||||
ResourceWarning,
|
||||
stacklevel=2,
|
||||
|
||||
@@ -59,6 +59,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for graph execution.",
|
||||
)
|
||||
|
||||
# FastAPI Thread Pool Configuration
|
||||
# IMPORTANT: FastAPI automatically offloads ALL sync functions to a thread pool:
|
||||
# - Sync endpoint functions (def instead of async def)
|
||||
# - Sync dependency functions (def instead of async def)
|
||||
# - Manually called run_in_threadpool() operations
|
||||
# Default thread pool size is only 40, which becomes a bottleneck under high concurrency
|
||||
fastapi_thread_pool_size: int = Field(
|
||||
default=60,
|
||||
ge=40,
|
||||
le=500,
|
||||
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
@@ -127,10 +140,20 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=5 * 60,
|
||||
description="Time in seconds after which the execution stuck on QUEUED status is considered late.",
|
||||
)
|
||||
cluster_lock_timeout: int = Field(
|
||||
default=300,
|
||||
description="Cluster lock timeout in seconds for graph execution coordination.",
|
||||
)
|
||||
execution_late_notification_checkrange_secs: int = Field(
|
||||
default=60 * 60,
|
||||
description="Time in seconds for how far back to check for the late executions.",
|
||||
)
|
||||
max_concurrent_graph_executions_per_user: int = Field(
|
||||
default=25,
|
||||
ge=1,
|
||||
le=1000,
|
||||
description="Maximum number of concurrent graph executions allowed per user per graph.",
|
||||
)
|
||||
|
||||
block_error_rate_threshold: float = Field(
|
||||
default=0.5,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from prisma import Json
|
||||
from pydantic import BaseModel
|
||||
@@ -215,3 +215,199 @@ class TestSafeJson:
|
||||
}
|
||||
result = SafeJson(data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
def test_control_character_sanitization(self):
|
||||
"""Test that PostgreSQL-incompatible control characters are sanitized by SafeJson."""
|
||||
# Test data with problematic control characters that would cause PostgreSQL errors
|
||||
problematic_data = {
|
||||
"null_byte": "data with \x00 null",
|
||||
"bell_char": "data with \x07 bell",
|
||||
"form_feed": "data with \x0C feed",
|
||||
"escape_char": "data with \x1B escape",
|
||||
"delete_char": "data with \x7F delete",
|
||||
}
|
||||
|
||||
# SafeJson should successfully process data with control characters
|
||||
result = SafeJson(problematic_data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
# Verify that dangerous control characters are actually removed
|
||||
result_data = result.data
|
||||
assert "\x00" not in str(result_data) # null byte removed
|
||||
assert "\x07" not in str(result_data) # bell removed
|
||||
assert "\x0C" not in str(result_data) # form feed removed
|
||||
assert "\x1B" not in str(result_data) # escape removed
|
||||
assert "\x7F" not in str(result_data) # delete removed
|
||||
|
||||
# Test that safe whitespace characters are preserved
|
||||
safe_data = {
|
||||
"with_tab": "text with \t tab",
|
||||
"with_newline": "text with \n newline",
|
||||
"with_carriage_return": "text with \r carriage return",
|
||||
"normal_text": "completely normal text",
|
||||
}
|
||||
|
||||
safe_result = SafeJson(safe_data)
|
||||
assert isinstance(safe_result, Json)
|
||||
|
||||
# Verify safe characters are preserved
|
||||
safe_result_data = cast(dict[str, Any], safe_result.data)
|
||||
assert isinstance(safe_result_data, dict)
|
||||
with_tab = safe_result_data.get("with_tab", "")
|
||||
with_newline = safe_result_data.get("with_newline", "")
|
||||
with_carriage_return = safe_result_data.get("with_carriage_return", "")
|
||||
assert "\t" in str(with_tab) # tab preserved
|
||||
assert "\n" in str(with_newline) # newline preserved
|
||||
assert "\r" in str(with_carriage_return) # carriage return preserved
|
||||
|
||||
def test_web_scraping_content_sanitization(self):
|
||||
"""Test sanitization of typical web scraping content with null characters."""
|
||||
# Simulate web content that might contain null bytes from SearchTheWebBlock
|
||||
web_content = "Article title\x00Hidden null\x01Start of heading\x08Backspace\x0CForm feed content\x1FUnit separator\x7FDelete char"
|
||||
|
||||
result = SafeJson(web_content)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
# Verify all problematic characters are removed
|
||||
sanitized_content = str(result.data)
|
||||
assert "\x00" not in sanitized_content
|
||||
assert "\x01" not in sanitized_content
|
||||
assert "\x08" not in sanitized_content
|
||||
assert "\x0C" not in sanitized_content
|
||||
assert "\x1F" not in sanitized_content
|
||||
assert "\x7F" not in sanitized_content
|
||||
|
||||
# Verify the content is still readable
|
||||
assert "Article title" in sanitized_content
|
||||
assert "Hidden null" in sanitized_content
|
||||
assert "content" in sanitized_content
|
||||
|
||||
def test_legitimate_code_preservation(self):
|
||||
"""Test that legitimate code with backslashes and escapes is preserved."""
|
||||
# File paths with backslashes should be preserved
|
||||
file_paths = {
|
||||
"windows_path": "C:\\Users\\test\\file.txt",
|
||||
"network_path": "\\\\server\\share\\folder",
|
||||
"escaped_backslashes": "String with \\\\ double backslashes",
|
||||
}
|
||||
|
||||
result = SafeJson(file_paths)
|
||||
result_data = cast(dict[str, Any], result.data)
|
||||
assert isinstance(result_data, dict)
|
||||
|
||||
# Verify file paths are preserved correctly (JSON converts \\\\ back to \\)
|
||||
windows_path = result_data.get("windows_path", "")
|
||||
network_path = result_data.get("network_path", "")
|
||||
escaped_backslashes = result_data.get("escaped_backslashes", "")
|
||||
assert "C:\\Users\\test\\file.txt" in str(windows_path)
|
||||
assert "\\server\\share" in str(network_path)
|
||||
assert "\\" in str(escaped_backslashes)
|
||||
|
||||
def test_legitimate_json_escapes_preservation(self):
|
||||
"""Test that legitimate JSON escape sequences are preserved."""
|
||||
# These should all be preserved as they're valid and useful
|
||||
legitimate_escapes = {
|
||||
"quotes": 'He said "Hello world!"',
|
||||
"newlines": "Line 1\\nLine 2\\nLine 3",
|
||||
"tabs": "Column1\\tColumn2\\tColumn3",
|
||||
"unicode_chars": "Unicode: \u0048\u0065\u006c\u006c\u006f", # "Hello"
|
||||
"mixed_content": "Path: C:\\\\temp\\\\file.txt\\nSize: 1024 bytes",
|
||||
}
|
||||
|
||||
result = SafeJson(legitimate_escapes)
|
||||
result_data = cast(dict[str, Any], result.data)
|
||||
assert isinstance(result_data, dict)
|
||||
|
||||
# Verify all legitimate content is preserved
|
||||
quotes = result_data.get("quotes", "")
|
||||
newlines = result_data.get("newlines", "")
|
||||
tabs = result_data.get("tabs", "")
|
||||
unicode_chars = result_data.get("unicode_chars", "")
|
||||
mixed_content = result_data.get("mixed_content", "")
|
||||
|
||||
assert '"' in str(quotes)
|
||||
assert "Line 1" in str(newlines) and "Line 2" in str(newlines)
|
||||
assert "Column1" in str(tabs) and "Column2" in str(tabs)
|
||||
assert "Hello" in str(unicode_chars) # Unicode should be decoded
|
||||
assert "C:" in str(mixed_content) and "temp" in str(mixed_content)
|
||||
|
||||
def test_regex_patterns_dont_over_match(self):
|
||||
"""Test that our regex patterns don't accidentally match legitimate sequences."""
|
||||
# Edge cases that could be problematic for regex
|
||||
edge_cases = {
|
||||
"file_with_b": "C:\\\\mybfile.txt", # Contains 'bf' but not escape sequence
|
||||
"file_with_f": "C:\\\\folder\\\\file.txt", # Contains 'f' after backslashes
|
||||
"json_like_string": '{"text": "\\\\bolder text"}', # Looks like JSON escape but isn't
|
||||
"unicode_like": "Code: \\\\u0040 (not a real escape)", # Looks like Unicode escape
|
||||
}
|
||||
|
||||
result = SafeJson(edge_cases)
|
||||
result_data = cast(dict[str, Any], result.data)
|
||||
assert isinstance(result_data, dict)
|
||||
|
||||
# Verify edge cases are handled correctly - no content should be lost
|
||||
file_with_b = result_data.get("file_with_b", "")
|
||||
file_with_f = result_data.get("file_with_f", "")
|
||||
json_like_string = result_data.get("json_like_string", "")
|
||||
unicode_like = result_data.get("unicode_like", "")
|
||||
|
||||
assert "mybfile.txt" in str(file_with_b)
|
||||
assert "folder" in str(file_with_f) and "file.txt" in str(file_with_f)
|
||||
assert "bolder text" in str(json_like_string)
|
||||
assert "\\u0040" in str(unicode_like)
|
||||
|
||||
def test_programming_code_preservation(self):
|
||||
"""Test that programming code with various escapes is preserved."""
|
||||
# Common programming patterns that should be preserved
|
||||
code_samples = {
|
||||
"python_string": 'print("Hello\\\\nworld")',
|
||||
"regex_pattern": "\\\\b[A-Za-z]+\\\\b", # Word boundary regex
|
||||
"json_string": '{"name": "test", "path": "C:\\\\\\\\folder"}',
|
||||
"sql_escape": "WHERE name LIKE '%\\\\%%'",
|
||||
"javascript": 'var path = "C:\\\\\\\\Users\\\\\\\\file.js";',
|
||||
}
|
||||
|
||||
result = SafeJson(code_samples)
|
||||
result_data = cast(dict[str, Any], result.data)
|
||||
assert isinstance(result_data, dict)
|
||||
|
||||
# Verify programming code is preserved
|
||||
python_string = result_data.get("python_string", "")
|
||||
regex_pattern = result_data.get("regex_pattern", "")
|
||||
json_string = result_data.get("json_string", "")
|
||||
sql_escape = result_data.get("sql_escape", "")
|
||||
javascript = result_data.get("javascript", "")
|
||||
|
||||
assert "print(" in str(python_string)
|
||||
assert "Hello" in str(python_string)
|
||||
assert "[A-Za-z]+" in str(regex_pattern)
|
||||
assert "name" in str(json_string)
|
||||
assert "LIKE" in str(sql_escape)
|
||||
assert "var path" in str(javascript)
|
||||
|
||||
def test_only_problematic_sequences_removed(self):
|
||||
"""Test that ONLY PostgreSQL-problematic sequences are removed, nothing else."""
|
||||
# Mix of problematic and safe content (using actual control characters)
|
||||
mixed_content = {
|
||||
"safe_and_unsafe": "Good text\twith tab\x00NULL BYTE\nand newline\x08BACKSPACE",
|
||||
"file_path_with_null": "C:\\temp\\file\x00.txt",
|
||||
"json_with_controls": '{"text": "data\x01\x0C\x1F"}',
|
||||
}
|
||||
|
||||
result = SafeJson(mixed_content)
|
||||
result_data = cast(dict[str, Any], result.data)
|
||||
assert isinstance(result_data, dict)
|
||||
|
||||
# Verify only problematic characters are removed
|
||||
safe_and_unsafe = result_data.get("safe_and_unsafe", "")
|
||||
file_path_with_null = result_data.get("file_path_with_null", "")
|
||||
|
||||
assert "Good text" in str(safe_and_unsafe)
|
||||
assert "\t" in str(safe_and_unsafe) # Tab preserved
|
||||
assert "\n" in str(safe_and_unsafe) # Newline preserved
|
||||
assert "\x00" not in str(safe_and_unsafe) # Null removed
|
||||
assert "\x08" not in str(safe_and_unsafe) # Backspace removed
|
||||
|
||||
assert "C:\\temp\\file" in str(file_path_with_null)
|
||||
assert ".txt" in str(file_path_with_null)
|
||||
assert "\x00" not in str(file_path_with_null) # Null removed from path
|
||||
|
||||
@@ -2,9 +2,16 @@ import asyncio
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import aioclamd
|
||||
# Suppress the specific pkg_resources deprecation warning from aioclamd
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="pkg_resources is deprecated", category=UserWarning
|
||||
)
|
||||
import aioclamd
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clean the test database by removing all data while preserving the schema.
|
||||
|
||||
Usage:
|
||||
poetry run python clean_test_db.py [--yes]
|
||||
|
||||
Options:
|
||||
--yes Skip confirmation prompt
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from prisma import Prisma
|
||||
|
||||
|
||||
async def main():
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
print("=" * 60)
|
||||
print("Cleaning Test Database")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Get initial counts
|
||||
user_count = await db.user.count()
|
||||
agent_count = await db.agentgraph.count()
|
||||
|
||||
print(f"Current data: {user_count} users, {agent_count} agent graphs")
|
||||
|
||||
if user_count == 0 and agent_count == 0:
|
||||
print("Database is already clean!")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
# Check for --yes flag
|
||||
skip_confirm = "--yes" in sys.argv
|
||||
|
||||
if not skip_confirm:
|
||||
response = input("\nDo you want to clean all data? (yes/no): ")
|
||||
if response.lower() != "yes":
|
||||
print("Aborted.")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
print("\nCleaning database...")
|
||||
|
||||
# Delete in reverse order of dependencies
|
||||
tables = [
|
||||
("UserNotificationBatch", db.usernotificationbatch),
|
||||
("NotificationEvent", db.notificationevent),
|
||||
("CreditRefundRequest", db.creditrefundrequest),
|
||||
("StoreListingReview", db.storelistingreview),
|
||||
("StoreListingVersion", db.storelistingversion),
|
||||
("StoreListing", db.storelisting),
|
||||
("AgentNodeExecutionInputOutput", db.agentnodeexecutioninputoutput),
|
||||
("AgentNodeExecution", db.agentnodeexecution),
|
||||
("AgentGraphExecution", db.agentgraphexecution),
|
||||
("AgentNodeLink", db.agentnodelink),
|
||||
("LibraryAgent", db.libraryagent),
|
||||
("AgentPreset", db.agentpreset),
|
||||
("IntegrationWebhook", db.integrationwebhook),
|
||||
("AgentNode", db.agentnode),
|
||||
("AgentGraph", db.agentgraph),
|
||||
("AgentBlock", db.agentblock),
|
||||
("APIKey", db.apikey),
|
||||
("CreditTransaction", db.credittransaction),
|
||||
("AnalyticsMetrics", db.analyticsmetrics),
|
||||
("AnalyticsDetails", db.analyticsdetails),
|
||||
("Profile", db.profile),
|
||||
("UserOnboarding", db.useronboarding),
|
||||
("User", db.user),
|
||||
]
|
||||
|
||||
for table_name, table in tables:
|
||||
try:
|
||||
count = await table.count()
|
||||
if count > 0:
|
||||
await table.delete_many()
|
||||
print(f"✓ Deleted {count} records from {table_name}")
|
||||
except Exception as e:
|
||||
print(f"⚠ Error cleaning {table_name}: {e}")
|
||||
|
||||
# Refresh materialized views (they should be empty now)
|
||||
try:
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
print("\n✓ Refreshed materialized views")
|
||||
except Exception as e:
|
||||
print(f"\n⚠ Could not refresh materialized views: {e}")
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Database cleaned successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -8,26 +8,23 @@ Clean, streamlined load testing infrastructure for the AutoGPT Platform using k6
|
||||
# 1. Set up Supabase service key (required for token generation)
|
||||
export SUPABASE_SERVICE_KEY="your-supabase-service-key"
|
||||
|
||||
# 2. Generate pre-authenticated tokens (first time setup - creates 150+ tokens with 24-hour expiry)
|
||||
node generate-tokens.js
|
||||
# 2. Generate pre-authenticated tokens (first time setup - creates 160+ tokens with 24-hour expiry)
|
||||
node generate-tokens.js --count=160
|
||||
|
||||
# 3. Set up k6 cloud credentials (for cloud testing)
|
||||
export K6_CLOUD_TOKEN="your-k6-cloud-token"
|
||||
# 3. Set up k6 cloud credentials (for cloud testing - see Credential Setup section below)
|
||||
export K6_CLOUD_TOKEN="your-k6-cloud-token"
|
||||
export K6_CLOUD_PROJECT_ID="4254406"
|
||||
|
||||
# 4. Verify setup and run quick test
|
||||
node run-tests.js verify
|
||||
# 4. Run orchestrated load tests locally
|
||||
node orchestrator/orchestrator.js DEV local
|
||||
|
||||
# 5. Run tests locally (development/debugging)
|
||||
node run-tests.js run all DEV
|
||||
|
||||
# 6. Run tests in k6 cloud (performance testing)
|
||||
node run-tests.js cloud all DEV
|
||||
# 5. Run orchestrated load tests in k6 cloud (recommended)
|
||||
node orchestrator/orchestrator.js DEV cloud
|
||||
```
|
||||
|
||||
## 📋 Unified Test Runner
|
||||
## 📋 Load Test Orchestrator
|
||||
|
||||
The AutoGPT Platform uses a single unified test runner (`run-tests.js`) for both local and cloud execution:
|
||||
The AutoGPT Platform uses a comprehensive load test orchestrator (`orchestrator/orchestrator.js`) that runs 12 optimized tests with maximum VU counts:
|
||||
|
||||
### Available Tests
|
||||
|
||||
@@ -53,45 +50,33 @@ The AutoGPT Platform uses a single unified test runner (`run-tests.js`) for both
|
||||
### Test Modes
|
||||
|
||||
- **Local Mode**: 5 VUs × 30s - Quick validation and debugging
|
||||
- **Cloud Mode**: 80-150 VUs × 3-5m - Real performance testing
|
||||
- **Cloud Mode**: 80-160 VUs × 3-6m - Real performance testing
|
||||
|
||||
## 🛠️ Usage
|
||||
|
||||
### Basic Commands
|
||||
|
||||
```bash
|
||||
# List available tests and show cloud credentials status
|
||||
node run-tests.js list
|
||||
# Run 12 optimized tests locally (for debugging)
|
||||
node orchestrator/orchestrator.js DEV local
|
||||
|
||||
# Quick setup verification
|
||||
node run-tests.js verify
|
||||
# Run 12 optimized tests in k6 cloud (recommended for performance testing)
|
||||
node orchestrator/orchestrator.js DEV cloud
|
||||
|
||||
# Run specific test locally
|
||||
node run-tests.js run core-api-test DEV
|
||||
# Run against production (coordinate with team!)
|
||||
node orchestrator/orchestrator.js PROD cloud
|
||||
|
||||
# Run multiple tests sequentially (comma-separated)
|
||||
node run-tests.js run connectivity-test,core-api-test,marketplace-public-test DEV
|
||||
|
||||
# Run all tests locally
|
||||
node run-tests.js run all DEV
|
||||
|
||||
# Run specific test in k6 cloud
|
||||
node run-tests.js cloud core-api-test DEV
|
||||
|
||||
# Run all tests in k6 cloud
|
||||
node run-tests.js cloud all DEV
|
||||
# Run individual test directly with k6
|
||||
K6_ENVIRONMENT=DEV VUS=100 DURATION=3m k6 run tests/api/core-api-test.js
|
||||
```
|
||||
|
||||
### NPM Scripts
|
||||
|
||||
```bash
|
||||
# Quick verification
|
||||
npm run verify
|
||||
# Run orchestrator locally
|
||||
npm run local
|
||||
|
||||
# Run all tests locally
|
||||
npm test
|
||||
|
||||
# Run all tests in k6 cloud
|
||||
# Run orchestrator in k6 cloud
|
||||
npm run cloud
|
||||
```
|
||||
|
||||
@@ -99,12 +84,12 @@ npm run cloud
|
||||
|
||||
### Pre-Authenticated Tokens
|
||||
|
||||
- **Generation**: Run `node generate-tokens.js` to create tokens
|
||||
- **File**: `configs/pre-authenticated-tokens.js` (gitignored for security)
|
||||
- **Capacity**: 150+ tokens supporting high-concurrency testing
|
||||
- **Generation**: Run `node generate-tokens.js --count=160` to create tokens
|
||||
- **File**: `configs/pre-authenticated-tokens.js` (gitignored for security)
|
||||
- **Capacity**: 160+ tokens supporting high-concurrency testing
|
||||
- **Expiry**: 24 hours (86400 seconds) - extended for long-duration testing
|
||||
- **Benefit**: Eliminates Supabase auth rate limiting at scale
|
||||
- **Regeneration**: Run `node generate-tokens.js` when tokens expire after 24 hours
|
||||
- **Regeneration**: Run `node generate-tokens.js --count=160` when tokens expire after 24 hours
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
@@ -124,8 +109,8 @@ npm run cloud
|
||||
|
||||
### Load Testing Capabilities
|
||||
|
||||
- **High Concurrency**: Up to 150+ virtual users per test
|
||||
- **Authentication Scaling**: Pre-auth tokens support 150+ concurrent users (10 tokens generated by default)
|
||||
- **High Concurrency**: Up to 160+ virtual users per test
|
||||
- **Authentication Scaling**: Pre-auth tokens support 160+ concurrent users
|
||||
- **Sequential Execution**: Multiple tests run one after another with proper delays
|
||||
- **Cloud Infrastructure**: Tests run on k6 cloud servers for consistent results
|
||||
- **ES Module Support**: Full ES module compatibility with modern JavaScript features
|
||||
@@ -134,10 +119,11 @@ npm run cloud
|
||||
|
||||
### Validated Performance Limits
|
||||
|
||||
- **Core API**: 100 VUs successfully handling `/api/credits`, `/api/graphs`, `/api/blocks`, `/api/executions`
|
||||
- **Graph Execution**: 80 VUs for complete workflow pipeline
|
||||
- **Marketplace Browsing**: 150 VUs for public marketplace access
|
||||
- **Authentication**: 150+ concurrent users with pre-authenticated tokens
|
||||
- **Core API**: 100+ VUs successfully handling `/api/credits`, `/api/graphs`, `/api/blocks`, `/api/executions`
|
||||
- **Graph Execution**: 80+ VUs for complete workflow pipeline
|
||||
- **Marketplace Browsing**: 160 VUs for public marketplace access (verified)
|
||||
- **Marketplace Library**: 160 VUs for authenticated library operations (verified)
|
||||
- **Authentication**: 160+ concurrent users with pre-authenticated tokens
|
||||
|
||||
### Target Metrics
|
||||
|
||||
@@ -146,6 +132,14 @@ npm run cloud
|
||||
- **Success Rate**: Target > 95% under normal load
|
||||
- **Error Rate**: Target < 5% for all endpoints
|
||||
|
||||
### Recent Performance Results (160 VU Test - Verified)
|
||||
|
||||
- **Marketplace Library Operations**: 500-1000ms response times at 160 VUs
|
||||
- **Authentication**: 100% success rate with pre-authenticated tokens
|
||||
- **Library Journeys**: 5 operations per journey completing successfully
|
||||
- **Test Duration**: 6+ minutes sustained load without degradation
|
||||
- **k6 Cloud Execution**: Stable performance on Amazon US Columbus infrastructure
|
||||
|
||||
## 🔍 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
@@ -157,8 +151,8 @@ npm run cloud
|
||||
❌ Token has expired
|
||||
```
|
||||
|
||||
- **Solution**: Run `node generate-tokens.js` to create fresh 24-hour tokens
|
||||
- **Note**: Default generates 10 tokens (increase with `--count=50` for higher concurrency)
|
||||
- **Solution**: Run `node generate-tokens.js --count=160` to create fresh 24-hour tokens
|
||||
- **Note**: Use `--count` parameter to generate appropriate number of tokens for your test scale
|
||||
|
||||
**2. Cloud Credentials Missing**
|
||||
|
||||
@@ -168,7 +162,17 @@ npm run cloud
|
||||
|
||||
- **Solution**: Set `K6_CLOUD_TOKEN` and `K6_CLOUD_PROJECT_ID=4254406`
|
||||
|
||||
**3. Setup Verification Failed**
|
||||
**3. k6 Cloud VU Scaling Issue**
|
||||
|
||||
```
|
||||
❌ Test shows only 5 VUs instead of requested 100+ VUs
|
||||
```
|
||||
|
||||
- **Problem**: Using `K6_ENVIRONMENT=DEV VUS=160 k6 cloud run test.js` (incorrect)
|
||||
- **Solution**: Use `k6 cloud run --env K6_ENVIRONMENT=DEV --env VUS=160 test.js` (correct)
|
||||
- **Note**: The unified test runner (`run-tests.js`) already uses the correct syntax
|
||||
|
||||
**4. Setup Verification Failed**
|
||||
|
||||
```
|
||||
❌ Verification failed
|
||||
@@ -181,28 +185,38 @@ npm run cloud
|
||||
**1. Supabase Service Key (Required for all testing):**
|
||||
|
||||
```bash
|
||||
# Get service key from environment or Kubernetes
|
||||
# Option 1: From your local environment (if available)
|
||||
export SUPABASE_SERVICE_KEY="your-supabase-service-key"
|
||||
|
||||
# Option 2: From Kubernetes secret (for platform developers)
|
||||
kubectl get secret supabase-service-key -o jsonpath='{.data.service-key}' | base64 -d
|
||||
|
||||
# Option 3: From Supabase dashboard
|
||||
# Go to Project Settings > API > service_role key (never commit this!)
|
||||
```
|
||||
|
||||
**2. Generate Pre-Authenticated Tokens (Required):**
|
||||
|
||||
```bash
|
||||
# Creates 10 tokens with 24-hour expiry - prevents auth rate limiting
|
||||
node generate-tokens.js
|
||||
# Creates 160 tokens with 24-hour expiry - prevents auth rate limiting
|
||||
node generate-tokens.js --count=160
|
||||
|
||||
# Generate more tokens for higher concurrency
|
||||
# Generate fewer tokens for smaller tests (minimum 10)
|
||||
node generate-tokens.js --count=50
|
||||
|
||||
# Regenerate when tokens expire (every 24 hours)
|
||||
node generate-tokens.js
|
||||
node generate-tokens.js --count=160
|
||||
```
|
||||
|
||||
**3. k6 Cloud Credentials (Required for cloud testing):**
|
||||
|
||||
```bash
|
||||
# Get from k6 cloud dashboard: https://app.k6.io/account/api-token
|
||||
export K6_CLOUD_TOKEN="your-k6-cloud-token"
|
||||
export K6_CLOUD_PROJECT_ID="4254406" # AutoGPT Platform project ID
|
||||
|
||||
# Verify credentials work by running orchestrator
|
||||
node orchestrator/orchestrator.js DEV cloud
|
||||
```
|
||||
|
||||
## 📂 File Structure
|
||||
@@ -210,9 +224,10 @@ export K6_CLOUD_PROJECT_ID="4254406" # AutoGPT Platform project ID
|
||||
```
|
||||
load-tests/
|
||||
├── README.md # This documentation
|
||||
├── run-tests.js # Unified test runner (MAIN ENTRY POINT)
|
||||
├── generate-tokens.js # Generate pre-auth tokens
|
||||
├── generate-tokens.js # Generate pre-auth tokens (MAIN TOKEN SETUP)
|
||||
├── package.json # Node.js dependencies and scripts
|
||||
├── orchestrator/
|
||||
│ └── orchestrator.js # Main test orchestrator (MAIN ENTRY POINT)
|
||||
├── configs/
|
||||
│ ├── environment.js # Environment URLs and configuration
|
||||
│ └── pre-authenticated-tokens.js # Generated tokens (gitignored)
|
||||
@@ -228,21 +243,19 @@ load-tests/
|
||||
│ │ └── library-access-test.js # Authenticated marketplace/library
|
||||
│ └── comprehensive/
|
||||
│ └── platform-journey-test.js # Complete user journey simulation
|
||||
├── orchestrator/
|
||||
│ └── comprehensive-orchestrator.js # Full 25-test orchestration suite
|
||||
├── results/ # Local test results (auto-created)
|
||||
├── k6-cloud-results.txt # Cloud test URLs (auto-created)
|
||||
└── *.json # Test output files (auto-created)
|
||||
├── unified-results-*.json # Orchestrator results (auto-created)
|
||||
└── *.log # Test execution logs (auto-created)
|
||||
```
|
||||
|
||||
## 🎯 Best Practices
|
||||
|
||||
1. **Start with Verification**: Always run `node run-tests.js verify` first
|
||||
2. **Local for Development**: Use `run` command for debugging and development
|
||||
3. **Cloud for Performance**: Use `cloud` command for actual performance testing
|
||||
1. **Generate Tokens First**: Always run `node generate-tokens.js --count=160` before testing
|
||||
2. **Local for Development**: Use `DEV local` for debugging and development
|
||||
3. **Cloud for Performance**: Use `DEV cloud` for actual performance testing
|
||||
4. **Monitor Real-Time**: Check k6 cloud dashboards during test execution
|
||||
5. **Regenerate Tokens**: Refresh tokens every 24 hours when they expire
|
||||
6. **Sequential Testing**: Use comma-separated tests for organized execution
|
||||
6. **Unified Testing**: Orchestrator runs 12 optimized tests automatically
|
||||
|
||||
## 🚀 Advanced Usage
|
||||
|
||||
@@ -252,8 +265,8 @@ For granular control over individual test scripts:
|
||||
|
||||
```bash
|
||||
# k6 Cloud execution (recommended for performance testing)
|
||||
K6_ENVIRONMENT=DEV VUS=100 DURATION=5m \
|
||||
k6 cloud run --env K6_ENVIRONMENT=DEV --env VUS=100 --env DURATION=5m tests/api/core-api-test.js
|
||||
# IMPORTANT: Use --env syntax for k6 cloud to ensure proper VU scaling
|
||||
k6 cloud run --env K6_ENVIRONMENT=DEV --env VUS=160 --env DURATION=5m --env RAMP_UP=30s --env RAMP_DOWN=30s tests/marketplace/library-access-test.js
|
||||
|
||||
# Local execution with cloud output (debugging)
|
||||
K6_ENVIRONMENT=DEV VUS=10 DURATION=1m \
|
||||
|
||||
@@ -1,611 +0,0 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
// AutoGPT Platform Load Test Orchestrator
|
||||
// Runs comprehensive test suite locally or in k6 cloud
|
||||
// Collects URLs, statistics, and generates reports
|
||||
|
||||
const { spawn } = require("child_process");
|
||||
const fs = require("fs");
|
||||
const path = require("path");
|
||||
|
||||
console.log("🎯 AUTOGPT PLATFORM LOAD TEST ORCHESTRATOR\n");
|
||||
console.log("===========================================\n");
|
||||
|
||||
// Parse command line arguments
|
||||
const args = process.argv.slice(2);
|
||||
const environment = args[0] || "DEV"; // LOCAL, DEV, PROD
|
||||
const executionMode = args[1] || "cloud"; // local, cloud
|
||||
const testScale = args[2] || "full"; // small, full
|
||||
|
||||
console.log(`🌍 Target Environment: ${environment}`);
|
||||
console.log(`🚀 Execution Mode: ${executionMode}`);
|
||||
console.log(`📏 Test Scale: ${testScale}`);
|
||||
|
||||
// Test scenario definitions
|
||||
const testScenarios = {
|
||||
// Small scale for validation (3 tests, ~5 minutes)
|
||||
small: [
|
||||
{
|
||||
name: "Basic_Connectivity_Test",
|
||||
file: "tests/basic/connectivity-test.js",
|
||||
vus: 5,
|
||||
duration: "30s",
|
||||
},
|
||||
{
|
||||
name: "Core_API_Quick_Test",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 10,
|
||||
duration: "1m",
|
||||
},
|
||||
{
|
||||
name: "Marketplace_Quick_Test",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 15,
|
||||
duration: "1m",
|
||||
},
|
||||
],
|
||||
|
||||
// Full comprehensive test suite (25 tests, ~2 hours)
|
||||
full: [
|
||||
// Marketplace Viewing Tests
|
||||
{
|
||||
name: "Viewing_Marketplace_Logged_Out_Day1",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 106,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Viewing_Marketplace_Logged_Out_VeryHigh",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 314,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Viewing_Marketplace_Logged_In_Day1",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 53,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Viewing_Marketplace_Logged_In_VeryHigh",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 157,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// Library Management Tests
|
||||
{
|
||||
name: "Adding_Agent_to_Library_Day1",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 32,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Adding_Agent_to_Library_VeryHigh",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 95,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Viewing_Library_Home_0_Agents_Day1",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 53,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Viewing_Library_Home_0_Agents_VeryHigh",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 157,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// Core API Tests
|
||||
{
|
||||
name: "Core_API_Load_Test",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Graph_Execution_Load_Test",
|
||||
file: "tests/api/graph-execution-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// Single API Endpoint Tests
|
||||
{
|
||||
name: "Credits_API_Single_Endpoint",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
env: { ENDPOINT: "credits", CONCURRENT_REQUESTS: 10 },
|
||||
},
|
||||
{
|
||||
name: "Graphs_API_Single_Endpoint",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
env: { ENDPOINT: "graphs", CONCURRENT_REQUESTS: 10 },
|
||||
},
|
||||
{
|
||||
name: "Blocks_API_Single_Endpoint",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
env: { ENDPOINT: "blocks", CONCURRENT_REQUESTS: 10 },
|
||||
},
|
||||
{
|
||||
name: "Executions_API_Single_Endpoint",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
env: { ENDPOINT: "executions", CONCURRENT_REQUESTS: 10 },
|
||||
},
|
||||
|
||||
// Comprehensive Platform Tests
|
||||
{
|
||||
name: "Comprehensive_Platform_Low",
|
||||
file: "tests/comprehensive/platform-journey-test.js",
|
||||
vus: 25,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Comprehensive_Platform_Medium",
|
||||
file: "tests/comprehensive/platform-journey-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Comprehensive_Platform_High",
|
||||
file: "tests/comprehensive/platform-journey-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// User Authentication Workflows
|
||||
{
|
||||
name: "User_Auth_Workflows_Day1",
|
||||
file: "tests/basic/connectivity-test.js",
|
||||
vus: 50,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "User_Auth_Workflows_VeryHigh",
|
||||
file: "tests/basic/connectivity-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// Mixed Load Tests
|
||||
{
|
||||
name: "Mixed_Load_Light",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 75,
|
||||
duration: "5m",
|
||||
},
|
||||
{
|
||||
name: "Mixed_Load_Heavy",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 200,
|
||||
duration: "5m",
|
||||
},
|
||||
|
||||
// Stress Tests
|
||||
{
|
||||
name: "Marketplace_Stress_Test",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 500,
|
||||
duration: "3m",
|
||||
},
|
||||
{
|
||||
name: "Core_API_Stress_Test",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 300,
|
||||
duration: "3m",
|
||||
},
|
||||
|
||||
// Extended Duration Tests
|
||||
{
|
||||
name: "Long_Duration_Marketplace",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 100,
|
||||
duration: "10m",
|
||||
},
|
||||
{
|
||||
name: "Long_Duration_Core_API",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 100,
|
||||
duration: "10m",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const scenarios = testScenarios[testScale];
|
||||
console.log(`📊 Running ${scenarios.length} test scenarios`);
|
||||
|
||||
// Results collection
|
||||
const results = [];
|
||||
const cloudUrls = [];
|
||||
const detailedMetrics = [];
|
||||
|
||||
// Create results directory
|
||||
const timestamp = new Date()
|
||||
.toISOString()
|
||||
.replace(/[:.]/g, "-")
|
||||
.substring(0, 16);
|
||||
const resultsDir = `results-${environment.toLowerCase()}-${executionMode}-${testScale}-${timestamp}`;
|
||||
if (!fs.existsSync(resultsDir)) {
|
||||
fs.mkdirSync(resultsDir);
|
||||
}
|
||||
|
||||
// Function to run a single test
|
||||
function runTest(scenario, testIndex) {
|
||||
return new Promise((resolve, reject) => {
|
||||
console.log(`\n🚀 Test ${testIndex}/${scenarios.length}: ${scenario.name}`);
|
||||
console.log(
|
||||
`📊 Config: ${scenario.vus} VUs × ${scenario.duration} (${executionMode} mode)`,
|
||||
);
|
||||
console.log(`📁 Script: ${scenario.file}`);
|
||||
|
||||
// Build k6 command
|
||||
let k6Command, k6Args;
|
||||
|
||||
// Determine k6 binary location
|
||||
const isInPod = fs.existsSync("/app/k6-v0.54.0-linux-amd64/k6");
|
||||
const k6Binary = isInPod ? "/app/k6-v0.54.0-linux-amd64/k6" : "k6";
|
||||
|
||||
// Build environment variables
|
||||
const envVars = [
|
||||
`K6_ENVIRONMENT=${environment}`,
|
||||
`VUS=${scenario.vus}`,
|
||||
`DURATION=${scenario.duration}`,
|
||||
`RAMP_UP=30s`,
|
||||
`RAMP_DOWN=30s`,
|
||||
`THRESHOLD_P95=60000`,
|
||||
`THRESHOLD_P99=60000`,
|
||||
];
|
||||
|
||||
// Add scenario-specific environment variables
|
||||
if (scenario.env) {
|
||||
Object.keys(scenario.env).forEach((key) => {
|
||||
envVars.push(`${key}=${scenario.env[key]}`);
|
||||
});
|
||||
}
|
||||
|
||||
// Configure command based on execution mode
|
||||
if (executionMode === "cloud") {
|
||||
k6Command = k6Binary;
|
||||
k6Args = ["cloud", "run", scenario.file];
|
||||
// Add environment variables as --env flags
|
||||
envVars.forEach((env) => {
|
||||
k6Args.push("--env", env);
|
||||
});
|
||||
} else {
|
||||
k6Command = k6Binary;
|
||||
k6Args = ["run", scenario.file];
|
||||
|
||||
// Add local output files
|
||||
const outputFile = path.join(resultsDir, `${scenario.name}.json`);
|
||||
const summaryFile = path.join(
|
||||
resultsDir,
|
||||
`${scenario.name}_summary.json`,
|
||||
);
|
||||
k6Args.push("--out", `json=${outputFile}`);
|
||||
k6Args.push("--summary-export", summaryFile);
|
||||
}
|
||||
|
||||
const startTime = Date.now();
|
||||
let testUrl = "";
|
||||
let stdout = "";
|
||||
let stderr = "";
|
||||
|
||||
console.log(`⏱️ Test started: ${new Date().toISOString()}`);
|
||||
|
||||
// Set environment variables for spawned process
|
||||
const processEnv = { ...process.env };
|
||||
envVars.forEach((env) => {
|
||||
const [key, value] = env.split("=");
|
||||
processEnv[key] = value;
|
||||
});
|
||||
|
||||
const childProcess = spawn(k6Command, k6Args, {
|
||||
env: processEnv,
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
// Handle stdout
|
||||
childProcess.stdout.on("data", (data) => {
|
||||
const output = data.toString();
|
||||
stdout += output;
|
||||
|
||||
// Extract k6 cloud URL
|
||||
if (executionMode === "cloud") {
|
||||
const urlMatch = output.match(/output:\s*(https:\/\/[^\s]+)/);
|
||||
if (urlMatch) {
|
||||
testUrl = urlMatch[1];
|
||||
console.log(`🔗 Test URL: ${testUrl}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Show progress indicators
|
||||
if (output.includes("Run [")) {
|
||||
const progressMatch = output.match(/Run\s+\[\s*(\d+)%\s*\]/);
|
||||
if (progressMatch) {
|
||||
process.stdout.write(`\r⏳ Progress: ${progressMatch[1]}%`);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Handle stderr
|
||||
childProcess.stderr.on("data", (data) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
|
||||
// Handle process completion
|
||||
childProcess.on("close", (code) => {
|
||||
const endTime = Date.now();
|
||||
const duration = Math.round((endTime - startTime) / 1000);
|
||||
|
||||
console.log(`\n⏱️ Completed in ${duration}s`);
|
||||
|
||||
if (code === 0) {
|
||||
console.log(`✅ ${scenario.name} SUCCESS`);
|
||||
|
||||
const result = {
|
||||
test: scenario.name,
|
||||
status: "SUCCESS",
|
||||
duration: `${duration}s`,
|
||||
vus: scenario.vus,
|
||||
target_duration: scenario.duration,
|
||||
url: testUrl || "N/A",
|
||||
execution_mode: executionMode,
|
||||
environment: environment,
|
||||
completed_at: new Date().toISOString(),
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
|
||||
if (testUrl) {
|
||||
cloudUrls.push(`${scenario.name}: ${testUrl}`);
|
||||
}
|
||||
|
||||
// Store detailed output for analysis
|
||||
detailedMetrics.push({
|
||||
test: scenario.name,
|
||||
stdout_lines: stdout.split("\n").length,
|
||||
stderr_lines: stderr.split("\n").length,
|
||||
has_url: !!testUrl,
|
||||
});
|
||||
|
||||
resolve(result);
|
||||
} else {
|
||||
console.error(`❌ ${scenario.name} FAILED (exit code ${code})`);
|
||||
|
||||
const result = {
|
||||
test: scenario.name,
|
||||
status: "FAILED",
|
||||
error: `Exit code ${code}`,
|
||||
duration: `${duration}s`,
|
||||
vus: scenario.vus,
|
||||
execution_mode: executionMode,
|
||||
environment: environment,
|
||||
completed_at: new Date().toISOString(),
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
reject(new Error(`Test failed with exit code ${code}`));
|
||||
}
|
||||
});
|
||||
|
||||
// Handle spawn errors
|
||||
childProcess.on("error", (error) => {
|
||||
console.error(`❌ ${scenario.name} ERROR:`, error.message);
|
||||
|
||||
results.push({
|
||||
test: scenario.name,
|
||||
status: "ERROR",
|
||||
error: error.message,
|
||||
execution_mode: executionMode,
|
||||
environment: environment,
|
||||
});
|
||||
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Main orchestration function
|
||||
async function runOrchestrator() {
|
||||
const estimatedMinutes = scenarios.length * (testScale === "small" ? 2 : 5);
|
||||
console.log(`\n🎯 Starting ${testScale} test suite on ${environment}`);
|
||||
console.log(`📈 Estimated time: ~${estimatedMinutes} minutes`);
|
||||
console.log(`🌩️ Execution: ${executionMode} mode\n`);
|
||||
|
||||
const startTime = Date.now();
|
||||
let successCount = 0;
|
||||
let failureCount = 0;
|
||||
|
||||
// Run tests sequentially
|
||||
for (let i = 0; i < scenarios.length; i++) {
|
||||
try {
|
||||
await runTest(scenarios[i], i + 1);
|
||||
successCount++;
|
||||
|
||||
// Pause between tests (avoid overwhelming k6 cloud API)
|
||||
if (i < scenarios.length - 1) {
|
||||
const pauseSeconds = testScale === "small" ? 10 : 30;
|
||||
console.log(`\n⏸️ Pausing ${pauseSeconds}s before next test...\n`);
|
||||
await new Promise((resolve) =>
|
||||
setTimeout(resolve, pauseSeconds * 1000),
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
failureCount++;
|
||||
console.log(`💥 Continuing after failure...\n`);
|
||||
|
||||
// Brief pause before continuing
|
||||
if (i < scenarios.length - 1) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 15000));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const totalTime = Math.round((Date.now() - startTime) / 1000);
|
||||
await generateReports(successCount, failureCount, totalTime);
|
||||
}
|
||||
|
||||
// Generate comprehensive reports
|
||||
async function generateReports(successCount, failureCount, totalTime) {
|
||||
console.log("\n🎉 LOAD TEST ORCHESTRATOR COMPLETE\n");
|
||||
console.log("===================================\n");
|
||||
|
||||
// Summary statistics
|
||||
const successRate = Math.round((successCount / scenarios.length) * 100);
|
||||
console.log("📊 EXECUTION SUMMARY:");
|
||||
console.log(
|
||||
`✅ Successful tests: ${successCount}/${scenarios.length} (${successRate}%)`,
|
||||
);
|
||||
console.log(`❌ Failed tests: ${failureCount}/${scenarios.length}`);
|
||||
console.log(`⏱️ Total execution time: ${Math.round(totalTime / 60)} minutes`);
|
||||
console.log(`🌍 Environment: ${environment}`);
|
||||
console.log(`🚀 Mode: ${executionMode}`);
|
||||
|
||||
// Generate CSV report
|
||||
const csvHeaders =
|
||||
"Test Name,Status,VUs,Target Duration,Actual Duration,Environment,Mode,Test URL,Error,Completed At";
|
||||
const csvRows = results.map(
|
||||
(r) =>
|
||||
`"${r.test}","${r.status}",${r.vus},"${r.target_duration || "N/A"}","${r.duration || "N/A"}","${r.environment}","${r.execution_mode}","${r.url || "N/A"}","${r.error || "None"}","${r.completed_at || "N/A"}"`,
|
||||
);
|
||||
|
||||
const csvContent = [csvHeaders, ...csvRows].join("\n");
|
||||
const csvFile = path.join(resultsDir, "orchestrator_results.csv");
|
||||
fs.writeFileSync(csvFile, csvContent);
|
||||
console.log(`\n📁 CSV Report: ${csvFile}`);
|
||||
|
||||
// Generate cloud URLs file
|
||||
if (executionMode === "cloud" && cloudUrls.length > 0) {
|
||||
const urlsContent = [
|
||||
`# AutoGPT Platform Load Test URLs`,
|
||||
`# Environment: ${environment}`,
|
||||
`# Generated: ${new Date().toISOString()}`,
|
||||
`# Dashboard: https://significantgravitas.grafana.net/a/k6-app/`,
|
||||
"",
|
||||
...cloudUrls,
|
||||
"",
|
||||
"# Direct Dashboard Access:",
|
||||
"https://significantgravitas.grafana.net/a/k6-app/",
|
||||
].join("\n");
|
||||
|
||||
const urlsFile = path.join(resultsDir, "cloud_test_urls.txt");
|
||||
fs.writeFileSync(urlsFile, urlsContent);
|
||||
console.log(`📁 Cloud URLs: ${urlsFile}`);
|
||||
}
|
||||
|
||||
// Generate detailed JSON report
|
||||
const jsonReport = {
|
||||
meta: {
|
||||
orchestrator_version: "1.0",
|
||||
environment: environment,
|
||||
execution_mode: executionMode,
|
||||
test_scale: testScale,
|
||||
total_scenarios: scenarios.length,
|
||||
generated_at: new Date().toISOString(),
|
||||
results_directory: resultsDir,
|
||||
},
|
||||
summary: {
|
||||
successful_tests: successCount,
|
||||
failed_tests: failureCount,
|
||||
success_rate: `${successRate}%`,
|
||||
total_execution_time_seconds: totalTime,
|
||||
total_execution_time_minutes: Math.round(totalTime / 60),
|
||||
},
|
||||
test_results: results,
|
||||
detailed_metrics: detailedMetrics,
|
||||
cloud_urls: cloudUrls,
|
||||
};
|
||||
|
||||
const jsonFile = path.join(resultsDir, "orchestrator_results.json");
|
||||
fs.writeFileSync(jsonFile, JSON.stringify(jsonReport, null, 2));
|
||||
console.log(`📁 JSON Report: ${jsonFile}`);
|
||||
|
||||
// Display immediate results
|
||||
if (executionMode === "cloud" && cloudUrls.length > 0) {
|
||||
console.log("\n🔗 K6 CLOUD TEST DASHBOARD URLS:");
|
||||
console.log("================================");
|
||||
cloudUrls.slice(0, 5).forEach((url) => console.log(url));
|
||||
if (cloudUrls.length > 5) {
|
||||
console.log(`... and ${cloudUrls.length - 5} more URLs in ${urlsFile}`);
|
||||
}
|
||||
console.log(
|
||||
"\n📈 Main Dashboard: https://significantgravitas.grafana.net/a/k6-app/",
|
||||
);
|
||||
}
|
||||
|
||||
console.log(`\n📂 All results saved in: ${resultsDir}/`);
|
||||
console.log("🏁 Load Test Orchestrator finished successfully!");
|
||||
}
|
||||
|
||||
// Show usage help
|
||||
function showUsage() {
|
||||
console.log("🎯 AutoGPT Platform Load Test Orchestrator\n");
|
||||
console.log(
|
||||
"Usage: node load-test-orchestrator.js [ENVIRONMENT] [MODE] [SCALE]\n",
|
||||
);
|
||||
console.log("ENVIRONMENT:");
|
||||
console.log(" LOCAL - http://localhost:8006 (local development)");
|
||||
console.log(" DEV - https://dev-api.agpt.co (development server)");
|
||||
console.log(
|
||||
" PROD - https://api.agpt.co (production - coordinate with team!)\n",
|
||||
);
|
||||
console.log("MODE:");
|
||||
console.log(" local - Run locally with JSON output files");
|
||||
console.log(" cloud - Run in k6 cloud with dashboard monitoring\n");
|
||||
console.log("SCALE:");
|
||||
console.log(" small - 3 validation tests (~5 minutes)");
|
||||
console.log(" full - 25 comprehensive tests (~2 hours)\n");
|
||||
console.log("Examples:");
|
||||
console.log(" node load-test-orchestrator.js DEV cloud small");
|
||||
console.log(" node load-test-orchestrator.js LOCAL local small");
|
||||
console.log(" node load-test-orchestrator.js DEV cloud full");
|
||||
console.log(
|
||||
" node load-test-orchestrator.js PROD cloud full # Coordinate with team!\n",
|
||||
);
|
||||
console.log("Requirements:");
|
||||
console.log(
|
||||
" - Pre-authenticated tokens generated (node generate-tokens.js)",
|
||||
);
|
||||
console.log(" - k6 installed locally or run from Kubernetes pod");
|
||||
console.log(" - For cloud mode: K6_CLOUD_TOKEN and K6_CLOUD_PROJECT_ID set");
|
||||
}
|
||||
|
||||
// Handle command line help
|
||||
if (args.includes("--help") || args.includes("-h")) {
|
||||
showUsage();
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
// Handle graceful shutdown
|
||||
process.on("SIGINT", () => {
|
||||
console.log("\n🛑 Orchestrator interrupted by user");
|
||||
console.log("📊 Generating partial results...");
|
||||
generateReports(
|
||||
results.filter((r) => r.status === "SUCCESS").length,
|
||||
results.filter((r) => r.status === "FAILED").length,
|
||||
0,
|
||||
).then(() => {
|
||||
console.log("🏃♂️ Partial results saved");
|
||||
process.exit(0);
|
||||
});
|
||||
});
|
||||
|
||||
// Start orchestrator
|
||||
if (require.main === module) {
|
||||
runOrchestrator().catch((error) => {
|
||||
console.error("💥 Orchestrator failed:", error);
|
||||
process.exit(1);
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = { runOrchestrator, testScenarios };
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user