Compare commits

..

4 Commits

Author SHA1 Message Date
Otto
9b20f4cd13 refactor: simplify ExecutionQueue docstrings and move test file
- Trim verbose BUG FIX docstring to concise 3-line note
- Remove redundant method docstrings (add, get, empty)
- Move test file to backend/data/ with proper pytest conventions
- Add note about ProcessPoolExecutor migration for future devs

Co-authored-by: Zamil Majdy <majdyz@users.noreply.github.com>
2026-02-08 16:11:35 +00:00
Nikhil Bhagat
a3d0f9cbd2 fix(backend): format test_execution_queue.py and remove unused variable 2025-12-14 19:37:29 +05:45
Nikhil Bhagat
02ddb51446 Added test_execution_queue.py and test the execution part and the test got passed 2025-12-14 19:05:14 +05:45
Nikhil Bhagat
750e096f15 fix(backend): replace multiprocessing.Manager().Queue() with queue.Queue()
ExecutionQueue was unnecessarily using multiprocessing.Manager().Queue() which
spawns a subprocess for IPC. Since ExecutionQueue is only accessed from threads
within the same process, queue.Queue() is sufficient and more efficient.

- Eliminates unnecessary subprocess spawning per graph execution
- Removes IPC overhead for queue operations
- Prevents potential resource leaks from Manager processes
- Improves scalability for concurrent graph executions
2025-12-14 19:04:14 +05:45
651 changed files with 15836 additions and 40702 deletions

View File

@@ -1,37 +0,0 @@
{
"worktreeCopyPatterns": [
".env*",
".vscode/**",
".auth/**",
".claude/**",
"autogpt_platform/.env*",
"autogpt_platform/backend/.env*",
"autogpt_platform/frontend/.env*",
"autogpt_platform/frontend/.auth/**",
"autogpt_platform/db/docker/.env*"
],
"worktreeCopyIgnores": [
"**/node_modules/**",
"**/dist/**",
"**/.git/**",
"**/Thumbs.db",
"**/.DS_Store",
"**/.next/**",
"**/__pycache__/**",
"**/.ruff_cache/**",
"**/.pytest_cache/**",
"**/*.pyc",
"**/playwright-report/**",
"**/logs/**",
"**/site/**"
],
"worktreePathTemplate": "$BASE_PATH.worktree",
"postCreateCmd": [
"cd autogpt_platform/autogpt_libs && poetry install",
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
"cd autogpt_platform/frontend && pnpm install",
"cd docs && pip install -r requirements.txt"
],
"terminalCommand": "code .",
"deleteBranchWithWorktree": false
}

View File

@@ -16,7 +16,6 @@
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
!autogpt_platform/backend/gen_prisma_types_stub.py
# Platform - Market
!autogpt_platform/market/market/

View File

@@ -74,7 +74,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js

View File

@@ -90,7 +90,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js

View File

@@ -72,7 +72,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
@@ -108,16 +108,6 @@ jobs:
# run: pnpm playwright install --with-deps chromium
# Docker setup for development environment
- name: Free up disk space
run: |
# Remove large unused tools to free disk space for Docker builds
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker system prune -af
df -h
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -134,7 +134,7 @@ jobs:
run: poetry install
- name: Generate Prisma Client
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
- id: supabase
name: Start Supabase

View File

@@ -11,7 +11,7 @@ jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v10
- uses: actions/stale@v9
with:
# operations-per-run: 5000
stale-issue-message: >

View File

@@ -61,6 +61,6 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v6
- uses: actions/labeler@v5
with:
sync-labels: true

View File

@@ -12,7 +12,6 @@ reset-db:
rm -rf db/docker/volumes/db/data
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
cd backend && poetry run gen-prisma-stub
# View logs for core services
logs-core:
@@ -34,7 +33,6 @@ init-env:
migrate:
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
cd backend && poetry run gen-prisma-stub
run-backend:
cd backend && poetry run app

View File

@@ -57,9 +57,6 @@ class APIKeySmith:
def hash_key(self, raw_key: str) -> tuple[str, str]:
"""Migrate a legacy hash to secure hash format."""
if not raw_key.startswith(self.PREFIX):
raise ValueError("Key without 'agpt_' prefix would fail validation")
salt = self._generate_salt()
hash = self._hash_key_with_salt(raw_key, salt)
return hash, salt.hex()

View File

@@ -1,25 +1,29 @@
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from .jwt_utils import bearer_jwt_auth
def add_auth_responses_to_openapi(app: FastAPI) -> None:
"""
Patch a FastAPI instance's `openapi()` method to add 401 responses
Set up custom OpenAPI schema generation that adds 401 responses
to all authenticated endpoints.
This is needed when using HTTPBearer with auto_error=False to get proper
401 responses instead of 403, but FastAPI only automatically adds security
responses when auto_error=True.
"""
# Wrap current method to allow stacking OpenAPI schema modifiers like this
wrapped_openapi = app.openapi
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = wrapped_openapi()
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# Add 401 response to all endpoints that have security requirements
for path, methods in openapi_schema["paths"].items():

View File

@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
# Generate Prisma client
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub
RUN poetry run prisma generate
FROM debian:13-slim AS server_dependencies

View File

@@ -108,7 +108,7 @@ import fastapi.testclient
import pytest
from pytest_snapshot.plugin import Snapshot
from backend.api.features.myroute import router
from backend.server.v2.myroute import router
app = fastapi.FastAPI()
app.include_router(router)
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
import fastapi
import fastapi.testclient
import pytest
from backend.api.features.myroute import router
from backend.server.v2.myroute import router
app = fastapi.FastAPI()
app.include_router(router)

View File

@@ -1,25 +0,0 @@
from fastapi import FastAPI
from backend.api.middleware.security import SecurityHeadersMiddleware
from backend.monitoring.instrumentation import instrument_fastapi
from .v1.routes import v1_router
external_api = FastAPI(
title="AutoGPT External API",
description="External API for AutoGPT integrations",
docs_url="/docs",
version="1.0",
)
external_api.add_middleware(SecurityHeadersMiddleware)
external_api.include_router(v1_router, prefix="/v1")
# Add Prometheus instrumentation
instrument_fastapi(
external_api,
service_name="external-api",
expose_endpoint=True,
endpoint="/metrics",
include_in_schema=True,
)

View File

@@ -1,107 +0,0 @@
from fastapi import HTTPException, Security, status
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
from prisma.enums import APIKeyPermission
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.auth.oauth import (
InvalidClientError,
InvalidTokenError,
OAuthAccessTokenInfo,
validate_access_token,
)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
bearer_auth = HTTPBearer(auto_error=False)
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
"""Middleware for API key authentication only"""
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
)
api_key_obj = await validate_api_key(api_key)
if not api_key_obj:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
)
return api_key_obj
async def require_access_token(
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
) -> OAuthAccessTokenInfo:
"""Middleware for OAuth access token authentication only"""
if bearer is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Authorization header",
)
try:
token_info, _ = await validate_access_token(bearer.credentials)
except (InvalidClientError, InvalidTokenError) as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
return token_info
async def require_auth(
api_key: str | None = Security(api_key_header),
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
) -> APIAuthorizationInfo:
"""
Unified authentication middleware supporting both API keys and OAuth tokens.
Supports two authentication methods, which are checked in order:
1. X-API-Key header (existing API key authentication)
2. Authorization: Bearer <token> header (OAuth access token)
Returns:
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
"""
# Try API key first
if api_key is not None:
api_key_info = await validate_api_key(api_key)
if api_key_info:
return api_key_info
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
)
# Try OAuth bearer token
if bearer is not None:
try:
token_info, _ = await validate_access_token(bearer.credentials)
return token_info
except (InvalidClientError, InvalidTokenError) as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
# No credentials provided
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication. Provide API key or access token.",
)
def require_permission(permission: APIKeyPermission):
"""
Dependency function for checking specific permissions
(works with API keys and OAuth tokens)
"""
async def check_permission(
auth: APIAuthorizationInfo = Security(require_auth),
) -> APIAuthorizationInfo:
if permission not in auth.scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permission: {permission.value}",
)
return auth
return check_permission

View File

@@ -1,340 +0,0 @@
"""Tests for analytics API endpoints."""
import json
from unittest.mock import AsyncMock, Mock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from pytest_snapshot.plugin import Snapshot
from .analytics import router as analytics_router
app = fastapi.FastAPI()
app.include_router(analytics_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module."""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
# =============================================================================
# /log_raw_metric endpoint tests
# =============================================================================
def test_log_raw_metric_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test successful raw metric logging."""
mock_result = Mock(id="metric-123-uuid")
mock_log_metric = mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
return_value=mock_result,
)
request_data = {
"metric_name": "page_load_time",
"metric_value": 2.5,
"data_string": "/dashboard",
}
response = client.post("/log_raw_metric", json=request_data)
assert response.status_code == 200, f"Unexpected response: {response.text}"
assert response.json() == "metric-123-uuid"
mock_log_metric.assert_called_once_with(
user_id=test_user_id,
metric_name="page_load_time",
metric_value=2.5,
data_string="/dashboard",
)
configured_snapshot.assert_match(
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
"analytics_log_metric_success",
)
@pytest.mark.parametrize(
"metric_value,metric_name,data_string,test_id",
[
(100, "api_calls_count", "external_api", "integer_value"),
(0, "error_count", "no_errors", "zero_value"),
(-5.2, "temperature_delta", "cooling", "negative_value"),
(1.23456789, "precision_test", "float_precision", "float_precision"),
(999999999, "large_number", "max_value", "large_number"),
(0.0000001, "tiny_number", "min_value", "tiny_number"),
],
)
def test_log_raw_metric_various_values(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
metric_value: float,
metric_name: str,
data_string: str,
test_id: str,
) -> None:
"""Test raw metric logging with various metric values."""
mock_result = Mock(id=f"metric-{test_id}-uuid")
mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
return_value=mock_result,
)
request_data = {
"metric_name": metric_name,
"metric_value": metric_value,
"data_string": data_string,
}
response = client.post("/log_raw_metric", json=request_data)
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
configured_snapshot.assert_match(
json.dumps(
{"metric_id": response.json(), "test_case": test_id},
indent=2,
sort_keys=True,
),
f"analytics_metric_{test_id}",
)
@pytest.mark.parametrize(
"invalid_data,expected_error",
[
({}, "Field required"),
({"metric_name": "test"}, "Field required"),
(
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
"Input should be a valid number",
),
(
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
"String should have at least 1 character",
),
(
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
"String should have at least 1 character",
),
],
ids=[
"empty_request",
"missing_metric_value_and_data_string",
"invalid_metric_value_type",
"empty_metric_name",
"empty_data_string",
],
)
def test_log_raw_metric_validation_errors(
invalid_data: dict,
expected_error: str,
) -> None:
"""Test validation errors for invalid metric requests."""
response = client.post("/log_raw_metric", json=invalid_data)
assert response.status_code == 422
error_detail = response.json()
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
error_text = json.dumps(error_detail)
assert (
expected_error in error_text
), f"Expected '{expected_error}' in error response: {error_text}"
def test_log_raw_metric_service_error(
mocker: pytest_mock.MockFixture,
test_user_id: str,
) -> None:
"""Test error handling when analytics service fails."""
mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
side_effect=Exception("Database connection failed"),
)
request_data = {
"metric_name": "test_metric",
"metric_value": 1.0,
"data_string": "test",
}
response = client.post("/log_raw_metric", json=request_data)
assert response.status_code == 500
error_detail = response.json()["detail"]
assert "Database connection failed" in error_detail["message"]
assert "hint" in error_detail
# =============================================================================
# /log_raw_analytics endpoint tests
# =============================================================================
def test_log_raw_analytics_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test successful raw analytics logging."""
mock_result = Mock(id="analytics-789-uuid")
mock_log_analytics = mocker.patch(
"backend.data.analytics.log_raw_analytics",
new_callable=AsyncMock,
return_value=mock_result,
)
request_data = {
"type": "user_action",
"data": {
"action": "button_click",
"button_id": "submit_form",
"timestamp": "2023-01-01T00:00:00Z",
"metadata": {"form_type": "registration", "fields_filled": 5},
},
"data_index": "button_click_submit_form",
}
response = client.post("/log_raw_analytics", json=request_data)
assert response.status_code == 200, f"Unexpected response: {response.text}"
assert response.json() == "analytics-789-uuid"
mock_log_analytics.assert_called_once_with(
test_user_id,
"user_action",
request_data["data"],
"button_click_submit_form",
)
configured_snapshot.assert_match(
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
"analytics_log_analytics_success",
)
def test_log_raw_analytics_complex_data(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test raw analytics logging with complex nested data structures."""
mock_result = Mock(id="analytics-complex-uuid")
mocker.patch(
"backend.data.analytics.log_raw_analytics",
new_callable=AsyncMock,
return_value=mock_result,
)
request_data = {
"type": "agent_execution",
"data": {
"agent_id": "agent_123",
"execution_id": "exec_456",
"status": "completed",
"duration_ms": 3500,
"nodes_executed": 15,
"blocks_used": [
{"block_id": "llm_block", "count": 3},
{"block_id": "http_block", "count": 5},
{"block_id": "code_block", "count": 2},
],
"errors": [],
"metadata": {
"trigger": "manual",
"user_tier": "premium",
"environment": "production",
},
},
"data_index": "agent_123_exec_456",
}
response = client.post("/log_raw_analytics", json=request_data)
assert response.status_code == 200
configured_snapshot.assert_match(
json.dumps(
{"analytics_id": response.json(), "logged_data": request_data["data"]},
indent=2,
sort_keys=True,
),
"analytics_log_analytics_complex_data",
)
@pytest.mark.parametrize(
"invalid_data,expected_error",
[
({}, "Field required"),
({"type": "test"}, "Field required"),
(
{"type": "test", "data": "not_a_dict", "data_index": "test"},
"Input should be a valid dictionary",
),
({"type": "test", "data": {"key": "value"}}, "Field required"),
],
ids=[
"empty_request",
"missing_data_and_data_index",
"invalid_data_type",
"missing_data_index",
],
)
def test_log_raw_analytics_validation_errors(
invalid_data: dict,
expected_error: str,
) -> None:
"""Test validation errors for invalid analytics requests."""
response = client.post("/log_raw_analytics", json=invalid_data)
assert response.status_code == 422
error_detail = response.json()
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
error_text = json.dumps(error_detail)
assert (
expected_error in error_text
), f"Expected '{expected_error}' in error response: {error_text}"
def test_log_raw_analytics_service_error(
mocker: pytest_mock.MockFixture,
test_user_id: str,
) -> None:
"""Test error handling when analytics service fails."""
mocker.patch(
"backend.data.analytics.log_raw_analytics",
new_callable=AsyncMock,
side_effect=Exception("Analytics DB unreachable"),
)
request_data = {
"type": "test_event",
"data": {"key": "value"},
"data_index": "test_index",
}
response = client.post("/log_raw_analytics", json=request_data)
assert response.status_code == 500
error_detail = response.json()["detail"]
assert "Analytics DB unreachable" in error_detail["message"]
assert "hint" in error_detail

View File

@@ -1,833 +0,0 @@
"""
OAuth 2.0 Provider Endpoints
Implements OAuth 2.0 Authorization Code flow with PKCE support.
Flow:
1. User clicks "Login with AutoGPT" in 3rd party app
2. App redirects user to /auth/authorize with client_id, redirect_uri, scope, state
3. User sees consent screen (if not already logged in, redirects to login first)
4. User approves → backend creates authorization code
5. User redirected back to app with code
6. App exchanges code for access/refresh tokens at /api/oauth/token
7. App uses access token to call external API endpoints
"""
import io
import logging
import os
import uuid
from datetime import datetime
from typing import Literal, Optional
from urllib.parse import urlencode
from autogpt_libs.auth import get_user_id
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
from gcloud.aio import storage as async_storage
from PIL import Image
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.data.auth.oauth import (
InvalidClientError,
InvalidGrantError,
OAuthApplicationInfo,
TokenIntrospectionResult,
consume_authorization_code,
create_access_token,
create_authorization_code,
create_refresh_token,
get_oauth_application,
get_oauth_application_by_id,
introspect_token,
list_user_oauth_applications,
refresh_tokens,
revoke_access_token,
revoke_refresh_token,
update_oauth_application,
validate_client_credentials,
validate_redirect_uri,
validate_scopes,
)
from backend.util.settings import Settings
from backend.util.virus_scanner import scan_content_safe
settings = Settings()
logger = logging.getLogger(__name__)
router = APIRouter()
# ============================================================================
# Request/Response Models
# ============================================================================
class TokenResponse(BaseModel):
"""OAuth 2.0 token response"""
token_type: Literal["Bearer"] = "Bearer"
access_token: str
access_token_expires_at: datetime
refresh_token: str
refresh_token_expires_at: datetime
scopes: list[str]
class ErrorResponse(BaseModel):
"""OAuth 2.0 error response"""
error: str
error_description: Optional[str] = None
class OAuthApplicationPublicInfo(BaseModel):
"""Public information about an OAuth application (for consent screen)"""
name: str
description: Optional[str] = None
logo_url: Optional[str] = None
scopes: list[str]
# ============================================================================
# Application Info Endpoint
# ============================================================================
@router.get(
"/app/{client_id}",
responses={
404: {"description": "Application not found or disabled"},
},
)
async def get_oauth_app_info(
client_id: str, user_id: str = Security(get_user_id)
) -> OAuthApplicationPublicInfo:
"""
Get public information about an OAuth application.
This endpoint is used by the consent screen to display application details
to the user before they authorize access.
Returns:
- name: Application name
- description: Application description (if provided)
- scopes: List of scopes the application is allowed to request
"""
app = await get_oauth_application(client_id)
if not app or not app.is_active:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found",
)
return OAuthApplicationPublicInfo(
name=app.name,
description=app.description,
logo_url=app.logo_url,
scopes=[s.value for s in app.scopes],
)
# ============================================================================
# Authorization Endpoint
# ============================================================================
class AuthorizeRequest(BaseModel):
"""OAuth 2.0 authorization request"""
client_id: str = Field(description="Client identifier")
redirect_uri: str = Field(description="Redirect URI")
scopes: list[str] = Field(description="List of scopes")
state: str = Field(description="Anti-CSRF token from client")
response_type: str = Field(
default="code", description="Must be 'code' for authorization code flow"
)
code_challenge: str = Field(description="PKCE code challenge (required)")
code_challenge_method: Literal["S256", "plain"] = Field(
default="S256", description="PKCE code challenge method (S256 recommended)"
)
class AuthorizeResponse(BaseModel):
"""OAuth 2.0 authorization response with redirect URL"""
redirect_url: str = Field(description="URL to redirect the user to")
@router.post("/authorize")
async def authorize(
request: AuthorizeRequest = Body(),
user_id: str = Security(get_user_id),
) -> AuthorizeResponse:
"""
OAuth 2.0 Authorization Endpoint
User must be logged in (authenticated with Supabase JWT).
This endpoint creates an authorization code and returns a redirect URL.
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
The frontend consent screen should call this endpoint after the user approves,
then redirect the user to the returned `redirect_url`.
Request Body:
- client_id: The OAuth application's client ID
- redirect_uri: Where to redirect after authorization (must match registered URI)
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
- state: Anti-CSRF token provided by client (will be returned in redirect)
- response_type: Must be "code" (for authorization code flow)
- code_challenge: PKCE code challenge (required)
- code_challenge_method: "S256" (recommended) or "plain"
Returns:
- redirect_url: The URL to redirect the user to (includes authorization code)
Error cases return a redirect_url with error parameters, or raise HTTPException
for critical errors (like invalid redirect_uri).
"""
try:
# Validate response_type
if request.response_type != "code":
return _error_redirect_url(
request.redirect_uri,
request.state,
"unsupported_response_type",
"Only 'code' response type is supported",
)
# Get application
app = await get_oauth_application(request.client_id)
if not app:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_client",
"Unknown client_id",
)
if not app.is_active:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_client",
"Application is not active",
)
# Validate redirect URI
if not validate_redirect_uri(app, request.redirect_uri):
# For invalid redirect_uri, we can't redirect safely
# Must return error instead
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"Invalid redirect_uri. "
f"Must be one of: {', '.join(app.redirect_uris)}"
),
)
# Parse and validate scopes
try:
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
except ValueError as e:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_scope",
f"Invalid scope: {e}",
)
if not requested_scopes:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_scope",
"At least one scope is required",
)
if not validate_scopes(app, requested_scopes):
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_scope",
"Application is not authorized for all requested scopes. "
f"Allowed: {', '.join(s.value for s in app.scopes)}",
)
# Create authorization code
auth_code = await create_authorization_code(
application_id=app.id,
user_id=user_id,
scopes=requested_scopes,
redirect_uri=request.redirect_uri,
code_challenge=request.code_challenge,
code_challenge_method=request.code_challenge_method,
)
# Build redirect URL with authorization code
params = {
"code": auth_code.code,
"state": request.state,
}
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
logger.info(
f"Authorization code issued for user #{user_id} "
f"and app {app.name} (#{app.id})"
)
return AuthorizeResponse(redirect_url=redirect_url)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
return _error_redirect_url(
request.redirect_uri,
request.state,
"server_error",
"An unexpected error occurred",
)
def _error_redirect_url(
redirect_uri: str,
state: str,
error: str,
error_description: Optional[str] = None,
) -> AuthorizeResponse:
"""Helper to build redirect URL with OAuth error parameters"""
params = {
"error": error,
"state": state,
}
if error_description:
params["error_description"] = error_description
redirect_url = f"{redirect_uri}?{urlencode(params)}"
return AuthorizeResponse(redirect_url=redirect_url)
# ============================================================================
# Token Endpoint
# ============================================================================
class TokenRequestByCode(BaseModel):
grant_type: Literal["authorization_code"]
code: str = Field(description="Authorization code")
redirect_uri: str = Field(
description="Redirect URI (must match authorization request)"
)
client_id: str
client_secret: str
code_verifier: str = Field(description="PKCE code verifier")
class TokenRequestByRefreshToken(BaseModel):
grant_type: Literal["refresh_token"]
refresh_token: str
client_id: str
client_secret: str
@router.post("/token")
async def token(
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
) -> TokenResponse:
"""
OAuth 2.0 Token Endpoint
Exchanges authorization code or refresh token for access token.
Grant Types:
1. authorization_code: Exchange authorization code for tokens
- Required: grant_type, code, redirect_uri, client_id, client_secret
- Optional: code_verifier (required if PKCE was used)
2. refresh_token: Exchange refresh token for new access token
- Required: grant_type, refresh_token, client_id, client_secret
Returns:
- access_token: Bearer token for API access (1 hour TTL)
- token_type: "Bearer"
- expires_in: Seconds until access token expires
- refresh_token: Token for refreshing access (30 days TTL)
- scopes: List of scopes
"""
# Validate client credentials
try:
app = await validate_client_credentials(
request.client_id, request.client_secret
)
except InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
# Handle authorization_code grant
if request.grant_type == "authorization_code":
# Consume authorization code
try:
user_id, scopes = await consume_authorization_code(
code=request.code,
application_id=app.id,
redirect_uri=request.redirect_uri,
code_verifier=request.code_verifier,
)
except InvalidGrantError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
# Create access and refresh tokens
access_token = await create_access_token(app.id, user_id, scopes)
refresh_token = await create_refresh_token(app.id, user_id, scopes)
logger.info(
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
"via authorization code"
)
if not access_token.token or not refresh_token.token:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate tokens",
)
return TokenResponse(
token_type="Bearer",
access_token=access_token.token.get_secret_value(),
access_token_expires_at=access_token.expires_at,
refresh_token=refresh_token.token.get_secret_value(),
refresh_token_expires_at=refresh_token.expires_at,
scopes=list(s.value for s in scopes),
)
# Handle refresh_token grant
elif request.grant_type == "refresh_token":
# Refresh access token
try:
new_access_token, new_refresh_token = await refresh_tokens(
request.refresh_token, app.id
)
except InvalidGrantError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
logger.info(
f"Tokens refreshed for user #{new_access_token.user_id} "
f"by app {app.name} (#{app.id})"
)
if not new_access_token.token or not new_refresh_token.token:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate tokens",
)
return TokenResponse(
token_type="Bearer",
access_token=new_access_token.token.get_secret_value(),
access_token_expires_at=new_access_token.expires_at,
refresh_token=new_refresh_token.token.get_secret_value(),
refresh_token_expires_at=new_refresh_token.expires_at,
scopes=list(s.value for s in new_access_token.scopes),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant_type: {request.grant_type}. "
"Must be 'authorization_code' or 'refresh_token'",
)
# ============================================================================
# Token Introspection Endpoint
# ============================================================================
@router.post("/introspect")
async def introspect(
token: str = Body(description="Token to introspect"),
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
None, description="Hint about token type ('access_token' or 'refresh_token')"
),
client_id: str = Body(description="Client identifier"),
client_secret: str = Body(description="Client secret"),
) -> TokenIntrospectionResult:
"""
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
Allows clients to check if a token is valid and get its metadata.
Returns:
- active: Whether the token is currently active
- scopes: List of authorized scopes (if active)
- client_id: The client the token was issued to (if active)
- user_id: The user the token represents (if active)
- exp: Expiration timestamp (if active)
- token_type: "access_token" or "refresh_token" (if active)
"""
# Validate client credentials
try:
await validate_client_credentials(client_id, client_secret)
except InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
# Introspect the token
return await introspect_token(token, token_type_hint)
# ============================================================================
# Token Revocation Endpoint
# ============================================================================
@router.post("/revoke")
async def revoke(
token: str = Body(description="Token to revoke"),
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
None, description="Hint about token type ('access_token' or 'refresh_token')"
),
client_id: str = Body(description="Client identifier"),
client_secret: str = Body(description="Client secret"),
):
"""
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
Allows clients to revoke an access or refresh token.
Note: Revoking a refresh token does NOT revoke associated access tokens.
Revoking an access token does NOT revoke the associated refresh token.
"""
# Validate client credentials
try:
app = await validate_client_credentials(client_id, client_secret)
except InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
# Try to revoke as access token first
# Note: We pass app.id to ensure the token belongs to the authenticated app
if token_type_hint != "refresh_token":
revoked = await revoke_access_token(token, app.id)
if revoked:
logger.info(
f"Access token revoked for app {app.name} (#{app.id}); "
f"user #{revoked.user_id}"
)
return {"status": "ok"}
# Try to revoke as refresh token
revoked = await revoke_refresh_token(token, app.id)
if revoked:
logger.info(
f"Refresh token revoked for app {app.name} (#{app.id}); "
f"user #{revoked.user_id}"
)
return {"status": "ok"}
# Per RFC 7009, revocation endpoint returns 200 even if token not found
# or if token belongs to a different application.
# This prevents token scanning attacks.
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
return {"status": "ok"}
# ============================================================================
# Application Management Endpoints (for app owners)
# ============================================================================
@router.get("/apps/mine")
async def list_my_oauth_apps(
user_id: str = Security(get_user_id),
) -> list[OAuthApplicationInfo]:
"""
List all OAuth applications owned by the current user.
Returns a list of OAuth applications with their details including:
- id, name, description, logo_url
- client_id (public identifier)
- redirect_uris, grant_types, scopes
- is_active status
- created_at, updated_at timestamps
Note: client_secret is never returned for security reasons.
"""
return await list_user_oauth_applications(user_id)
@router.patch("/apps/{app_id}/status")
async def update_app_status(
app_id: str,
user_id: str = Security(get_user_id),
is_active: bool = Body(description="Whether the app should be active", embed=True),
) -> OAuthApplicationInfo:
"""
Enable or disable an OAuth application.
Only the application owner can update the status.
When disabled, the application cannot be used for new authorizations
and existing access tokens will fail validation.
Returns the updated application info.
"""
updated_app = await update_oauth_application(
app_id=app_id,
owner_id=user_id,
is_active=is_active,
)
if not updated_app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found or you don't have permission to update it",
)
action = "enabled" if is_active else "disabled"
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
return updated_app
class UpdateAppLogoRequest(BaseModel):
logo_url: str = Field(description="URL of the uploaded logo image")
@router.patch("/apps/{app_id}/logo")
async def update_app_logo(
app_id: str,
request: UpdateAppLogoRequest = Body(),
user_id: str = Security(get_user_id),
) -> OAuthApplicationInfo:
"""
Update the logo URL for an OAuth application.
Only the application owner can update the logo.
The logo should be uploaded first using the media upload endpoint,
then this endpoint is called with the resulting URL.
Logo requirements:
- Must be square (1:1 aspect ratio)
- Minimum 512x512 pixels
- Maximum 2048x2048 pixels
Returns the updated application info.
"""
if (
not (app := await get_oauth_application_by_id(app_id))
or app.owner_id != user_id
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth App not found",
)
# Delete the current app logo file (if any and it's in our cloud storage)
await _delete_app_current_logo_file(app)
updated_app = await update_oauth_application(
app_id=app_id,
owner_id=user_id,
logo_url=request.logo_url,
)
if not updated_app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found or you don't have permission to update it",
)
logger.info(
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
)
return updated_app
# Logo upload constraints
LOGO_MIN_SIZE = 512
LOGO_MAX_SIZE = 2048
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
@router.post("/apps/{app_id}/logo/upload")
async def upload_app_logo(
app_id: str,
file: UploadFile,
user_id: str = Security(get_user_id),
) -> OAuthApplicationInfo:
"""
Upload a logo image for an OAuth application.
Requirements:
- Image must be square (1:1 aspect ratio)
- Minimum 512x512 pixels
- Maximum 2048x2048 pixels
- Allowed formats: JPEG, PNG, WebP
- Maximum file size: 3MB
The image is uploaded to cloud storage and the app's logoUrl is updated.
Returns the updated application info.
"""
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
if (
not (app := await get_oauth_application_by_id(app_id))
or app.owner_id != user_id
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth App not found",
)
# Check GCS configuration
if not settings.config.media_gcs_bucket_name:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Media storage is not configured",
)
# Validate content type
content_type = file.content_type
if content_type not in LOGO_ALLOWED_TYPES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
)
# Read file content
try:
file_bytes = await file.read()
except Exception as e:
logger.error(f"Error reading logo file: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to read uploaded file",
)
# Check file size
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"File too large. "
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
),
)
# Validate image dimensions
try:
image = Image.open(io.BytesIO(file_bytes))
width, height = image.size
if width != height:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Logo must be square. Got {width}x{height}",
)
if width < LOGO_MIN_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
f"Got {width}x{height}",
)
if width > LOGO_MAX_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
f"Got {width}x{height}",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error validating logo image: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid image file",
)
# Scan for viruses
filename = file.filename or "logo"
await scan_content_safe(file_bytes, filename=filename)
# Generate unique filename
file_ext = os.path.splitext(filename)[1].lower() or ".png"
unique_filename = f"{uuid.uuid4()}{file_ext}"
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
# Upload to GCS
try:
async with async_storage.Storage() as async_client:
bucket_name = settings.config.media_gcs_bucket_name
await async_client.upload(
bucket_name, storage_path, file_bytes, content_type=content_type
)
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
except Exception as e:
logger.error(f"Error uploading logo to GCS: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to upload logo",
)
# Delete the current app logo file (if any and it's in our cloud storage)
await _delete_app_current_logo_file(app)
# Update the app with the new logo URL
updated_app = await update_oauth_application(
app_id=app_id,
owner_id=user_id,
logo_url=logo_url,
)
if not updated_app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found or you don't have permission to update it",
)
logger.info(
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
)
return updated_app
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
"""
Delete the current logo file for the given app, if there is one in our cloud storage
"""
bucket_name = settings.config.media_gcs_bucket_name
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
if app.logo_url and app.logo_url.startswith(storage_base_url):
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
old_path = app.logo_url.replace(storage_base_url, "")
try:
async with async_storage.Storage() as async_client:
await async_client.delete(bucket_name, old_path)
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
except Exception as e:
# Log but don't fail - the new logo was uploaded successfully
logger.warning(
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,41 +0,0 @@
from fastapi import FastAPI
def sort_openapi(app: FastAPI) -> None:
"""
Patch a FastAPI instance's `openapi()` method to sort the endpoints,
schemas, and responses.
"""
wrapped_openapi = app.openapi
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = wrapped_openapi()
# Sort endpoints
openapi_schema["paths"] = dict(sorted(openapi_schema["paths"].items()))
# Sort endpoints -> methods
for p in openapi_schema["paths"].keys():
openapi_schema["paths"][p] = dict(
sorted(openapi_schema["paths"][p].items())
)
# Sort endpoints -> methods -> responses
for m in openapi_schema["paths"][p].keys():
openapi_schema["paths"][p][m]["responses"] = dict(
sorted(openapi_schema["paths"][p][m]["responses"].items())
)
# Sort schemas and responses as well
for k in openapi_schema["components"].keys():
openapi_schema["components"][k] = dict(
sorted(openapi_schema["components"][k].items())
)
app.openapi_schema = openapi_schema
return openapi_schema
app.openapi = custom_openapi

View File

@@ -36,10 +36,10 @@ def main(**kwargs):
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
"""
from backend.api.rest_api import AgentServer
from backend.api.ws_api import WebsocketServer
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.notifications import NotificationManager
from backend.server.rest_api import AgentServer
from backend.server.ws_api import WebsocketServer
run_processes(
DatabaseManager().set_log_level("warning"),

View File

@@ -1,7 +1,6 @@
from typing import Any
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AIBlockBase,
@@ -50,7 +49,7 @@ class AIConditionBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for evaluating the condition.",
advanced=False,
)
@@ -82,7 +81,7 @@ class AIConditionBlock(AIBlockBase):
"condition": "the input is an email address",
"yes_value": "Valid email",
"no_value": "Not an email",
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,

View File

@@ -20,7 +20,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import BlockExecutionError
from backend.util.request import Requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -247,11 +246,7 @@ class AIShortformVideoCreatorBlock(Block):
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise BlockExecutionError(
message="Video creation timed out",
block_name=self.name,
block_id=self.id,
)
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
@@ -427,11 +422,7 @@ class AIAdMakerVideoCreatorBlock(Block):
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise BlockExecutionError(
message="Video creation timed out",
block_name=self.name,
block_id=self.id,
)
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
@@ -608,11 +599,7 @@ class AIScreenshotToVideoAdBlock(Block):
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise BlockExecutionError(
message="Video creation timed out",
block_name=self.name,
block_id=self.id,
)
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(

View File

@@ -6,9 +6,6 @@ import hashlib
import hmac
import logging
from enum import Enum
from typing import cast
from prisma.types import Serializable
from backend.sdk import (
BaseWebhooksManager,
@@ -87,9 +84,7 @@ class AirtableWebhookManager(BaseWebhooksManager):
# update webhook config
await update_webhook(
webhook.id,
config=cast(
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
),
config={"base_id": base_id, "cursor": response.cursor},
)
event_type = "notification"

View File

@@ -106,10 +106,7 @@ class ConditionBlock(Block):
ComparisonOperator.LESS_THAN_OR_EQUAL: lambda a, b: a <= b,
}
try:
result = comparison_funcs[operator](value1, value2)
except Exception as e:
raise ValueError(f"Comparison failed: {e}") from e
result = comparison_funcs[operator](value1, value2)
yield "result", result

View File

@@ -182,10 +182,13 @@ class DataForSeoRelatedKeywordsBlock(Block):
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
# Handle missing key, null value, or valid list value
if isinstance(first_result, dict):
items = first_result.get("items") or []
else:
items = (
first_result.get("items", [])
if isinstance(first_result, dict)
else []
)
# Ensure items is never None
if items is None:
items = []
for item in items:
# Extract keyword_data from the item

View File

@@ -15,7 +15,6 @@ from backend.sdk import (
SchemaField,
cost,
)
from backend.util.exceptions import BlockExecutionError
from ._config import firecrawl
@@ -60,18 +59,11 @@ class FirecrawlExtractBlock(Block):
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
try:
extract_result = app.extract(
urls=input_data.urls,
prompt=input_data.prompt,
schema=input_data.output_schema,
enable_web_search=input_data.enable_web_search,
)
except Exception as e:
raise BlockExecutionError(
message=f"Extract failed: {e}",
block_name=self.name,
block_id=self.id,
) from e
extract_result = app.extract(
urls=input_data.urls,
prompt=input_data.prompt,
schema=input_data.output_schema,
enable_web_search=input_data.enable_web_search,
)
yield "data", extract_result.data

View File

@@ -19,7 +19,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import ModerationError
from backend.util.file import MediaFileType, store_media_file
TEST_CREDENTIALS = APIKeyCredentials(
@@ -154,8 +153,6 @@ class AIImageEditorBlock(Block):
),
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
user_id=user_id,
graph_exec_id=graph_exec_id,
)
yield "output_image", result
@@ -167,8 +164,6 @@ class AIImageEditorBlock(Block):
input_image_b64: Optional[str],
aspect_ratio: str,
seed: Optional[int],
user_id: str,
graph_exec_id: str,
) -> MediaFileType:
client = ReplicateClient(api_token=api_key.get_secret_value())
input_params = {
@@ -178,21 +173,11 @@ class AIImageEditorBlock(Block):
**({"seed": seed} if seed is not None else {}),
}
try:
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
model_name,
input=input_params,
wait=False,
)
except Exception as e:
if "flagged as sensitive" in str(e).lower():
raise ModerationError(
message="Content was flagged as sensitive by the model provider",
user_id=user_id,
graph_exec_id=graph_exec_id,
moderation_type="model_provider",
)
raise ValueError(f"Model execution failed: {e}") from e
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
model_name,
input=input_params,
wait=False,
)
if isinstance(output, list) and output:
output = output[0]

File diff suppressed because it is too large Load Diff

View File

@@ -1,184 +0,0 @@
"""
Shared helpers for Human-In-The-Loop (HITL) review functionality.
Used by both the dedicated HumanInTheLoopBlock and blocks that require human review.
"""
import logging
from typing import Any, Optional
from prisma.enums import ReviewStatus
from pydantic import BaseModel
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
class ReviewDecision(BaseModel):
"""Result of a review decision."""
should_proceed: bool
message: str
review_result: ReviewResult
class HITLReviewHelper:
"""Helper class for Human-In-The-Loop review operations."""
@staticmethod
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
"""Create or retrieve a human review from the database."""
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)
@staticmethod
async def update_node_execution_status(**kwargs) -> None:
"""Update the execution status of a node."""
await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
@staticmethod
async def update_review_processed_status(
node_exec_id: str, processed: bool
) -> None:
"""Update the processed status of a review."""
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)
@staticmethod
async def _handle_review_request(
input_data: Any,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
Returns:
ReviewResult if review is complete, None if waiting for human input
Raises:
Exception: If review creation or status update fails
"""
# Skip review if safe mode is disabled - return auto-approved result
if not execution_context.safe_mode:
logger.info(
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
)
return ReviewResult(
data=input_data,
status=ReviewStatus.APPROVED,
message="Auto-approved (safe mode disabled)",
processed=True,
node_exec_id=node_exec_id,
)
result = await HITLReviewHelper.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data,
message=f"Review required for {block_name} execution",
editable=editable,
)
if result is None:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
if not result.processed:
await HITLReviewHelper.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)
return result
@staticmethod
async def handle_review_decision(
input_data: Any,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
Returns:
ReviewDecision if review is complete (approved/rejected),
None if execution should pause (awaiting review)
"""
review_result = await HITLReviewHelper._handle_review_request(
input_data=input_data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=block_name,
editable=editable,
)
if review_result is None:
# Still awaiting review - return None to pause execution
return None
# Review is complete, determine outcome
should_proceed = review_result.status == ReviewStatus.APPROVED
message = review_result.message or (
"Execution approved by reviewer"
if should_proceed
else "Execution rejected by reviewer"
)
return ReviewDecision(
should_proceed=should_proceed, message=message, review_result=review_result
)

View File

@@ -1,9 +1,8 @@
import logging
from typing import Any
from typing import Any, Literal
from prisma.enums import ReviewStatus
from backend.blocks.helpers.review import HITLReviewHelper
from backend.data.block import (
Block,
BlockCategory,
@@ -12,9 +11,11 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.data.model import SchemaField
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
@@ -44,11 +45,11 @@ class HumanInTheLoopBlock(Block):
)
class Output(BlockSchemaOutput):
approved_data: Any = SchemaField(
description="The data when approved (may be modified by reviewer)"
reviewed_data: Any = SchemaField(
description="The data after human review (may be modified)"
)
rejected_data: Any = SchemaField(
description="The data when rejected (may be modified by reviewer)"
status: Literal["approved", "rejected"] = SchemaField(
description="Status of the review: 'approved' or 'rejected'"
)
review_message: str = SchemaField(
description="Any message provided by the reviewer", default=""
@@ -68,29 +69,36 @@ class HumanInTheLoopBlock(Block):
"editable": True,
},
test_output=[
("approved_data", {"name": "John Doe", "age": 30}),
("status", "approved"),
("reviewed_data", {"name": "John Doe", "age": 30}),
],
test_mock={
"handle_review_decision": lambda **kwargs: type(
"ReviewDecision",
(),
{
"should_proceed": True,
"message": "Test approval message",
"review_result": ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
message="",
processed=False,
node_exec_id="test-node-exec-id",
),
},
)(),
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
message="",
processed=False,
node_exec_id="test-node-exec-id",
),
"update_node_execution_status": lambda *_args, **_kwargs: None,
"update_review_processed_status": lambda *_args, **_kwargs: None,
},
)
async def handle_review_decision(self, **kwargs):
return await HITLReviewHelper.handle_review_decision(**kwargs)
async def get_or_create_human_review(self, **kwargs):
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)
async def update_node_execution_status(self, **kwargs):
return await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
async def update_review_processed_status(self, node_exec_id: str, processed: bool):
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)
async def run(
self,
@@ -102,38 +110,60 @@ class HumanInTheLoopBlock(Block):
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
**_kwargs,
**kwargs,
) -> BlockOutput:
if not execution_context.safe_mode:
logger.info(
f"HITL block skipping review for node {node_exec_id} - safe mode disabled"
)
yield "approved_data", input_data.data
yield "status", "approved"
yield "reviewed_data", input_data.data
yield "review_message", "Auto-approved (safe mode disabled)"
return
decision = await self.handle_review_decision(
input_data=input_data.data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=input_data.editable,
)
try:
result = await self.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data.data,
message=input_data.name,
editable=input_data.editable,
)
except Exception as e:
logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}")
raise
if decision is None:
return
if result is None:
logger.info(
f"HITL block pausing execution for node {node_exec_id} - awaiting human review"
)
try:
await self.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return
except Exception as e:
logger.error(
f"Failed to update node status for HITL block {node_exec_id}: {str(e)}"
)
raise
status = decision.review_result.status
if status == ReviewStatus.APPROVED:
yield "approved_data", decision.review_result.data
elif status == ReviewStatus.REJECTED:
yield "rejected_data", decision.review_result.data
else:
raise RuntimeError(f"Unexpected review status: {status}")
if not result.processed:
await self.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)
if decision.message:
yield "review_message", decision.message
if result.status == ReviewStatus.APPROVED:
yield "status", "approved"
yield "reviewed_data", result.data
if result.message:
yield "review_message", result.message
elif result.status == ReviewStatus.REJECTED:
yield "status", "rejected"
if result.message:
yield "review_message", result.message

View File

@@ -2,6 +2,7 @@ from enum import Enum
from typing import Any, Dict, Literal, Optional
from pydantic import SecretStr
from requests.exceptions import RequestException
from backend.data.block import (
Block,
@@ -331,8 +332,8 @@ class IdeogramModelBlock(Block):
try:
response = await Requests().post(url, headers=headers, json=data)
return response.json()["data"][0]["url"]
except Exception as e:
raise ValueError(f"Failed to fetch image with V3 endpoint: {e}") from e
except RequestException as e:
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
async def _run_model_legacy(
self,
@@ -384,8 +385,8 @@ class IdeogramModelBlock(Block):
try:
response = await Requests().post(url, headers=headers, json=data)
return response.json()["data"][0]["url"]
except Exception as e:
raise ValueError(f"Failed to fetch image with legacy endpoint: {e}") from e
except RequestException as e:
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
async def upscale_image(self, api_key: SecretStr, image_url: str):
url = "https://api.ideogram.ai/upscale"
@@ -412,5 +413,5 @@ class IdeogramModelBlock(Block):
return (response.json())["data"][0]["url"]
except Exception as e:
raise ValueError(f"Failed to upscale image: {e}") from e
except RequestException as e:
raise Exception(f"Failed to upscale image: {str(e)}")

View File

@@ -16,7 +16,6 @@ from backend.data.block import (
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
class SearchTheWebBlock(Block, GetRequest):
@@ -57,17 +56,7 @@ class SearchTheWebBlock(Block, GetRequest):
# Prepend the Jina Search URL to the encoded query
jina_search_url = f"https://s.jina.ai/{encoded_query}"
try:
results = await self.get_request(
jina_search_url, headers=headers, json=False
)
except Exception as e:
raise BlockExecutionError(
message=f"Search failed: {e}",
block_name=self.name,
block_id=self.id,
) from e
results = await self.get_request(jina_search_url, headers=headers, json=False)
# Output the search results
yield "results", results

View File

@@ -92,9 +92,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5_2 = "gpt-5.2-2025-12-11"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5 = "gpt-5-2025-08-07"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
@@ -195,9 +194,8 @@ MODEL_METADATA = {
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
@@ -305,8 +303,6 @@ MODEL_METADATA = {
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
}
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
for model in LlmModel:
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
@@ -794,7 +790,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -859,7 +855,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -1225,7 +1221,7 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -1321,7 +1317,7 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for summarizing the text.",
)
focus: str = SchemaField(
@@ -1538,7 +1534,7 @@ class AIConversationBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for the conversation.",
)
credentials: AICredentials = AICredentialsField()
@@ -1576,7 +1572,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1639,7 +1635,7 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for generating the list.",
advanced=True,
)
@@ -1696,7 +1692,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
"force_json_output": False,

File diff suppressed because it is too large Load Diff

View File

@@ -18,7 +18,6 @@ from backend.data.block import (
BlockSchemaOutput,
)
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from backend.util.exceptions import BlockExecutionError, BlockInputError
logger = logging.getLogger(__name__)
@@ -112,27 +111,9 @@ class ReplicateModelBlock(Block):
yield "status", "succeeded"
yield "model_name", input_data.model_name
except Exception as e:
error_msg = str(e)
logger.error(f"Error running Replicate model: {error_msg}")
# Input validation errors (422, 400) → BlockInputError
if (
"422" in error_msg
or "Input validation failed" in error_msg
or "400" in error_msg
):
raise BlockInputError(
message=f"Invalid model inputs: {error_msg}",
block_name=self.name,
block_id=self.id,
) from e
# Everything else → BlockExecutionError
else:
raise BlockExecutionError(
message=f"Replicate model error: {error_msg}",
block_name=self.name,
block_id=self.id,
) from e
error_msg = f"Unexpected error running Replicate model: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
"""

View File

@@ -45,16 +45,10 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
topic = input_data.topic
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
# Note: User-Agent is now automatically set by the request library
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
try:
response = await self.get_request(url, json=True)
if "extract" not in response:
raise ValueError(f"Unable to parse Wikipedia response: {response}")
yield "summary", response["extract"]
except Exception as e:
raise ValueError(f"Failed to fetch Wikipedia summary: {e}") from e
response = await self.get_request(url, json=True)
if "extract" not in response:
raise RuntimeError(f"Unable to parse Wikipedia response: {response}")
yield "summary", response["extract"]
TEST_CREDENTIALS = APIKeyCredentials(

View File

@@ -226,7 +226,7 @@ class SmartDecisionMakerBlock(Block):
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default=llm.DEFAULT_LLM_MODEL,
default=llm.LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -391,12 +391,8 @@ class SmartDecisionMakerBlock(Block):
"""
block = sink_node.block
# Use custom name from node metadata if set, otherwise fall back to block.name
custom_name = sink_node.metadata.get("customized_name")
tool_name = custom_name if custom_name else block.name
tool_function: dict[str, Any] = {
"name": SmartDecisionMakerBlock.cleanup(tool_name),
"name": SmartDecisionMakerBlock.cleanup(block.name),
"description": block.description,
}
sink_block_input_schema = block.input_schema
@@ -493,12 +489,8 @@ class SmartDecisionMakerBlock(Block):
f"Sink graph metadata not found: {graph_id} {graph_version}"
)
# Use custom name from node metadata if set, otherwise fall back to graph name
custom_name = sink_node.metadata.get("customized_name")
tool_name = custom_name if custom_name else sink_graph_meta.name
tool_function: dict[str, Any] = {
"name": SmartDecisionMakerBlock.cleanup(tool_name),
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
"description": sink_graph_meta.description,
}
@@ -983,28 +975,10 @@ class SmartDecisionMakerBlock(Block):
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
nodes_to_skip: set[str] | None = None,
**kwargs,
) -> BlockOutput:
tool_functions = await self._create_tool_node_signatures(node_id)
original_tool_count = len(tool_functions)
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
if nodes_to_skip:
tool_functions = [
tf
for tf in tool_functions
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
]
# Only raise error if we had tools but they were all filtered out
if original_tool_count > 0 and not tool_functions:
raise ValueError(
"No available tools to execute - all downstream nodes are unavailable "
"(possibly due to missing optional credentials)"
)
yield "tool_functions", json.dumps(tool_functions)
conversation_history = input_data.conversation_history or []

View File

@@ -196,15 +196,6 @@ class TestXMLParserBlockSecurity:
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
pass
async def test_rejects_text_outside_root(self):
"""Ensure parser surfaces readable errors for invalid root text."""
block = XMLParserBlock()
invalid_xml = "<root><child>value</child></root> trailing"
with pytest.raises(ValueError, match="text outside the root element"):
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
pass
class TestStoreMediaFileSecurity:
"""Test file storage security limits."""

View File

@@ -28,7 +28,7 @@ class TestLLMStatsTracking:
response = await llm.llm_call(
credentials=llm.TEST_CREDENTIALS,
llm_model=llm.DEFAULT_LLM_MODEL,
llm_model=llm.LlmModel.GPT4O,
prompt=[{"role": "user", "content": "Hello"}],
max_tokens=100,
)
@@ -65,7 +65,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore # type: ignore
)
@@ -109,7 +109,7 @@ class TestLLMStatsTracking:
# Run the block
input_data = llm.AITextGeneratorBlock.Input(
prompt="Generate text",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -170,7 +170,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
)
@@ -228,7 +228,7 @@ class TestLLMStatsTracking:
input_data = llm.AITextSummarizerBlock.Input(
text=long_text,
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=100, # Small chunks
chunk_overlap=10,
@@ -299,7 +299,7 @@ class TestLLMStatsTracking:
# Test with very short text (should only need 1 chunk + 1 final summary)
input_data = llm.AITextSummarizerBlock.Input(
text="This is a short text.",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=1000, # Large enough to avoid chunking
)
@@ -346,7 +346,7 @@ class TestLLMStatsTracking:
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
],
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -387,7 +387,7 @@ class TestLLMStatsTracking:
# Run the block
input_data = llm.AIListGeneratorBlock.Input(
focus="test items",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_retries=3,
)
@@ -469,7 +469,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test",
expected_format={"result": "desc"},
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -513,7 +513,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
style=llm.SummaryStyle.BULLET_POINTS,
)
@@ -558,7 +558,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
style=llm.SummaryStyle.BULLET_POINTS,
max_tokens=1000,
@@ -593,7 +593,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -623,7 +623,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=1000,
)
@@ -654,7 +654,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)

View File

@@ -1,246 +0,0 @@
"""
Standalone tests for pin name sanitization that can run without full backend dependencies.
These tests verify the core sanitization logic independently of the full system.
Run with: python -m pytest test_pin_sanitization_standalone.py -v
Or simply: python test_pin_sanitization_standalone.py
"""
import re
from typing import Any
# Simulate the exact cleanup function from SmartDecisionMakerBlock
def cleanup(s: str) -> str:
"""Clean up names for use as tool function names."""
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
# Simulate the key parts of parse_execution_output
def simulate_tool_routing(
emit_key: str,
sink_node_id: str,
sink_pin_name: str,
) -> bool:
"""
Simulate the routing comparison from parse_execution_output.
Returns True if routing would succeed, False otherwise.
"""
if not emit_key.startswith("tools_^_") or "_~_" not in emit_key:
return False
# Extract routing info from emit key: tools_^_{node_id}_~_{field}
selector = emit_key[8:] # Remove "tools_^_"
target_node_id, target_input_pin = selector.split("_~_", 1)
# Current (buggy) comparison - direct string comparison
return target_node_id == sink_node_id and target_input_pin == sink_pin_name
def simulate_fixed_tool_routing(
emit_key: str,
sink_node_id: str,
sink_pin_name: str,
) -> bool:
"""
Simulate the FIXED routing comparison.
The fix: sanitize sink_pin_name before comparison.
"""
if not emit_key.startswith("tools_^_") or "_~_" not in emit_key:
return False
selector = emit_key[8:]
target_node_id, target_input_pin = selector.split("_~_", 1)
# Fixed comparison - sanitize sink_pin_name
return target_node_id == sink_node_id and target_input_pin == cleanup(sink_pin_name)
class TestCleanupFunction:
"""Tests for the cleanup function."""
def test_spaces_to_underscores(self):
assert cleanup("Max Keyword Difficulty") == "max_keyword_difficulty"
def test_mixed_case_to_lowercase(self):
assert cleanup("MaxKeywordDifficulty") == "maxkeyworddifficulty"
def test_special_chars_to_underscores(self):
assert cleanup("field@name!") == "field_name_"
assert cleanup("CPC ($)") == "cpc____"
def test_preserves_valid_chars(self):
assert cleanup("valid_name-123") == "valid_name-123"
def test_empty_string(self):
assert cleanup("") == ""
def test_consecutive_spaces(self):
assert cleanup("a b") == "a___b"
def test_unicode(self):
assert cleanup("café") == "caf_"
class TestCurrentRoutingBehavior:
"""Tests demonstrating the current (buggy) routing behavior."""
def test_exact_match_works(self):
"""When names match exactly, routing works."""
emit_key = "tools_^_node-123_~_query"
assert simulate_tool_routing(emit_key, "node-123", "query") is True
def test_spaces_cause_failure(self):
"""When sink_pin has spaces, routing fails."""
sanitized = cleanup("Max Keyword Difficulty")
emit_key = f"tools_^_node-123_~_{sanitized}"
assert simulate_tool_routing(emit_key, "node-123", "Max Keyword Difficulty") is False
def test_special_chars_cause_failure(self):
"""When sink_pin has special chars, routing fails."""
sanitized = cleanup("CPC ($)")
emit_key = f"tools_^_node-123_~_{sanitized}"
assert simulate_tool_routing(emit_key, "node-123", "CPC ($)") is False
class TestFixedRoutingBehavior:
"""Tests demonstrating the fixed routing behavior."""
def test_exact_match_still_works(self):
"""When names match exactly, routing still works."""
emit_key = "tools_^_node-123_~_query"
assert simulate_fixed_tool_routing(emit_key, "node-123", "query") is True
def test_spaces_work_with_fix(self):
"""With the fix, spaces in sink_pin work."""
sanitized = cleanup("Max Keyword Difficulty")
emit_key = f"tools_^_node-123_~_{sanitized}"
assert simulate_fixed_tool_routing(emit_key, "node-123", "Max Keyword Difficulty") is True
def test_special_chars_work_with_fix(self):
"""With the fix, special chars in sink_pin work."""
sanitized = cleanup("CPC ($)")
emit_key = f"tools_^_node-123_~_{sanitized}"
assert simulate_fixed_tool_routing(emit_key, "node-123", "CPC ($)") is True
class TestBugReproduction:
"""Exact reproduction of the reported bug."""
def test_max_keyword_difficulty_bug(self):
"""
Reproduce the exact bug from the issue:
"For this agent specifically the input pin has space and unsanitized,
the frontend somehow connect without sanitizing creating a link like:
tools_^_767682f5-..._~_Max Keyword Difficulty
but what's produced by backend is
tools_^_767682f5-..._~_max_keyword_difficulty
so the tool calls go into the void"
"""
node_id = "767682f5-fake-uuid"
original_field = "Max Keyword Difficulty"
sanitized_field = cleanup(original_field)
# What backend produces (emit key)
emit_key = f"tools_^_{node_id}_~_{sanitized_field}"
assert emit_key == f"tools_^_{node_id}_~_max_keyword_difficulty"
# What frontend link has (sink_pin_name)
frontend_sink = original_field
# Current behavior: FAILS
assert simulate_tool_routing(emit_key, node_id, frontend_sink) is False
# With fix: WORKS
assert simulate_fixed_tool_routing(emit_key, node_id, frontend_sink) is True
class TestCommonFieldNamePatterns:
"""Test common field name patterns that could cause issues."""
FIELD_NAMES = [
"Max Keyword Difficulty",
"Search Volume (Monthly)",
"CPC ($)",
"User's Input",
"Target URL",
"API Response",
"Query #1",
"First Name",
"Last Name",
"Email Address",
"Phone Number",
"Total Cost ($)",
"Discount (%)",
"Created At",
"Updated At",
"Is Active",
]
def test_current_behavior_fails_for_special_names(self):
"""Current behavior fails for names with spaces/special chars."""
failed = []
for name in self.FIELD_NAMES:
sanitized = cleanup(name)
emit_key = f"tools_^_node_~_{sanitized}"
if not simulate_tool_routing(emit_key, "node", name):
failed.append(name)
# All names with spaces should fail
names_with_spaces = [n for n in self.FIELD_NAMES if " " in n or any(c in n for c in "()$%#'")]
assert set(failed) == set(names_with_spaces)
def test_fixed_behavior_works_for_all_names(self):
"""Fixed behavior works for all names."""
for name in self.FIELD_NAMES:
sanitized = cleanup(name)
emit_key = f"tools_^_node_~_{sanitized}"
assert simulate_fixed_tool_routing(emit_key, "node", name) is True, f"Failed for: {name}"
def run_tests():
"""Run all tests manually without pytest."""
import traceback
test_classes = [
TestCleanupFunction,
TestCurrentRoutingBehavior,
TestFixedRoutingBehavior,
TestBugReproduction,
TestCommonFieldNamePatterns,
]
total = 0
passed = 0
failed = 0
for test_class in test_classes:
print(f"\n{test_class.__name__}:")
instance = test_class()
for name in dir(instance):
if name.startswith("test_"):
total += 1
try:
getattr(instance, name)()
print(f"{name}")
passed += 1
except AssertionError as e:
print(f"{name}: {e}")
failed += 1
except Exception as e:
print(f"{name}: {e}")
traceback.print_exc()
failed += 1
print(f"\n{'='*50}")
print(f"Total: {total}, Passed: {passed}, Failed: {failed}")
return failed == 0
if __name__ == "__main__":
import sys
success = run_tests()
sys.exit(0 if success else 1)

View File

@@ -5,10 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.model import CreateGraph
from backend.api.rest_api import AgentServer
from backend.data.execution import ExecutionContext
from backend.data.model import ProviderName, User
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.test import SpinTestServer, wait_execution
@@ -233,7 +233,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
# Create test input
input_data = SmartDecisionMakerBlock.Input(
prompt="Should I continue with this task?",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -335,7 +335,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2, # Set retry to 2 for testing
agent_mode_max_iterations=0,
@@ -402,7 +402,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -462,7 +462,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -526,7 +526,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -648,7 +648,7 @@ async def test_smart_decision_maker_raw_response_conversion():
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
agent_mode_max_iterations=0,
@@ -722,7 +722,7 @@ async def test_smart_decision_maker_raw_response_conversion():
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Simple prompt",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -778,7 +778,7 @@ async def test_smart_decision_maker_raw_response_conversion():
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Another test",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -931,7 +931,7 @@ async def test_smart_decision_maker_agent_mode():
# Test agent mode with max_iterations = 3
input_data = SmartDecisionMakerBlock.Input(
prompt="Complete this task using tools",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
)
@@ -1020,7 +1020,7 @@ async def test_smart_decision_maker_traditional_mode_default():
# Test default behavior (traditional mode)
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0, # Traditional mode
)
@@ -1057,153 +1057,3 @@ async def test_smart_decision_maker_traditional_mode_default():
) # Should yield individual tool parameters
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
assert "conversations" in outputs
@pytest.mark.asyncio
async def test_smart_decision_maker_uses_customized_name_for_blocks():
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
from unittest.mock import MagicMock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node with customized_name in metadata
mock_node = MagicMock(spec=Node)
mock_node.id = "test-node-id"
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "input"
# Call the function directly
result = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the customized name (cleaned up)
assert result["type"] == "function"
assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version
assert result["function"]["_sink_node_id"] == "test-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_falls_back_to_block_name():
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
from unittest.mock import MagicMock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node without customized_name
mock_node = MagicMock(spec=Node)
mock_node.id = "test-node-id"
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "input"
# Call the function directly
result = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the block's default name
assert result["type"] == "function"
assert result["function"]["name"] == "storevalueblock" # Default block name cleaned
assert result["function"]["_sink_node_id"] == "test-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_uses_customized_name_for_agents():
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node with customized_name in metadata
mock_node = MagicMock(spec=Node)
mock_node.id = "test-agent-node-id"
mock_node.metadata = {"customized_name": "My Custom Agent"}
mock_node.input_default = {
"graph_id": "test-graph-id",
"graph_version": 1,
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
}
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "test_input"
# Mock the database client
mock_graph_meta = MagicMock()
mock_graph_meta.name = "Original Agent Name"
mock_graph_meta.description = "Agent description"
mock_db_client = AsyncMock()
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
result = await SmartDecisionMakerBlock._create_agent_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the customized name (cleaned up)
assert result["type"] == "function"
assert result["function"]["name"] == "my_custom_agent" # Cleaned version
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
"""Test that agent node falls back to graph name when no customized_name."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node without customized_name
mock_node = MagicMock(spec=Node)
mock_node.id = "test-agent-node-id"
mock_node.metadata = {} # No customized_name
mock_node.input_default = {
"graph_id": "test-graph-id",
"graph_version": 1,
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
}
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "test_input"
# Mock the database client
mock_graph_meta = MagicMock()
mock_graph_meta.name = "Original Agent Name"
mock_graph_meta.description = "Agent description"
mock_db_client = AsyncMock()
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
result = await SmartDecisionMakerBlock._create_agent_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the graph's default name
assert result["type"] == "function"
assert result["function"]["name"] == "original_agent_name" # Graph name cleaned
assert result["function"]["_sink_node_id"] == "test-agent-node-id"

View File

@@ -1,916 +0,0 @@
"""
Tests for SmartDecisionMaker agent mode specific failure modes.
Covers failure modes:
2. Silent Tool Failures in Agent Mode
3. Unbounded Agent Mode Iterations
10. Unbounded Agent Iterations
12. Stale Credentials in Agent Mode
13. Tool Signature Cache Invalidation
"""
import asyncio
import json
import threading
from collections import defaultdict
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from backend.blocks.smart_decision_maker import (
SmartDecisionMakerBlock,
ExecutionParams,
ToolInfo,
)
class TestSilentToolFailuresInAgentMode:
"""
Tests for Failure Mode #2: Silent Tool Failures in Agent Mode
When tool execution fails in agent mode, the error is converted to a
tool response and execution continues silently.
"""
@pytest.mark.asyncio
async def test_tool_execution_failure_converted_to_response(self):
"""
Test that tool execution failures are silently converted to responses.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
# First response: tool call
mock_tool_call = MagicMock()
mock_tool_call.id = "call_1"
mock_tool_call.function.name = "failing_tool"
mock_tool_call.function.arguments = json.dumps({"param": "value"})
mock_response_1 = MagicMock()
mock_response_1.response = None
mock_response_1.tool_calls = [mock_tool_call]
mock_response_1.prompt_tokens = 50
mock_response_1.completion_tokens = 25
mock_response_1.reasoning = None
mock_response_1.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": "call_1"}]
}
# Second response: finish after seeing error
mock_response_2 = MagicMock()
mock_response_2.response = "I encountered an error"
mock_response_2.tool_calls = []
mock_response_2.prompt_tokens = 30
mock_response_2.completion_tokens = 15
mock_response_2.reasoning = None
mock_response_2.raw_response = {"role": "assistant", "content": "I encountered an error"}
llm_call_count = 0
async def mock_llm_call(**kwargs):
nonlocal llm_call_count
llm_call_count += 1
if llm_call_count == 1:
return mock_response_1
return mock_response_2
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "failing_tool",
"_sink_node_id": "sink-node",
"_field_mapping": {"param": "param"},
"parameters": {
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
},
}
]
# Mock database client that will fail
mock_db_client = AsyncMock()
mock_db_client.get_node.side_effect = Exception("Database connection failed!")
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
input_data = SmartDecisionMakerBlock.Input(
prompt="Do something",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=5,
)
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# The execution completed (didn't crash)
assert "finished" in outputs or "conversations" in outputs
# BUG: The tool failure was silent - user doesn't know what happened
# The error was just logged and converted to a tool response
@pytest.mark.asyncio
async def test_tool_failure_causes_infinite_retry_loop(self):
"""
Test scenario where LLM keeps calling the same failing tool.
If tool fails but LLM doesn't realize it, it may keep trying.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
call_count = 0
max_calls = 10 # Limit for test
def create_tool_call_response():
mock_tool_call = MagicMock()
mock_tool_call.id = f"call_{call_count}"
mock_tool_call.function.name = "persistent_tool"
mock_tool_call.function.arguments = json.dumps({"retry": call_count})
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": f"call_{call_count}"}]
}
return mock_response
async def mock_llm_call(**kwargs):
nonlocal call_count
call_count += 1
if call_count >= max_calls:
# Eventually finish to prevent actual infinite loop in test
final = MagicMock()
final.response = "Giving up"
final.tool_calls = []
final.prompt_tokens = 10
final.completion_tokens = 5
final.reasoning = None
final.raw_response = {"role": "assistant", "content": "Giving up"}
return final
return create_tool_call_response()
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "persistent_tool",
"_sink_node_id": "sink-node",
"_field_mapping": {"retry": "retry"},
"parameters": {
"properties": {"retry": {"type": "integer"}},
"required": ["retry"],
},
},
}
]
mock_db_client = AsyncMock()
mock_db_client.get_node.side_effect = Exception("Always fails!")
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
input_data = SmartDecisionMakerBlock.Input(
prompt="Keep trying",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=-1, # Infinite mode!
)
# Use timeout to prevent actual infinite loop
try:
async with asyncio.timeout(5):
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
except asyncio.TimeoutError:
pass # Expected if we hit infinite loop
# Document that many calls were made before we gave up
assert call_count >= max_calls - 1, \
f"Expected many retries, got {call_count}"
class TestUnboundedAgentIterations:
"""
Tests for Failure Mode #3 and #10: Unbounded Agent Mode Iterations
With max_iterations = -1, the agent can run forever, consuming
unlimited tokens and compute resources.
"""
@pytest.mark.asyncio
async def test_infinite_mode_requires_llm_to_stop(self):
"""
Test that infinite mode (-1) only stops when LLM stops making tool calls.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
iterations = 0
max_test_iterations = 20
async def mock_llm_call(**kwargs):
nonlocal iterations
iterations += 1
if iterations >= max_test_iterations:
# Stop to prevent actual infinite loop
resp = MagicMock()
resp.response = "Finally done"
resp.tool_calls = []
resp.prompt_tokens = 10
resp.completion_tokens = 5
resp.reasoning = None
resp.raw_response = {"role": "assistant", "content": "Done"}
return resp
# Keep making tool calls
tool_call = MagicMock()
tool_call.id = f"call_{iterations}"
tool_call.function.name = "counter_tool"
tool_call.function.arguments = json.dumps({"count": iterations})
resp = MagicMock()
resp.response = None
resp.tool_calls = [tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": f"call_{iterations}"}]
}
return resp
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "counter_tool",
"_sink_node_id": "sink",
"_field_mapping": {"count": "count"},
"parameters": {
"properties": {"count": {"type": "integer"}},
"required": ["count"],
},
},
}
]
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block"
mock_db_client.get_node.return_value = mock_node
mock_exec_result = MagicMock()
mock_exec_result.node_exec_id = "exec-id"
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {"count": 1})
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"}
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
input_data = SmartDecisionMakerBlock.Input(
prompt="Count forever",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=-1, # INFINITE MODE
)
async with asyncio.timeout(10):
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# We ran many iterations before stopping
assert iterations == max_test_iterations
# BUG: No built-in safeguard against runaway iterations
@pytest.mark.asyncio
async def test_max_iterations_limit_enforced(self):
"""
Test that max_iterations limit is properly enforced.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
iterations = 0
async def mock_llm_call(**kwargs):
nonlocal iterations
iterations += 1
# Always make tool calls (never finish voluntarily)
tool_call = MagicMock()
tool_call.id = f"call_{iterations}"
tool_call.function.name = "endless_tool"
tool_call.function.arguments = json.dumps({})
resp = MagicMock()
resp.response = None
resp.tool_calls = [tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": f"call_{iterations}"}]
}
return resp
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "endless_tool",
"_sink_node_id": "sink",
"_field_mapping": {},
"parameters": {"properties": {}, "required": []},
},
}
]
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block"
mock_db_client.get_node.return_value = mock_node
mock_exec_result = MagicMock()
mock_exec_result.node_exec_id = "exec-id"
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
MAX_ITERATIONS = 3
input_data = SmartDecisionMakerBlock.Input(
prompt="Run forever",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=MAX_ITERATIONS,
)
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Should have stopped at max iterations
assert iterations == MAX_ITERATIONS
assert "finished" in outputs
assert "limit reached" in outputs["finished"].lower()
class TestStaleCredentialsInAgentMode:
"""
Tests for Failure Mode #12: Stale Credentials in Agent Mode
Credentials are validated once at start but can expire during
long-running agent mode executions.
"""
@pytest.mark.asyncio
async def test_credentials_not_revalidated_between_iterations(self):
"""
Test that credentials are used without revalidation in agent mode.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
credential_check_count = 0
iteration = 0
async def mock_llm_call(**kwargs):
nonlocal credential_check_count, iteration
iteration += 1
# Simulate credential check (in real code this happens in llm_call)
credential_check_count += 1
if iteration >= 3:
resp = MagicMock()
resp.response = "Done"
resp.tool_calls = []
resp.prompt_tokens = 10
resp.completion_tokens = 5
resp.reasoning = None
resp.raw_response = {"role": "assistant", "content": "Done"}
return resp
tool_call = MagicMock()
tool_call.id = f"call_{iteration}"
tool_call.function.name = "test_tool"
tool_call.function.arguments = json.dumps({})
resp = MagicMock()
resp.response = None
resp.tool_calls = [tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
}
return resp
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "sink",
"_field_mapping": {},
"parameters": {"properties": {}, "required": []},
},
}
]
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block"
mock_db_client.get_node.return_value = mock_node
mock_exec_result = MagicMock()
mock_exec_result.node_exec_id = "exec-id"
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
input_data = SmartDecisionMakerBlock.Input(
prompt="Test credentials",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=5,
)
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Credentials were checked on each LLM call but not refreshed
# If they expired mid-execution, we'd get auth errors
assert credential_check_count == iteration
@pytest.mark.asyncio
async def test_credential_expiration_mid_execution(self):
"""
Test what happens when credentials expire during agent mode.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
iteration = 0
async def mock_llm_call_with_expiration(**kwargs):
nonlocal iteration
iteration += 1
if iteration >= 3:
# Simulate credential expiration
raise Exception("401 Unauthorized: API key expired")
tool_call = MagicMock()
tool_call.id = f"call_{iteration}"
tool_call.function.name = "test_tool"
tool_call.function.arguments = json.dumps({})
resp = MagicMock()
resp.response = None
resp.tool_calls = [tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
}
return resp
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "sink",
"_field_mapping": {},
"parameters": {"properties": {}, "required": []},
},
}
]
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block"
mock_db_client.get_node.return_value = mock_node
mock_exec_result = MagicMock()
mock_exec_result.node_exec_id = "exec-id"
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call_with_expiration), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
input_data = SmartDecisionMakerBlock.Input(
prompt="Test credentials",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=10,
)
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Should have an error output
assert "error" in outputs
assert "expired" in outputs["error"].lower() or "unauthorized" in outputs["error"].lower()
class TestToolSignatureCacheInvalidation:
"""
Tests for Failure Mode #13: Tool Signature Cache Invalidation
Tool signatures are created once at the start of run() but the
graph could change during agent mode execution.
"""
@pytest.mark.asyncio
async def test_signatures_created_once_at_start(self):
"""
Test that tool signatures are only created once, not refreshed.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
signature_creation_count = 0
iteration = 0
original_create_signatures = block._create_tool_node_signatures
async def counting_create_signatures(node_id):
nonlocal signature_creation_count
signature_creation_count += 1
return [
{
"type": "function",
"function": {
"name": "tool_v1",
"_sink_node_id": "sink",
"_field_mapping": {},
"parameters": {"properties": {}, "required": []},
},
}
]
async def mock_llm_call(**kwargs):
nonlocal iteration
iteration += 1
if iteration >= 3:
resp = MagicMock()
resp.response = "Done"
resp.tool_calls = []
resp.prompt_tokens = 10
resp.completion_tokens = 5
resp.reasoning = None
resp.raw_response = {"role": "assistant", "content": "Done"}
return resp
tool_call = MagicMock()
tool_call.id = f"call_{iteration}"
tool_call.function.name = "tool_v1"
tool_call.function.arguments = json.dumps({})
resp = MagicMock()
resp.response = None
resp.tool_calls = [tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
}
return resp
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block"
mock_db_client.get_node.return_value = mock_node
mock_exec_result = MagicMock()
mock_exec_result.node_exec_id = "exec-id"
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", side_effect=counting_create_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
input_data = SmartDecisionMakerBlock.Input(
prompt="Test signatures",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=5,
)
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Signatures were only created once, even though we had multiple iterations
assert signature_creation_count == 1
assert iteration >= 3 # We had multiple iterations
@pytest.mark.asyncio
async def test_stale_signatures_cause_tool_mismatch(self):
"""
Test scenario where tool definitions change but agent uses stale signatures.
"""
# This documents the potential issue:
# 1. Agent starts with tool_v1
# 2. User modifies graph, tool becomes tool_v2
# 3. Agent still thinks tool_v1 exists
# 4. LLM calls tool_v1, but it no longer exists
# Since signatures are created once at start and never refreshed,
# any changes to the graph during execution won't be reflected.
# This is more of a documentation test - the actual fix would
# require either:
# a) Refreshing signatures periodically
# b) Locking the graph during execution
# c) Checking tool existence before each call
pass
class TestAgentModeConversationManagement:
"""Tests for conversation management in agent mode."""
@pytest.mark.asyncio
async def test_conversation_grows_with_iterations(self):
"""
Test that conversation history grows correctly with each iteration.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
iteration = 0
conversation_lengths = []
async def mock_llm_call(**kwargs):
nonlocal iteration
iteration += 1
# Record conversation length at each call
prompt = kwargs.get("prompt", [])
conversation_lengths.append(len(prompt))
if iteration >= 3:
resp = MagicMock()
resp.response = "Done"
resp.tool_calls = []
resp.prompt_tokens = 10
resp.completion_tokens = 5
resp.reasoning = None
resp.raw_response = {"role": "assistant", "content": "Done"}
return resp
tool_call = MagicMock()
tool_call.id = f"call_{iteration}"
tool_call.function.name = "test_tool"
tool_call.function.arguments = json.dumps({})
resp = MagicMock()
resp.response = None
resp.tool_calls = [tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
}
return resp
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "sink",
"_field_mapping": {},
"parameters": {"properties": {}, "required": []},
},
}
]
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block"
mock_db_client.get_node.return_value = mock_node
mock_exec_result = MagicMock()
mock_exec_result.node_exec_id = "exec-id"
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"}
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
input_data = SmartDecisionMakerBlock.Input(
prompt="Test conversation",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=5,
)
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Conversation should grow with each iteration
# Each iteration adds: assistant message + tool response
assert len(conversation_lengths) == 3
for i in range(1, len(conversation_lengths)):
assert conversation_lengths[i] > conversation_lengths[i-1], \
f"Conversation should grow: {conversation_lengths}"

View File

@@ -1,525 +0,0 @@
"""
Tests for SmartDecisionMaker concurrency issues and race conditions.
Covers failure modes:
1. Conversation History Race Condition
4. Concurrent Execution State Sharing
7. Race in Pending Tool Calls
11. Race in Pending Tool Call Retrieval
14. Concurrent State Sharing
"""
import asyncio
import json
import threading
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from backend.blocks.smart_decision_maker import (
SmartDecisionMakerBlock,
get_pending_tool_calls,
_create_tool_response,
_get_tool_requests,
_get_tool_responses,
)
class TestConversationHistoryRaceCondition:
"""
Tests for Failure Mode #1: Conversation History Race Condition
When multiple executions share conversation history, concurrent
modifications can cause data loss or corruption.
"""
def test_get_pending_tool_calls_with_concurrent_modification(self):
"""
Test that concurrent modifications to conversation history
can cause inconsistent pending tool call counts.
"""
# Shared conversation history
conversation_history = [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "toolu_1"},
{"type": "tool_use", "id": "toolu_2"},
{"type": "tool_use", "id": "toolu_3"},
]
}
]
results = []
errors = []
def reader_thread():
"""Repeatedly read pending calls."""
for _ in range(100):
try:
pending = get_pending_tool_calls(conversation_history)
results.append(len(pending))
except Exception as e:
errors.append(str(e))
def writer_thread():
"""Modify conversation while readers are active."""
for i in range(50):
# Add a tool response
conversation_history.append({
"role": "user",
"content": [{"type": "tool_result", "tool_use_id": f"toolu_{(i % 3) + 1}"}]
})
# Remove it
if len(conversation_history) > 1:
conversation_history.pop()
# Run concurrent readers and writers
threads = []
for _ in range(3):
threads.append(threading.Thread(target=reader_thread))
threads.append(threading.Thread(target=writer_thread))
for t in threads:
t.start()
for t in threads:
t.join()
# The issue: results may be inconsistent due to race conditions
# In a correct implementation, we'd expect consistent results
# Document that this CAN produce inconsistent results
assert len(results) > 0, "Should have some results"
# Note: This test documents the race condition exists
# When fixed, all results should be consistent
def test_prompt_list_mutation_race(self):
"""
Test that mutating prompt list during iteration can cause issues.
"""
prompt = []
errors = []
def appender():
for i in range(100):
prompt.append({"role": "user", "content": f"msg_{i}"})
def extender():
for i in range(100):
prompt.extend([{"role": "assistant", "content": f"resp_{i}"}])
def reader():
for _ in range(100):
try:
# Iterate while others modify
_ = [p for p in prompt if p.get("role") == "user"]
except RuntimeError as e:
# "dictionary changed size during iteration" or similar
errors.append(str(e))
threads = [
threading.Thread(target=appender),
threading.Thread(target=extender),
threading.Thread(target=reader),
]
for t in threads:
t.start()
for t in threads:
t.join()
# Document that race conditions can occur
# In production, this could cause silent data corruption
@pytest.mark.asyncio
async def test_concurrent_block_runs_share_state(self):
"""
Test that concurrent runs on same block instance can share state incorrectly.
This is Failure Mode #14: Concurrent State Sharing
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
# Track all outputs from all runs
all_outputs = []
lock = threading.Lock()
async def run_block(run_id: int):
"""Run the block with a unique run_id."""
mock_response = MagicMock()
mock_response.response = f"Response for run {run_id}"
mock_response.tool_calls = [] # No tool calls, just finish
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": f"Run {run_id}"}
mock_tool_signatures = []
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = mock_response
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
input_data = SmartDecisionMakerBlock.Input(
prompt=f"Prompt for run {run_id}",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0,
)
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id=f"graph-{run_id}",
node_id=f"node-{run_id}",
graph_exec_id=f"exec-{run_id}",
node_exec_id=f"node-exec-{run_id}",
user_id=f"user-{run_id}",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
with lock:
all_outputs.append((run_id, outputs))
# Run multiple concurrent executions
tasks = [run_block(i) for i in range(5)]
await asyncio.gather(*tasks)
# Verify each run got its own response (no cross-contamination)
for run_id, outputs in all_outputs:
if "finished" in outputs:
assert f"run {run_id}" in outputs["finished"].lower() or outputs["finished"] == f"Response for run {run_id}", \
f"Run {run_id} may have received contaminated response: {outputs}"
class TestPendingToolCallRace:
"""
Tests for Failure Mode #7 and #11: Race in Pending Tool Calls
The get_pending_tool_calls function can race with modifications
to the conversation history, causing StopIteration or incorrect counts.
"""
def test_pending_tool_calls_counter_accuracy(self):
"""Test that pending tool call counting is accurate."""
conversation = [
# Assistant makes 3 tool calls
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1"},
{"type": "tool_use", "id": "call_2"},
{"type": "tool_use", "id": "call_3"},
]
},
# User provides 1 response
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "call_1"}
]
}
]
pending = get_pending_tool_calls(conversation)
# Should have 2 pending (call_2, call_3)
assert len(pending) == 2
assert "call_2" in pending
assert "call_3" in pending
assert pending["call_2"] == 1
assert pending["call_3"] == 1
def test_pending_tool_calls_duplicate_responses(self):
"""Test handling of duplicate tool responses."""
conversation = [
{
"role": "assistant",
"content": [{"type": "tool_use", "id": "call_1"}]
},
# Duplicate responses for same call
{
"role": "user",
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
},
{
"role": "user",
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
}
]
pending = get_pending_tool_calls(conversation)
# call_1 has count -1 (1 request - 2 responses)
# Should not be in pending (count <= 0)
assert "call_1" not in pending or pending.get("call_1", 0) <= 0
def test_empty_conversation_no_pending(self):
"""Test that empty conversation has no pending calls."""
assert get_pending_tool_calls([]) == {}
assert get_pending_tool_calls(None) == {}
def test_next_iter_on_empty_dict_raises_stop_iteration(self):
"""
Document the StopIteration vulnerability.
If pending_tool_calls becomes empty between the check and
next(iter(...)), StopIteration is raised.
"""
pending = {}
# This is the pattern used in smart_decision_maker.py:1019
# if pending_tool_calls and ...:
# first_call_id = next(iter(pending_tool_calls.keys()))
with pytest.raises(StopIteration):
next(iter(pending.keys()))
# Safe pattern should be:
# first_call_id = next(iter(pending_tool_calls.keys()), None)
safe_result = next(iter(pending.keys()), None)
assert safe_result is None
class TestToolRequestResponseParsing:
"""Tests for tool request/response parsing edge cases."""
def test_get_tool_requests_openai_format(self):
"""Test parsing OpenAI format tool requests."""
entry = {
"role": "assistant",
"tool_calls": [
{"id": "call_abc123"},
{"id": "call_def456"},
]
}
requests = _get_tool_requests(entry)
assert requests == ["call_abc123", "call_def456"]
def test_get_tool_requests_anthropic_format(self):
"""Test parsing Anthropic format tool requests."""
entry = {
"role": "assistant",
"content": [
{"type": "tool_use", "id": "toolu_abc123"},
{"type": "text", "text": "Let me call this tool"},
{"type": "tool_use", "id": "toolu_def456"},
]
}
requests = _get_tool_requests(entry)
assert requests == ["toolu_abc123", "toolu_def456"]
def test_get_tool_requests_non_assistant_role(self):
"""Non-assistant roles should return empty list."""
entry = {"role": "user", "tool_calls": [{"id": "call_123"}]}
assert _get_tool_requests(entry) == []
def test_get_tool_responses_openai_format(self):
"""Test parsing OpenAI format tool responses."""
entry = {
"role": "tool",
"tool_call_id": "call_abc123",
"content": "Result"
}
responses = _get_tool_responses(entry)
assert responses == ["call_abc123"]
def test_get_tool_responses_anthropic_format(self):
"""Test parsing Anthropic format tool responses."""
entry = {
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "toolu_abc123"},
{"type": "tool_result", "tool_use_id": "toolu_def456"},
]
}
responses = _get_tool_responses(entry)
assert responses == ["toolu_abc123", "toolu_def456"]
def test_get_tool_responses_mixed_content(self):
"""Test parsing responses with mixed content types."""
entry = {
"role": "user",
"content": [
{"type": "text", "text": "Here are the results"},
{"type": "tool_result", "tool_use_id": "toolu_123"},
{"type": "image", "url": "http://example.com/img.png"},
]
}
responses = _get_tool_responses(entry)
assert responses == ["toolu_123"]
class TestConcurrentToolSignatureCreation:
"""Tests for concurrent tool signature creation."""
@pytest.mark.asyncio
async def test_concurrent_signature_creation_same_node(self):
"""
Test that concurrent signature creation for same node
doesn't cause issues.
"""
block = SmartDecisionMakerBlock()
mock_node = Mock()
mock_node.id = "test-node"
mock_node.block = Mock()
mock_node.block.name = "TestBlock"
mock_node.block.description = "Test"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={"properties": {}, "required": []}
)
mock_node.block.input_schema.get_field_schema = Mock(
return_value={"type": "string", "description": "test"}
)
mock_links = [
Mock(sink_name="field1", sink_id="test-node", source_id="source"),
Mock(sink_name="field2", sink_id="test-node", source_id="source"),
]
# Run multiple concurrent signature creations
tasks = [
block._create_block_function_signature(mock_node, mock_links)
for _ in range(10)
]
results = await asyncio.gather(*tasks)
# All results should be identical
first = results[0]
for i, result in enumerate(results[1:], 1):
assert result["function"]["name"] == first["function"]["name"], \
f"Result {i} has different name"
assert set(result["function"]["parameters"]["properties"].keys()) == \
set(first["function"]["parameters"]["properties"].keys()), \
f"Result {i} has different properties"
class TestThreadSafetyOfCleanup:
"""Tests for thread safety of cleanup function."""
def test_cleanup_is_thread_safe(self):
"""
Test that cleanup function is thread-safe.
Since it's a pure function with no shared state, it should be safe.
"""
results = {}
lock = threading.Lock()
test_inputs = [
"Max Keyword Difficulty",
"Search Volume (Monthly)",
"CPC ($)",
"Target URL",
]
def worker(input_str: str, thread_id: int):
for _ in range(100):
result = SmartDecisionMakerBlock.cleanup(input_str)
with lock:
key = f"{thread_id}_{input_str}"
if key not in results:
results[key] = set()
results[key].add(result)
threads = []
for i, input_str in enumerate(test_inputs):
for j in range(3):
t = threading.Thread(target=worker, args=(input_str, i * 3 + j))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
# Each input should produce exactly one unique output
for key, values in results.items():
assert len(values) == 1, f"Non-deterministic cleanup for {key}: {values}"
class TestAsyncConcurrencyPatterns:
"""Tests for async concurrency patterns in the block."""
@pytest.mark.asyncio
async def test_multiple_async_runs_isolation(self):
"""
Test that multiple async runs are properly isolated.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
run_count = 5
results = []
async def single_run(run_id: int):
mock_response = MagicMock()
mock_response.response = f"Unique response {run_id}"
mock_response.tool_calls = []
mock_response.prompt_tokens = 10
mock_response.completion_tokens = 5
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": f"Run {run_id}"}
# Add small random delay to increase chance of interleaving
await asyncio.sleep(0.001 * (run_id % 3))
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = mock_response
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
input_data = SmartDecisionMakerBlock.Input(
prompt=f"Prompt {run_id}",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0,
)
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id=f"g{run_id}",
node_id=f"n{run_id}",
graph_exec_id=f"e{run_id}",
node_exec_id=f"ne{run_id}",
user_id=f"u{run_id}",
graph_version=1,
execution_context=ExecutionContext(safe_mode=False),
execution_processor=MagicMock(),
):
outputs[name] = value
return run_id, outputs
# Run all concurrently
tasks = [single_run(i) for i in range(run_count)]
results = await asyncio.gather(*tasks)
# Verify isolation
for run_id, outputs in results:
if "finished" in outputs:
assert str(run_id) in outputs["finished"], \
f"Run {run_id} got wrong response: {outputs['finished']}"

View File

@@ -1,667 +0,0 @@
"""
Tests for SmartDecisionMaker conversation handling and corruption scenarios.
Covers failure modes:
6. Conversation Corruption in Error Paths
And related conversation management issues.
"""
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from backend.blocks.smart_decision_maker import (
SmartDecisionMakerBlock,
get_pending_tool_calls,
_create_tool_response,
_combine_tool_responses,
_convert_raw_response_to_dict,
_get_tool_requests,
_get_tool_responses,
)
class TestConversationCorruptionInErrorPaths:
"""
Tests for Failure Mode #6: Conversation Corruption in Error Paths
When there's a logic error (orphaned tool output), the code appends
it as a "user" message instead of proper tool response format,
violating LLM conversation structure.
"""
@pytest.mark.asyncio
async def test_orphaned_tool_output_creates_user_message(self):
"""
Test that orphaned tool output (no pending calls) creates wrong message type.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
# Response with no tool calls
mock_response = MagicMock()
mock_response.response = "No tools needed"
mock_response.tool_calls = []
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": "No tools needed"}
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = mock_response
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0,
# Orphaned tool output - no pending calls but we have output
last_tool_output={"result": "orphaned data"},
conversation_history=[], # Empty - no pending calls
)
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Check the conversation for the orphaned output handling
# The orphaned output is logged as error but may be added as user message
# This is the BUG: should not add orphaned outputs to conversation
def test_create_tool_response_anthropic_format(self):
"""Test that Anthropic format tool responses are created correctly."""
response = _create_tool_response(
"toolu_abc123",
{"result": "success"}
)
assert response["role"] == "user"
assert response["type"] == "message"
assert isinstance(response["content"], list)
assert response["content"][0]["type"] == "tool_result"
assert response["content"][0]["tool_use_id"] == "toolu_abc123"
def test_create_tool_response_openai_format(self):
"""Test that OpenAI format tool responses are created correctly."""
response = _create_tool_response(
"call_abc123",
{"result": "success"}
)
assert response["role"] == "tool"
assert response["tool_call_id"] == "call_abc123"
assert "content" in response
def test_tool_response_with_string_content(self):
"""Test tool response creation with string content."""
response = _create_tool_response(
"call_123",
"Simple string result"
)
assert response["content"] == "Simple string result"
def test_tool_response_with_complex_content(self):
"""Test tool response creation with complex JSON content."""
complex_data = {
"nested": {"key": "value"},
"list": [1, 2, 3],
"null": None,
}
response = _create_tool_response("call_123", complex_data)
# Content should be JSON string
parsed = json.loads(response["content"])
assert parsed == complex_data
class TestCombineToolResponses:
"""Tests for combining multiple tool responses."""
def test_combine_single_response_unchanged(self):
"""Test that single response is returned unchanged."""
responses = [
{
"role": "user",
"type": "message",
"content": [{"type": "tool_result", "tool_use_id": "123"}]
}
]
result = _combine_tool_responses(responses)
assert result == responses
def test_combine_multiple_anthropic_responses(self):
"""Test combining multiple Anthropic responses."""
responses = [
{
"role": "user",
"type": "message",
"content": [{"type": "tool_result", "tool_use_id": "123", "content": "a"}]
},
{
"role": "user",
"type": "message",
"content": [{"type": "tool_result", "tool_use_id": "456", "content": "b"}]
},
]
result = _combine_tool_responses(responses)
# Should be combined into single message
assert len(result) == 1
assert result[0]["role"] == "user"
assert len(result[0]["content"]) == 2
def test_combine_mixed_responses(self):
"""Test combining mixed Anthropic and OpenAI responses."""
responses = [
{
"role": "user",
"type": "message",
"content": [{"type": "tool_result", "tool_use_id": "123"}]
},
{
"role": "tool",
"tool_call_id": "call_456",
"content": "openai result"
},
]
result = _combine_tool_responses(responses)
# Anthropic response combined, OpenAI kept separate
assert len(result) == 2
def test_combine_empty_list(self):
"""Test combining empty list."""
result = _combine_tool_responses([])
assert result == []
class TestConversationHistoryValidation:
"""Tests for conversation history validation."""
def test_pending_tool_calls_basic(self):
"""Test basic pending tool call counting."""
history = [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1"},
{"type": "tool_use", "id": "call_2"},
]
}
]
pending = get_pending_tool_calls(history)
assert len(pending) == 2
assert "call_1" in pending
assert "call_2" in pending
def test_pending_tool_calls_with_responses(self):
"""Test pending calls after some responses."""
history = [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1"},
{"type": "tool_use", "id": "call_2"},
]
},
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "call_1"}
]
}
]
pending = get_pending_tool_calls(history)
assert len(pending) == 1
assert "call_2" in pending
assert "call_1" not in pending
def test_pending_tool_calls_all_responded(self):
"""Test when all tool calls have responses."""
history = [
{
"role": "assistant",
"content": [{"type": "tool_use", "id": "call_1"}]
},
{
"role": "user",
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
}
]
pending = get_pending_tool_calls(history)
assert len(pending) == 0
def test_pending_tool_calls_openai_format(self):
"""Test pending calls with OpenAI format."""
history = [
{
"role": "assistant",
"tool_calls": [
{"id": "call_1"},
{"id": "call_2"},
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "result"
}
]
pending = get_pending_tool_calls(history)
assert len(pending) == 1
assert "call_2" in pending
class TestConversationUpdateBehavior:
"""Tests for conversation update behavior."""
@pytest.mark.asyncio
async def test_conversation_includes_assistant_response(self):
"""Test that assistant responses are added to conversation."""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
mock_response = MagicMock()
mock_response.response = "Final answer"
mock_response.tool_calls = []
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": "Final answer"}
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = mock_response
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0,
)
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# No conversations output when no tool calls (just finished)
assert "finished" in outputs
assert outputs["finished"] == "Final answer"
@pytest.mark.asyncio
async def test_conversation_with_tool_calls(self):
"""Test that tool calls are properly added to conversation."""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
mock_tool_call = MagicMock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({"param": "value"})
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = "I'll use the test tool"
mock_response.raw_response = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_1"}]
}
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "sink",
"_field_mapping": {"param": "param"},
"parameters": {
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
},
}
]
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = mock_response
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0,
)
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Should have conversations output
assert "conversations" in outputs
# Conversation should include the assistant message
conversations = outputs["conversations"]
has_assistant = any(
msg.get("role") == "assistant"
for msg in conversations
)
assert has_assistant
class TestConversationHistoryPreservation:
"""Tests for conversation history preservation across calls."""
@pytest.mark.asyncio
async def test_existing_history_preserved(self):
"""Test that existing conversation history is preserved."""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
existing_history = [
{"role": "user", "content": "Previous message 1"},
{"role": "assistant", "content": "Previous response 1"},
{"role": "user", "content": "Previous message 2"},
]
mock_response = MagicMock()
mock_response.response = "New response"
mock_response.tool_calls = []
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": "New response"}
captured_prompt = []
async def capture_llm_call(**kwargs):
captured_prompt.extend(kwargs.get("prompt", []))
return mock_response
with patch("backend.blocks.llm.llm_call", side_effect=capture_llm_call):
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
input_data = SmartDecisionMakerBlock.Input(
prompt="New message",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0,
conversation_history=existing_history,
)
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
async for _ in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
pass
# Existing history should be in the prompt
assert len(captured_prompt) >= len(existing_history)
class TestRawResponseConversion:
"""Tests for raw response to dict conversion."""
def test_string_response(self):
"""Test conversion of string response."""
result = _convert_raw_response_to_dict("Hello world")
assert result == {"role": "assistant", "content": "Hello world"}
def test_dict_response(self):
"""Test that dict response is passed through."""
original = {"role": "assistant", "content": "test", "extra": "data"}
result = _convert_raw_response_to_dict(original)
assert result == original
def test_object_response(self):
"""Test conversion of object response."""
mock_obj = MagicMock()
with patch("backend.blocks.smart_decision_maker.json.to_dict") as mock_to_dict:
mock_to_dict.return_value = {"role": "assistant", "content": "converted"}
result = _convert_raw_response_to_dict(mock_obj)
mock_to_dict.assert_called_once_with(mock_obj)
assert result["role"] == "assistant"
class TestConversationMessageStructure:
"""Tests for correct conversation message structure."""
def test_system_message_not_duplicated(self):
"""Test that system messages are not duplicated."""
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
# Existing system message in history
existing_history = [
{"role": "system", "content": f"{MAIN_OBJECTIVE_PREFIX}Existing system prompt"},
]
# The block should not add another system message
# This is verified by checking the prompt passed to LLM
def test_user_message_not_duplicated(self):
"""Test that user messages are not duplicated."""
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
# Existing user message with MAIN_OBJECTIVE_PREFIX
existing_history = [
{"role": "user", "content": f"{MAIN_OBJECTIVE_PREFIX}Existing user prompt"},
]
# The block should not add another user message with same prefix
# This is verified by checking the prompt passed to LLM
def test_tool_response_after_tool_call(self):
"""Test that tool responses come after tool calls."""
# Valid conversation structure
valid_history = [
{
"role": "assistant",
"content": [{"type": "tool_use", "id": "call_1"}]
},
{
"role": "user",
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
}
]
# This should be valid - tool result follows tool use
pending = get_pending_tool_calls(valid_history)
assert len(pending) == 0
def test_orphaned_tool_response_detected(self):
"""Test detection of orphaned tool responses."""
# Invalid: tool response without matching tool call
invalid_history = [
{
"role": "user",
"content": [{"type": "tool_result", "tool_use_id": "orphan_call"}]
}
]
pending = get_pending_tool_calls(invalid_history)
# Orphan response creates negative count
# Should have count -1 for orphan_call
# But it's filtered out (count <= 0)
assert "orphan_call" not in pending
class TestValidationErrorInConversation:
"""Tests for validation error handling in conversation."""
@pytest.mark.asyncio
async def test_validation_error_feedback_not_in_final_conversation(self):
"""
Test that validation error feedback is not in final conversation output.
When retrying due to validation errors, the error feedback should
only be used for the retry prompt, not persisted in final conversation.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
call_count = 0
async def mock_llm_call(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# First call: invalid tool call
mock_tool_call = MagicMock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({"wrong": "param"})
resp = MagicMock()
resp.response = None
resp.tool_calls = [mock_tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {"role": "assistant", "content": None}
return resp
else:
# Second call: finish
resp = MagicMock()
resp.response = "Done"
resp.tool_calls = []
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {"role": "assistant", "content": "Done"}
return resp
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "sink",
"_field_mapping": {"correct": "correct"},
"parameters": {
"properties": {"correct": {"type": "string"}},
"required": ["correct"],
},
},
}
]
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call):
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0,
retry=3,
)
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Should have finished successfully after retry
assert "finished" in outputs
# Note: In traditional mode (agent_mode_max_iterations=0),
# conversations are only output when there are tool calls
# After the retry succeeds with no tool calls, we just get "finished"

View File

@@ -1,671 +0,0 @@
"""
Tests for SmartDecisionMaker data integrity failure modes.
Covers failure modes:
6. Conversation Corruption in Error Paths
7. Field Name Collision Not Detected
8. No Type Validation in Dynamic Field Merging
9. Unhandled Field Mapping Keys
16. Silent Value Loss in Output Routing
"""
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
class TestFieldNameCollisionDetection:
"""
Tests for Failure Mode #7: Field Name Collision Not Detected
When multiple field names sanitize to the same value,
the last one silently overwrites previous mappings.
"""
def test_different_names_same_sanitized_result(self):
"""Test that different names can produce the same sanitized result."""
cleanup = SmartDecisionMakerBlock.cleanup
# All these sanitize to "test_field"
variants = [
"test_field",
"Test Field",
"test field",
"TEST_FIELD",
"Test_Field",
"test-field", # Note: hyphen is preserved, this is different
]
sanitized = [cleanup(v) for v in variants]
# Count unique sanitized values
unique = set(sanitized)
# Most should collide (except hyphenated one)
assert len(unique) < len(variants), \
f"Expected collisions, got {unique}"
@pytest.mark.asyncio
async def test_collision_last_one_wins(self):
"""Test that in case of collision, the last field mapping wins."""
block = SmartDecisionMakerBlock()
mock_node = Mock()
mock_node.id = "test-node"
mock_node.block = Mock()
mock_node.block.name = "TestBlock"
mock_node.block.description = "Test"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={"properties": {}, "required": []}
)
mock_node.block.input_schema.get_field_schema = Mock(
return_value={"type": "string", "description": "test"}
)
# Two fields that sanitize to the same name
mock_links = [
Mock(sink_name="Test Field", sink_id="test-node", source_id="source"),
Mock(sink_name="test field", sink_id="test-node", source_id="source"),
]
signature = await block._create_block_function_signature(mock_node, mock_links)
field_mapping = signature["function"]["_field_mapping"]
properties = signature["function"]["parameters"]["properties"]
# Only one property (collision)
assert len(properties) == 1
assert "test_field" in properties
# The mapping has only the last one
# This is the BUG: first field's mapping is lost
assert field_mapping["test_field"] in ["Test Field", "test field"]
@pytest.mark.asyncio
async def test_collision_causes_data_loss(self):
"""
Test that field collision can cause actual data loss.
Scenario:
1. Two fields "Field A" and "field a" both map to "field_a"
2. LLM provides value for "field_a"
3. Only one original field gets the value
4. The other field's expected input is lost
"""
block = SmartDecisionMakerBlock()
# Simulate processing tool calls with collision
mock_response = Mock()
mock_tool_call = Mock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({
"field_a": "value_for_both" # LLM uses sanitized name
})
mock_response.tool_calls = [mock_tool_call]
# Tool definition with collision in field mapping
tool_functions = [
{
"type": "function",
"function": {
"name": "test_tool",
"parameters": {
"properties": {
"field_a": {"type": "string"},
},
"required": ["field_a"],
},
"_sink_node_id": "sink",
# BUG: Only one original name is stored
# "Field A" was overwritten by "field a"
"_field_mapping": {"field_a": "field a"},
},
}
]
processed = block._process_tool_calls(mock_response, tool_functions)
assert len(processed) == 1
input_data = processed[0].input_data
# Only "field a" gets the value
assert "field a" in input_data
assert input_data["field a"] == "value_for_both"
# "Field A" is completely lost!
assert "Field A" not in input_data
class TestUnhandledFieldMappingKeys:
"""
Tests for Failure Mode #9: Unhandled Field Mapping Keys
When field_mapping is missing a key, the code falls back to
the clean name, which may not be what the sink expects.
"""
@pytest.mark.asyncio
async def test_missing_field_mapping_falls_back_to_clean_name(self):
"""Test that missing field mapping falls back to clean name."""
block = SmartDecisionMakerBlock()
mock_response = Mock()
mock_tool_call = Mock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({
"unmapped_field": "value"
})
mock_response.tool_calls = [mock_tool_call]
# Tool definition with incomplete field mapping
tool_functions = [
{
"type": "function",
"function": {
"name": "test_tool",
"parameters": {
"properties": {
"unmapped_field": {"type": "string"},
},
"required": [],
},
"_sink_node_id": "sink",
"_field_mapping": {}, # Empty! No mapping for unmapped_field
},
}
]
processed = block._process_tool_calls(mock_response, tool_functions)
assert len(processed) == 1
input_data = processed[0].input_data
# Falls back to clean name (which IS the key since it's already clean)
assert "unmapped_field" in input_data
@pytest.mark.asyncio
async def test_partial_field_mapping(self):
"""Test behavior with partial field mapping."""
block = SmartDecisionMakerBlock()
mock_response = Mock()
mock_tool_call = Mock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({
"mapped_field": "value1",
"unmapped_field": "value2",
})
mock_response.tool_calls = [mock_tool_call]
tool_functions = [
{
"type": "function",
"function": {
"name": "test_tool",
"parameters": {
"properties": {
"mapped_field": {"type": "string"},
"unmapped_field": {"type": "string"},
},
"required": [],
},
"_sink_node_id": "sink",
# Only one field is mapped
"_field_mapping": {
"mapped_field": "Original Mapped Field",
},
},
}
]
processed = block._process_tool_calls(mock_response, tool_functions)
assert len(processed) == 1
input_data = processed[0].input_data
# Mapped field uses original name
assert "Original Mapped Field" in input_data
# Unmapped field uses clean name (fallback)
assert "unmapped_field" in input_data
class TestSilentValueLossInRouting:
"""
Tests for Failure Mode #16: Silent Value Loss in Output Routing
When routing fails in parse_execution_output, it returns None
without any logging or indication of why it failed.
"""
def test_routing_mismatch_returns_none_silently(self):
"""Test that routing mismatch returns None without error."""
from backend.data.dynamic_fields import parse_execution_output
output_item = ("tools_^_node-123_~_sanitized_name", "important_value")
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name="Original Name", # Doesn't match sanitized_name
)
# Silently returns None
assert result is None
# No way to distinguish "value is None" from "routing failed"
def test_wrong_node_id_returns_none(self):
"""Test that wrong node ID returns None."""
from backend.data.dynamic_fields import parse_execution_output
output_item = ("tools_^_node-123_~_field", "value")
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="different-node", # Wrong node
sink_pin_name="field",
)
assert result is None
def test_wrong_selector_returns_none(self):
"""Test that wrong selector returns None."""
from backend.data.dynamic_fields import parse_execution_output
output_item = ("tools_^_node-123_~_field", "value")
result = parse_execution_output(
output_item,
link_output_selector="different_selector", # Wrong selector
sink_node_id="node-123",
sink_pin_name="field",
)
assert result is None
def test_cannot_distinguish_none_value_from_routing_failure(self):
"""
Test that None as actual value is indistinguishable from routing failure.
"""
from backend.data.dynamic_fields import parse_execution_output
# Case 1: Actual None value
output_with_none = ("field_name", None)
result1 = parse_execution_output(
output_with_none,
link_output_selector="field_name",
sink_node_id=None,
sink_pin_name=None,
)
# Case 2: Routing failure
output_mismatched = ("field_name", "value")
result2 = parse_execution_output(
output_mismatched,
link_output_selector="different_field",
sink_node_id=None,
sink_pin_name=None,
)
# Both return None - cannot distinguish!
assert result1 is None
assert result2 is None
class TestProcessToolCallsInputData:
"""Tests for _process_tool_calls input data generation."""
@pytest.mark.asyncio
async def test_all_expected_args_included(self):
"""Test that all expected arguments are included in input_data."""
block = SmartDecisionMakerBlock()
mock_response = Mock()
mock_tool_call = Mock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({
"provided_field": "value",
# optional_field not provided
})
mock_response.tool_calls = [mock_tool_call]
tool_functions = [
{
"type": "function",
"function": {
"name": "test_tool",
"parameters": {
"properties": {
"provided_field": {"type": "string"},
"optional_field": {"type": "string"},
},
"required": ["provided_field"],
},
"_sink_node_id": "sink",
"_field_mapping": {
"provided_field": "Provided Field",
"optional_field": "Optional Field",
},
},
}
]
processed = block._process_tool_calls(mock_response, tool_functions)
assert len(processed) == 1
input_data = processed[0].input_data
# Both fields should be in input_data
assert "Provided Field" in input_data
assert "Optional Field" in input_data
# Provided has value, optional is None
assert input_data["Provided Field"] == "value"
assert input_data["Optional Field"] is None
@pytest.mark.asyncio
async def test_extra_args_from_llm_ignored(self):
"""Test that extra arguments from LLM not in schema are ignored."""
block = SmartDecisionMakerBlock()
mock_response = Mock()
mock_tool_call = Mock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({
"expected_field": "value",
"unexpected_field": "should_be_ignored",
})
mock_response.tool_calls = [mock_tool_call]
tool_functions = [
{
"type": "function",
"function": {
"name": "test_tool",
"parameters": {
"properties": {
"expected_field": {"type": "string"},
# unexpected_field not in schema
},
"required": [],
},
"_sink_node_id": "sink",
"_field_mapping": {"expected_field": "Expected Field"},
},
}
]
processed = block._process_tool_calls(mock_response, tool_functions)
assert len(processed) == 1
input_data = processed[0].input_data
# Only expected field should be in input_data
assert "Expected Field" in input_data
assert "unexpected_field" not in input_data
assert "Unexpected Field" not in input_data
class TestToolCallMatching:
"""Tests for tool call matching logic."""
@pytest.mark.asyncio
async def test_tool_not_found_skipped(self):
"""Test that tool calls for unknown tools are skipped."""
block = SmartDecisionMakerBlock()
mock_response = Mock()
mock_tool_call = Mock()
mock_tool_call.function.name = "unknown_tool"
mock_tool_call.function.arguments = json.dumps({})
mock_response.tool_calls = [mock_tool_call]
tool_functions = [
{
"type": "function",
"function": {
"name": "known_tool", # Different name
"parameters": {"properties": {}, "required": []},
"_sink_node_id": "sink",
},
}
]
processed = block._process_tool_calls(mock_response, tool_functions)
# Unknown tool is skipped (not processed)
assert len(processed) == 0
@pytest.mark.asyncio
async def test_single_tool_fallback(self):
"""Test fallback when only one tool exists but name doesn't match."""
block = SmartDecisionMakerBlock()
mock_response = Mock()
mock_tool_call = Mock()
mock_tool_call.function.name = "wrong_name"
mock_tool_call.function.arguments = json.dumps({"field": "value"})
mock_response.tool_calls = [mock_tool_call]
# Only one tool defined
tool_functions = [
{
"type": "function",
"function": {
"name": "only_tool",
"parameters": {
"properties": {"field": {"type": "string"}},
"required": [],
},
"_sink_node_id": "sink",
"_field_mapping": {"field": "Field"},
},
}
]
processed = block._process_tool_calls(mock_response, tool_functions)
# Falls back to the only tool
assert len(processed) == 1
assert processed[0].input_data["Field"] == "value"
@pytest.mark.asyncio
async def test_multiple_tool_calls_processed(self):
"""Test that multiple tool calls are all processed."""
block = SmartDecisionMakerBlock()
mock_response = Mock()
mock_tool_call_1 = Mock()
mock_tool_call_1.function.name = "tool_a"
mock_tool_call_1.function.arguments = json.dumps({"a": "1"})
mock_tool_call_2 = Mock()
mock_tool_call_2.function.name = "tool_b"
mock_tool_call_2.function.arguments = json.dumps({"b": "2"})
mock_response.tool_calls = [mock_tool_call_1, mock_tool_call_2]
tool_functions = [
{
"type": "function",
"function": {
"name": "tool_a",
"parameters": {
"properties": {"a": {"type": "string"}},
"required": [],
},
"_sink_node_id": "sink_a",
"_field_mapping": {"a": "A"},
},
},
{
"type": "function",
"function": {
"name": "tool_b",
"parameters": {
"properties": {"b": {"type": "string"}},
"required": [],
},
"_sink_node_id": "sink_b",
"_field_mapping": {"b": "B"},
},
},
]
processed = block._process_tool_calls(mock_response, tool_functions)
assert len(processed) == 2
assert processed[0].input_data["A"] == "1"
assert processed[1].input_data["B"] == "2"
class TestOutputEmitKeyGeneration:
"""Tests for output emit key generation consistency."""
def test_emit_key_uses_sanitized_field_name(self):
"""Test that emit keys use sanitized field names."""
cleanup = SmartDecisionMakerBlock.cleanup
original_field = "Max Keyword Difficulty"
sink_node_id = "node-123"
sanitized = cleanup(original_field)
emit_key = f"tools_^_{sink_node_id}_~_{sanitized}"
assert emit_key == "tools_^_node-123_~_max_keyword_difficulty"
def test_emit_key_format_consistent(self):
"""Test that emit key format is consistent."""
test_cases = [
("field", "node", "tools_^_node_~_field"),
("Field Name", "node-123", "tools_^_node-123_~_field_name"),
("CPC ($)", "abc", "tools_^_abc_~_cpc____"),
]
cleanup = SmartDecisionMakerBlock.cleanup
for original_field, node_id, expected in test_cases:
sanitized = cleanup(original_field)
emit_key = f"tools_^_{node_id}_~_{sanitized}"
assert emit_key == expected, \
f"Expected {expected}, got {emit_key}"
def test_emit_key_sanitization_idempotent(self):
"""Test that sanitizing an already sanitized name gives same result."""
cleanup = SmartDecisionMakerBlock.cleanup
original = "Test Field Name"
first_clean = cleanup(original)
second_clean = cleanup(first_clean)
assert first_clean == second_clean
class TestToolFunctionMetadata:
"""Tests for tool function metadata handling."""
@pytest.mark.asyncio
async def test_sink_node_id_preserved(self):
"""Test that _sink_node_id is preserved in tool function."""
block = SmartDecisionMakerBlock()
mock_node = Mock()
mock_node.id = "specific-node-id"
mock_node.block = Mock()
mock_node.block.name = "TestBlock"
mock_node.block.description = "Test"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={"properties": {}, "required": []}
)
mock_node.block.input_schema.get_field_schema = Mock(
return_value={"type": "string", "description": "test"}
)
mock_links = [
Mock(sink_name="field", sink_id="specific-node-id", source_id="source"),
]
signature = await block._create_block_function_signature(mock_node, mock_links)
assert signature["function"]["_sink_node_id"] == "specific-node-id"
@pytest.mark.asyncio
async def test_field_mapping_preserved(self):
"""Test that _field_mapping is preserved in tool function."""
block = SmartDecisionMakerBlock()
mock_node = Mock()
mock_node.id = "test-node"
mock_node.block = Mock()
mock_node.block.name = "TestBlock"
mock_node.block.description = "Test"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={"properties": {}, "required": []}
)
mock_node.block.input_schema.get_field_schema = Mock(
return_value={"type": "string", "description": "test"}
)
mock_links = [
Mock(sink_name="Original Field Name", sink_id="test-node", source_id="source"),
]
signature = await block._create_block_function_signature(mock_node, mock_links)
field_mapping = signature["function"]["_field_mapping"]
assert "original_field_name" in field_mapping
assert field_mapping["original_field_name"] == "Original Field Name"
class TestRequiredFieldsHandling:
"""Tests for required fields handling."""
@pytest.mark.asyncio
async def test_required_fields_use_sanitized_names(self):
"""Test that required fields array uses sanitized names."""
block = SmartDecisionMakerBlock()
mock_node = Mock()
mock_node.id = "test-node"
mock_node.block = Mock()
mock_node.block.name = "TestBlock"
mock_node.block.description = "Test"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={
"properties": {},
"required": ["Required Field", "Another Required"],
}
)
mock_node.block.input_schema.get_field_schema = Mock(
return_value={"type": "string", "description": "test"}
)
mock_links = [
Mock(sink_name="Required Field", sink_id="test-node", source_id="source"),
Mock(sink_name="Another Required", sink_id="test-node", source_id="source"),
Mock(sink_name="Optional Field", sink_id="test-node", source_id="source"),
]
signature = await block._create_block_function_signature(mock_node, mock_links)
required = signature["function"]["parameters"]["required"]
# Should use sanitized names
assert "required_field" in required
assert "another_required" in required
# Original names should NOT be in required
assert "Required Field" not in required
assert "Another Required" not in required
# Optional field should not be required
assert "optional_field" not in required
assert "Optional Field" not in required

View File

@@ -373,7 +373,7 @@ async def test_output_yielding_with_dynamic_fields():
input_data = block.input_schema(
prompt="Create a user dictionary",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
)
@@ -594,7 +594,7 @@ async def test_validation_errors_dont_pollute_conversation():
input_data = block.input_schema(
prompt="Test prompt",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
retry=3, # Allow retries
agent_mode_max_iterations=1,
)

View File

@@ -1,871 +0,0 @@
"""
Tests for SmartDecisionMaker error handling failure modes.
Covers failure modes:
3. JSON Deserialization Without Exception Handling
4. Database Transaction Inconsistency
5. Missing Null Checks After Database Calls
15. Error Message Context Loss
17. No Validation of Dynamic Field Paths
"""
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from backend.blocks.smart_decision_maker import (
SmartDecisionMakerBlock,
_convert_raw_response_to_dict,
_create_tool_response,
)
class TestJSONDeserializationErrors:
"""
Tests for Failure Mode #3: JSON Deserialization Without Exception Handling
When LLM returns malformed JSON in tool call arguments, the json.loads()
call fails without proper error handling.
"""
def test_malformed_json_single_quotes(self):
"""
Test that single quotes in JSON cause parsing failure.
LLMs sometimes return {'key': 'value'} instead of {"key": "value"}
"""
malformed = "{'key': 'value'}"
with pytest.raises(json.JSONDecodeError):
json.loads(malformed)
def test_malformed_json_trailing_comma(self):
"""
Test that trailing commas cause parsing failure.
"""
malformed = '{"key": "value",}'
with pytest.raises(json.JSONDecodeError):
json.loads(malformed)
def test_malformed_json_unquoted_keys(self):
"""
Test that unquoted keys cause parsing failure.
"""
malformed = '{key: "value"}'
with pytest.raises(json.JSONDecodeError):
json.loads(malformed)
def test_malformed_json_python_none(self):
"""
Test that Python None instead of null causes failure.
"""
malformed = '{"key": None}'
with pytest.raises(json.JSONDecodeError):
json.loads(malformed)
def test_malformed_json_python_true_false(self):
"""
Test that Python True/False instead of true/false causes failure.
"""
malformed_true = '{"key": True}'
malformed_false = '{"key": False}'
with pytest.raises(json.JSONDecodeError):
json.loads(malformed_true)
with pytest.raises(json.JSONDecodeError):
json.loads(malformed_false)
@pytest.mark.asyncio
async def test_llm_returns_malformed_json_crashes_block(self):
"""
Test that malformed JSON from LLM causes block to crash.
BUG: The json.loads() at line 625, 706, 1124 can throw JSONDecodeError
which is not caught, causing the entire block to fail.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
# Create response with malformed JSON
mock_tool_call = MagicMock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = "{'malformed': 'json'}" # Single quotes!
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": None}
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "sink",
"_field_mapping": {},
"parameters": {"properties": {"malformed": {"type": "string"}}, "required": []},
},
}
]
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = mock_response
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0,
)
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
# BUG: This should raise JSONDecodeError
with pytest.raises(json.JSONDecodeError):
async for _ in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
pass
class TestDatabaseTransactionInconsistency:
"""
Tests for Failure Mode #4: Database Transaction Inconsistency
When multiple database operations are performed in sequence,
a failure partway through leaves the database in an inconsistent state.
"""
@pytest.mark.asyncio
async def test_partial_input_insertion_on_failure(self):
"""
Test that partial failures during multi-input insertion
leave database in inconsistent state.
"""
import threading
from collections import defaultdict
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
# Track which inputs were inserted
inserted_inputs = []
call_count = 0
async def failing_upsert(node_id, graph_exec_id, input_name, input_data):
nonlocal call_count
call_count += 1
# Fail on the third input
if call_count == 3:
raise Exception("Database connection lost!")
inserted_inputs.append(input_name)
mock_result = MagicMock()
mock_result.node_exec_id = "exec-id"
return mock_result, {input_name: input_data}
mock_tool_call = MagicMock()
mock_tool_call.id = "call_1"
mock_tool_call.function.name = "multi_input_tool"
mock_tool_call.function.arguments = json.dumps({
"input1": "value1",
"input2": "value2",
"input3": "value3", # This one will fail
"input4": "value4",
"input5": "value5",
})
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": "call_1"}]
}
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "multi_input_tool",
"_sink_node_id": "sink",
"_field_mapping": {
"input1": "input1",
"input2": "input2",
"input3": "input3",
"input4": "input4",
"input5": "input5",
},
"parameters": {
"properties": {
"input1": {"type": "string"},
"input2": {"type": "string"},
"input3": {"type": "string"},
"input4": {"type": "string"},
"input5": {"type": "string"},
},
"required": ["input1", "input2", "input3", "input4", "input5"],
},
},
}
]
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block"
mock_db_client.get_node.return_value = mock_node
mock_db_client.upsert_execution_input.side_effect = failing_upsert
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm, \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_llm.return_value = mock_response
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=1,
)
# The block should fail, but some inputs were already inserted
outputs = {}
try:
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
except Exception:
pass # Expected
# BUG: Some inputs were inserted before failure
# Database is now in inconsistent state
assert len(inserted_inputs) == 2, \
f"Expected 2 inserted before failure, got {inserted_inputs}"
assert "input1" in inserted_inputs
assert "input2" in inserted_inputs
# input3, input4, input5 were never inserted
class TestMissingNullChecks:
"""
Tests for Failure Mode #5: Missing Null Checks After Database Calls
"""
@pytest.mark.asyncio
async def test_get_node_returns_none(self):
"""
Test handling when get_node returns None.
"""
import threading
from collections import defaultdict
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
mock_tool_call = MagicMock()
mock_tool_call.id = "call_1"
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({"param": "value"})
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": "call_1"}]
}
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "nonexistent-node",
"_field_mapping": {"param": "param"},
"parameters": {
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
},
}
]
mock_db_client = AsyncMock()
mock_db_client.get_node.return_value = None # Node doesn't exist!
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm, \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_llm.return_value = mock_response
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=1,
)
# Should raise ValueError for missing node
with pytest.raises(ValueError, match="not found"):
async for _ in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
pass
@pytest.mark.asyncio
async def test_empty_execution_outputs(self):
"""
Test handling when get_execution_outputs_by_node_exec_id returns empty.
"""
import threading
from collections import defaultdict
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
call_count = 0
async def mock_llm_call(**kwargs):
nonlocal call_count
call_count += 1
if call_count > 1:
resp = MagicMock()
resp.response = "Done"
resp.tool_calls = []
resp.prompt_tokens = 10
resp.completion_tokens = 5
resp.reasoning = None
resp.raw_response = {"role": "assistant", "content": "Done"}
return resp
mock_tool_call = MagicMock()
mock_tool_call.id = "call_1"
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({})
resp = MagicMock()
resp.response = None
resp.tool_calls = [mock_tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": "call_1"}]
}
return resp
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "sink",
"_field_mapping": {},
"parameters": {"properties": {}, "required": []},
},
}
]
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block"
mock_db_client.get_node.return_value = mock_node
mock_exec_result = MagicMock()
mock_exec_result.node_exec_id = "exec-id"
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {} # Empty!
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=2,
)
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Empty outputs should be handled gracefully
# (uses "Tool executed successfully" as fallback)
assert "finished" in outputs or "conversations" in outputs
class TestErrorMessageContextLoss:
"""
Tests for Failure Mode #15: Error Message Context Loss
When exceptions are caught and converted to strings, important
debugging information is lost.
"""
def test_exception_to_string_loses_traceback(self):
"""
Test that converting exception to string loses traceback.
"""
try:
def inner():
raise ValueError("Inner error")
def outer():
inner()
outer()
except Exception as e:
error_string = str(e)
error_repr = repr(e)
# String representation loses call stack
assert "inner" not in error_string
assert "outer" not in error_string
# Even repr doesn't have full traceback
assert "Traceback" not in error_repr
def test_tool_response_loses_exception_type(self):
"""
Test that _create_tool_response loses exception type information.
"""
original_error = ConnectionError("Database unreachable")
tool_response = _create_tool_response(
"call_123",
f"Tool execution failed: {str(original_error)}"
)
content = tool_response.get("content", "")
# Original exception type is lost
assert "ConnectionError" not in content
# Only the message remains
assert "Database unreachable" in content
@pytest.mark.asyncio
async def test_agent_mode_error_response_lacks_context(self):
"""
Test that agent mode error responses lack debugging context.
"""
import threading
from collections import defaultdict
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
mock_tool_call = MagicMock()
mock_tool_call.id = "call_1"
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({})
mock_response_1 = MagicMock()
mock_response_1.response = None
mock_response_1.tool_calls = [mock_tool_call]
mock_response_1.prompt_tokens = 50
mock_response_1.completion_tokens = 25
mock_response_1.reasoning = None
mock_response_1.raw_response = {
"role": "assistant",
"content": [{"type": "tool_use", "id": "call_1"}]
}
mock_response_2 = MagicMock()
mock_response_2.response = "Handled the error"
mock_response_2.tool_calls = []
mock_response_2.prompt_tokens = 30
mock_response_2.completion_tokens = 15
mock_response_2.reasoning = None
mock_response_2.raw_response = {"role": "assistant", "content": "Handled"}
call_count = 0
async def mock_llm_call(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return mock_response_1
return mock_response_2
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "sink",
"_field_mapping": {},
"parameters": {"properties": {}, "required": []},
},
}
]
# Create a complex error with nested cause
class CustomDatabaseError(Exception):
pass
def create_complex_error():
try:
raise ConnectionError("Network timeout after 30s")
except ConnectionError as e:
raise CustomDatabaseError("Failed to connect to database") from e
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block"
mock_db_client.get_node.return_value = mock_node
# Make upsert raise the complex error
try:
create_complex_error()
except CustomDatabaseError as e:
mock_db_client.upsert_execution_input.side_effect = e
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=2,
)
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Check conversation for error details
conversations = outputs.get("conversations", [])
error_found = False
for msg in conversations:
content = msg.get("content", "")
if isinstance(content, list):
for item in content:
if item.get("type") == "tool_result":
result_content = item.get("content", "")
if "Error" in result_content or "failed" in result_content.lower():
error_found = True
# BUG: The error content lacks:
# - Exception type (CustomDatabaseError)
# - Chained cause (ConnectionError)
# - Stack trace
assert "CustomDatabaseError" not in result_content
assert "ConnectionError" not in result_content
# Note: error_found may be False if the error prevented tool response creation
class TestRawResponseConversion:
"""Tests for _convert_raw_response_to_dict edge cases."""
def test_string_response_converted(self):
"""Test that string responses are properly wrapped."""
result = _convert_raw_response_to_dict("Hello, world!")
assert result == {"role": "assistant", "content": "Hello, world!"}
def test_dict_response_unchanged(self):
"""Test that dict responses are passed through."""
original = {"role": "assistant", "content": "test", "extra": "field"}
result = _convert_raw_response_to_dict(original)
assert result == original
def test_object_response_converted(self):
"""Test that objects are converted using json.to_dict."""
mock_obj = MagicMock()
with patch("backend.blocks.smart_decision_maker.json.to_dict") as mock_to_dict:
mock_to_dict.return_value = {"converted": True}
result = _convert_raw_response_to_dict(mock_obj)
mock_to_dict.assert_called_once_with(mock_obj)
assert result == {"converted": True}
def test_none_response(self):
"""Test handling of None response."""
with patch("backend.blocks.smart_decision_maker.json.to_dict") as mock_to_dict:
mock_to_dict.return_value = None
result = _convert_raw_response_to_dict(None)
# None is not a string or dict, so it goes through to_dict
assert result is None
class TestValidationRetryMechanism:
"""Tests for the validation and retry mechanism."""
@pytest.mark.asyncio
async def test_validation_error_triggers_retry(self):
"""
Test that validation errors trigger retry with feedback.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
call_count = 0
async def mock_llm_call(**kwargs):
nonlocal call_count
call_count += 1
prompt = kwargs.get("prompt", [])
if call_count == 1:
# First call: return tool call with wrong parameter
mock_tool_call = MagicMock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({"wrong_param": "value"})
resp = MagicMock()
resp.response = None
resp.tool_calls = [mock_tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {"role": "assistant", "content": None}
return resp
else:
# Second call: check that error feedback was added
has_error_feedback = any(
"parameter errors" in str(msg.get("content", "")).lower()
for msg in prompt
)
# Return correct tool call
mock_tool_call = MagicMock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({"correct_param": "value"})
resp = MagicMock()
resp.response = None
resp.tool_calls = [mock_tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {"role": "assistant", "content": None}
return resp
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "sink",
"_field_mapping": {"correct_param": "correct_param"},
"parameters": {
"properties": {"correct_param": {"type": "string"}},
"required": ["correct_param"],
},
},
}
]
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0, # Traditional mode
retry=3,
)
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
outputs = {}
async for name, value in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[name] = value
# Should have made multiple calls due to retry
assert call_count >= 2
@pytest.mark.asyncio
async def test_max_retries_exceeded(self):
"""
Test behavior when max retries are exceeded.
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
async def mock_llm_call(**kwargs):
# Always return invalid tool call
mock_tool_call = MagicMock()
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = json.dumps({"wrong": "param"})
resp = MagicMock()
resp.response = None
resp.tool_calls = [mock_tool_call]
resp.prompt_tokens = 50
resp.completion_tokens = 25
resp.reasoning = None
resp.raw_response = {"role": "assistant", "content": None}
return resp
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "test_tool",
"_sink_node_id": "sink",
"_field_mapping": {"correct": "correct"},
"parameters": {
"properties": {"correct": {"type": "string"}},
"required": ["correct"],
},
},
}
]
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
input_data = SmartDecisionMakerBlock.Input(
prompt="Test",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0,
retry=2, # Only 2 retries
)
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
# Should raise ValueError after max retries
with pytest.raises(ValueError, match="parameter errors"):
async for _ in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph",
node_id="test-node",
graph_exec_id="test-exec",
node_exec_id="test-node-exec",
user_id="test-user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
pass

View File

@@ -1,819 +0,0 @@
"""
Comprehensive tests for SmartDecisionMakerBlock pin name sanitization.
This test file addresses the critical bug where field names with spaces/special characters
(e.g., "Max Keyword Difficulty") are not consistently sanitized between frontend and backend,
causing tool calls to "go into the void".
The core issue:
- Frontend connects link with original name: tools_^_{node_id}_~_Max Keyword Difficulty
- Backend emits with sanitized name: tools_^_{node_id}_~_max_keyword_difficulty
- parse_execution_output compares sink_pin_name directly without sanitization
- Result: mismatch causes tool calls to fail silently
"""
import json
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.dynamic_fields import (
parse_execution_output,
sanitize_pin_name,
)
class TestCleanupFunction:
"""Tests for the SmartDecisionMakerBlock.cleanup() static method."""
def test_cleanup_spaces_to_underscores(self):
"""Spaces should be replaced with underscores."""
assert SmartDecisionMakerBlock.cleanup("Max Keyword Difficulty") == "max_keyword_difficulty"
def test_cleanup_mixed_case_to_lowercase(self):
"""Mixed case should be converted to lowercase."""
assert SmartDecisionMakerBlock.cleanup("MaxKeywordDifficulty") == "maxkeyworddifficulty"
assert SmartDecisionMakerBlock.cleanup("UPPER_CASE") == "upper_case"
def test_cleanup_special_characters(self):
"""Special characters should be replaced with underscores."""
assert SmartDecisionMakerBlock.cleanup("field@name!") == "field_name_"
assert SmartDecisionMakerBlock.cleanup("value#1") == "value_1"
assert SmartDecisionMakerBlock.cleanup("test$value") == "test_value"
assert SmartDecisionMakerBlock.cleanup("a%b^c") == "a_b_c"
def test_cleanup_preserves_valid_characters(self):
"""Valid characters (alphanumeric, underscore, hyphen) should be preserved."""
assert SmartDecisionMakerBlock.cleanup("valid_name-123") == "valid_name-123"
assert SmartDecisionMakerBlock.cleanup("abc123") == "abc123"
def test_cleanup_empty_string(self):
"""Empty string should return empty string."""
assert SmartDecisionMakerBlock.cleanup("") == ""
def test_cleanup_only_special_chars(self):
"""String of only special characters should return underscores."""
assert SmartDecisionMakerBlock.cleanup("@#$%") == "____"
def test_cleanup_unicode_characters(self):
"""Unicode characters should be replaced with underscores."""
assert SmartDecisionMakerBlock.cleanup("café") == "caf_"
assert SmartDecisionMakerBlock.cleanup("日本語") == "___"
def test_cleanup_multiple_consecutive_spaces(self):
"""Multiple consecutive spaces should become multiple underscores."""
assert SmartDecisionMakerBlock.cleanup("a b") == "a___b"
def test_cleanup_leading_trailing_spaces(self):
"""Leading/trailing spaces should become underscores."""
assert SmartDecisionMakerBlock.cleanup(" name ") == "_name_"
def test_cleanup_realistic_field_names(self):
"""Test realistic field names from actual use cases."""
# From the reported bug
assert SmartDecisionMakerBlock.cleanup("Max Keyword Difficulty") == "max_keyword_difficulty"
# Other realistic names
assert SmartDecisionMakerBlock.cleanup("Search Query") == "search_query"
assert SmartDecisionMakerBlock.cleanup("API Response (JSON)") == "api_response__json_"
assert SmartDecisionMakerBlock.cleanup("User's Input") == "user_s_input"
class TestFieldMappingCreation:
"""Tests for field mapping creation in function signatures."""
@pytest.mark.asyncio
async def test_field_mapping_with_spaces_in_names(self):
"""Test that field mapping correctly maps clean names back to original names with spaces."""
block = SmartDecisionMakerBlock()
mock_node = Mock()
mock_node.id = "test-node-id"
mock_node.block = Mock()
mock_node.block.name = "TestBlock"
mock_node.block.description = "Test description"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={"properties": {}, "required": ["Max Keyword Difficulty"]}
)
def get_field_schema(field_name):
if field_name == "Max Keyword Difficulty":
return {"type": "integer", "description": "Maximum keyword difficulty (0-100)"}
raise KeyError(f"Field {field_name} not found")
mock_node.block.input_schema.get_field_schema = get_field_schema
mock_links = [
Mock(
source_name="tools_^_test_~_max_keyword_difficulty",
sink_name="Max Keyword Difficulty", # Original name with spaces
sink_id="test-node-id",
source_id="smart_node_id",
),
]
signature = await block._create_block_function_signature(mock_node, mock_links)
# Verify the cleaned name is used in properties
properties = signature["function"]["parameters"]["properties"]
assert "max_keyword_difficulty" in properties
# Verify the field mapping maps back to original
field_mapping = signature["function"]["_field_mapping"]
assert field_mapping["max_keyword_difficulty"] == "Max Keyword Difficulty"
@pytest.mark.asyncio
async def test_field_mapping_with_multiple_special_char_names(self):
"""Test field mapping with multiple fields containing special characters."""
block = SmartDecisionMakerBlock()
mock_node = Mock()
mock_node.id = "test-node-id"
mock_node.block = Mock()
mock_node.block.name = "SEO Tool"
mock_node.block.description = "SEO analysis tool"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={"properties": {}, "required": []}
)
def get_field_schema(field_name):
schemas = {
"Max Keyword Difficulty": {"type": "integer", "description": "Max difficulty"},
"Search Volume (Monthly)": {"type": "integer", "description": "Monthly volume"},
"CPC ($)": {"type": "number", "description": "Cost per click"},
"Target URL": {"type": "string", "description": "URL to analyze"},
}
if field_name in schemas:
return schemas[field_name]
raise KeyError(f"Field {field_name} not found")
mock_node.block.input_schema.get_field_schema = get_field_schema
mock_links = [
Mock(sink_name="Max Keyword Difficulty", sink_id="test-node-id", source_id="smart_node_id"),
Mock(sink_name="Search Volume (Monthly)", sink_id="test-node-id", source_id="smart_node_id"),
Mock(sink_name="CPC ($)", sink_id="test-node-id", source_id="smart_node_id"),
Mock(sink_name="Target URL", sink_id="test-node-id", source_id="smart_node_id"),
]
signature = await block._create_block_function_signature(mock_node, mock_links)
properties = signature["function"]["parameters"]["properties"]
field_mapping = signature["function"]["_field_mapping"]
# Verify all cleaned names are in properties
assert "max_keyword_difficulty" in properties
assert "search_volume__monthly_" in properties
assert "cpc____" in properties
assert "target_url" in properties
# Verify field mappings
assert field_mapping["max_keyword_difficulty"] == "Max Keyword Difficulty"
assert field_mapping["search_volume__monthly_"] == "Search Volume (Monthly)"
assert field_mapping["cpc____"] == "CPC ($)"
assert field_mapping["target_url"] == "Target URL"
class TestFieldNameCollision:
"""Tests for detecting field name collisions after sanitization."""
@pytest.mark.asyncio
async def test_collision_detection_same_sanitized_name(self):
"""Test behavior when two different names sanitize to the same value."""
block = SmartDecisionMakerBlock()
# These two different names will sanitize to the same value
name1 = "max keyword difficulty" # -> max_keyword_difficulty
name2 = "Max Keyword Difficulty" # -> max_keyword_difficulty
name3 = "MAX_KEYWORD_DIFFICULTY" # -> max_keyword_difficulty
assert SmartDecisionMakerBlock.cleanup(name1) == SmartDecisionMakerBlock.cleanup(name2)
assert SmartDecisionMakerBlock.cleanup(name2) == SmartDecisionMakerBlock.cleanup(name3)
@pytest.mark.asyncio
async def test_collision_in_function_signature(self):
"""Test that collisions in sanitized names could cause issues."""
block = SmartDecisionMakerBlock()
mock_node = Mock()
mock_node.id = "test-node-id"
mock_node.block = Mock()
mock_node.block.name = "TestBlock"
mock_node.block.description = "Test description"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={"properties": {}, "required": []}
)
def get_field_schema(field_name):
return {"type": "string", "description": f"Field: {field_name}"}
mock_node.block.input_schema.get_field_schema = get_field_schema
# Two different fields that sanitize to the same name
mock_links = [
Mock(sink_name="Test Field", sink_id="test-node-id", source_id="smart_node_id"),
Mock(sink_name="test field", sink_id="test-node-id", source_id="smart_node_id"),
]
signature = await block._create_block_function_signature(mock_node, mock_links)
properties = signature["function"]["parameters"]["properties"]
field_mapping = signature["function"]["_field_mapping"]
# Both sanitize to "test_field" - only one will be in properties
assert "test_field" in properties
# The field_mapping will have the last one written
assert field_mapping["test_field"] in ["Test Field", "test field"]
class TestOutputRouting:
"""Tests for output routing with sanitized names."""
def test_emit_key_format_with_spaces(self):
"""Test that emit keys use sanitized field names."""
block = SmartDecisionMakerBlock()
original_field_name = "Max Keyword Difficulty"
sink_node_id = "node-123"
sanitized_name = block.cleanup(original_field_name)
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_name}"
assert emit_key == "tools_^_node-123_~_max_keyword_difficulty"
def test_parse_execution_output_exact_match(self):
"""Test parse_execution_output with exact matching names."""
output_item = ("tools_^_node-123_~_max_keyword_difficulty", 50)
# When sink_pin_name matches the sanitized name, it should work
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name="max_keyword_difficulty",
)
assert result == 50
def test_parse_execution_output_mismatch_original_vs_sanitized(self):
"""
CRITICAL TEST: This reproduces the exact bug reported.
When frontend creates a link with original name "Max Keyword Difficulty"
but backend emits with sanitized name "max_keyword_difficulty",
the tool call should still be routed correctly.
CURRENT BEHAVIOR (BUG): Returns None because names don't match
EXPECTED BEHAVIOR: Should return the value (50) after sanitizing both names
"""
output_item = ("tools_^_node-123_~_max_keyword_difficulty", 50)
# This is what happens: sink_pin_name comes from frontend link (unsanitized)
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name="Max Keyword Difficulty", # Original name with spaces
)
# BUG: This currently returns None because:
# - target_input_pin = "max_keyword_difficulty" (from emit key, sanitized)
# - sink_pin_name = "Max Keyword Difficulty" (from link, original)
# - They don't match, so routing fails
#
# TODO: When the bug is fixed, change this assertion to:
# assert result == 50
assert result is None # Current buggy behavior
def test_parse_execution_output_with_sanitized_sink_pin(self):
"""Test that if sink_pin_name is pre-sanitized, routing works."""
output_item = ("tools_^_node-123_~_max_keyword_difficulty", 50)
# If sink_pin_name is already sanitized, routing works
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name="max_keyword_difficulty", # Pre-sanitized
)
assert result == 50
class TestProcessToolCallsMapping:
"""Tests for _process_tool_calls method field mapping."""
@pytest.mark.asyncio
async def test_process_tool_calls_maps_clean_to_original(self):
"""Test that _process_tool_calls correctly maps clean names back to original."""
block = SmartDecisionMakerBlock()
mock_response = Mock()
mock_tool_call = Mock()
mock_tool_call.function.name = "seo_tool"
mock_tool_call.function.arguments = json.dumps({
"max_keyword_difficulty": 50, # LLM uses clean name
"search_query": "test query",
})
mock_response.tool_calls = [mock_tool_call]
tool_functions = [
{
"type": "function",
"function": {
"name": "seo_tool",
"parameters": {
"properties": {
"max_keyword_difficulty": {"type": "integer"},
"search_query": {"type": "string"},
},
"required": ["max_keyword_difficulty", "search_query"],
},
"_sink_node_id": "test-sink-node",
"_field_mapping": {
"max_keyword_difficulty": "Max Keyword Difficulty", # Original name
"search_query": "Search Query",
},
},
}
]
processed = block._process_tool_calls(mock_response, tool_functions)
assert len(processed) == 1
tool_info = processed[0]
# Verify input_data uses ORIGINAL field names
assert "Max Keyword Difficulty" in tool_info.input_data
assert "Search Query" in tool_info.input_data
assert tool_info.input_data["Max Keyword Difficulty"] == 50
assert tool_info.input_data["Search Query"] == "test query"
class TestToolOutputEmitting:
"""Tests for the tool output emitting in traditional mode."""
@pytest.mark.asyncio
async def test_emit_keys_use_sanitized_names(self):
"""Test that emit keys always use sanitized field names."""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
mock_tool_call = MagicMock()
mock_tool_call.function.name = "seo_tool"
mock_tool_call.function.arguments = json.dumps({
"max_keyword_difficulty": 50,
})
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": None}
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "seo_tool",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {
"max_keyword_difficulty": "Max Keyword Difficulty",
},
"parameters": {
"properties": {
"max_keyword_difficulty": {"type": "integer"},
},
"required": ["max_keyword_difficulty"],
},
},
}
]
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response,
), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=0,
)
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
# The emit key should use the sanitized field name
# Even though the original was "Max Keyword Difficulty", emit uses sanitized
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
assert outputs["tools_^_test-sink-node-id_~_max_keyword_difficulty"] == 50
class TestSanitizationConsistency:
"""Tests for ensuring sanitization is consistent throughout the pipeline."""
@pytest.mark.asyncio
async def test_full_round_trip_with_spaces(self):
"""
Test the full round-trip of a field name with spaces through the system.
This simulates:
1. Frontend creates link with sink_name="Max Keyword Difficulty"
2. Backend creates function signature with cleaned property name
3. LLM responds with cleaned name
4. Backend processes response and maps back to original
5. Backend emits with sanitized name
6. Routing should match (currently broken)
"""
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
original_field_name = "Max Keyword Difficulty"
cleaned_field_name = SmartDecisionMakerBlock.cleanup(original_field_name)
# Step 1: Simulate frontend link creation
mock_link = Mock()
mock_link.sink_name = original_field_name # Frontend uses original
mock_link.sink_id = "test-sink-node-id"
mock_link.source_id = "smart-node-id"
# Step 2: Create function signature
mock_node = Mock()
mock_node.id = "test-sink-node-id"
mock_node.block = Mock()
mock_node.block.name = "SEO Tool"
mock_node.block.description = "SEO analysis"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={"properties": {}, "required": [original_field_name]}
)
mock_node.block.input_schema.get_field_schema = Mock(
return_value={"type": "integer", "description": "Max difficulty"}
)
signature = await block._create_block_function_signature(mock_node, [mock_link])
# Verify cleaned name is in properties
assert cleaned_field_name in signature["function"]["parameters"]["properties"]
# Verify field mapping exists
assert signature["function"]["_field_mapping"][cleaned_field_name] == original_field_name
# Step 3: Simulate LLM response using cleaned name
mock_tool_call = MagicMock()
mock_tool_call.function.name = "seo_tool"
mock_tool_call.function.arguments = json.dumps({
cleaned_field_name: 50 # LLM uses cleaned name
})
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": None}
# Prepare tool_functions as they would be in run()
tool_functions = [
{
"type": "function",
"function": {
"name": "seo_tool",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": signature["function"]["_field_mapping"],
"parameters": signature["function"]["parameters"],
},
}
]
# Step 4: Process tool calls
processed = block._process_tool_calls(mock_response, tool_functions)
assert len(processed) == 1
# Input data should have ORIGINAL name
assert original_field_name in processed[0].input_data
assert processed[0].input_data[original_field_name] == 50
# Step 5: Emit key generation (from run method logic)
field_mapping = processed[0].field_mapping
for clean_arg_name in signature["function"]["parameters"]["properties"]:
original = field_mapping.get(clean_arg_name, clean_arg_name)
sanitized_arg_name = block.cleanup(original)
emit_key = f"tools_^_test-sink-node-id_~_{sanitized_arg_name}"
# Emit key uses sanitized name
assert emit_key == f"tools_^_test-sink-node-id_~_{cleaned_field_name}"
# Step 6: Routing check (this is where the bug manifests)
emit_key = f"tools_^_test-sink-node-id_~_{cleaned_field_name}"
output_item = (emit_key, 50)
# Current routing uses original sink_name from link
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="test-sink-node-id",
sink_pin_name=original_field_name, # Frontend's original name
)
# BUG: This returns None because sanitized != original
# When fixed, this should return 50
assert result is None # Current broken behavior
def test_sanitization_is_idempotent(self):
"""Test that sanitizing an already sanitized name gives the same result."""
original = "Max Keyword Difficulty"
first_clean = SmartDecisionMakerBlock.cleanup(original)
second_clean = SmartDecisionMakerBlock.cleanup(first_clean)
assert first_clean == second_clean
class TestEdgeCases:
"""Tests for edge cases in the sanitization pipeline."""
@pytest.mark.asyncio
async def test_empty_field_name(self):
"""Test handling of empty field name."""
assert SmartDecisionMakerBlock.cleanup("") == ""
@pytest.mark.asyncio
async def test_very_long_field_name(self):
"""Test handling of very long field names."""
long_name = "A" * 1000 + " " + "B" * 1000
cleaned = SmartDecisionMakerBlock.cleanup(long_name)
assert "_" in cleaned # Space was replaced
assert len(cleaned) == len(long_name)
@pytest.mark.asyncio
async def test_field_name_with_newlines(self):
"""Test handling of field names with newlines."""
name_with_newline = "First Line\nSecond Line"
cleaned = SmartDecisionMakerBlock.cleanup(name_with_newline)
assert "\n" not in cleaned
assert "_" in cleaned
@pytest.mark.asyncio
async def test_field_name_with_tabs(self):
"""Test handling of field names with tabs."""
name_with_tab = "First\tSecond"
cleaned = SmartDecisionMakerBlock.cleanup(name_with_tab)
assert "\t" not in cleaned
assert "_" in cleaned
@pytest.mark.asyncio
async def test_numeric_field_name(self):
"""Test handling of purely numeric field names."""
assert SmartDecisionMakerBlock.cleanup("123") == "123"
assert SmartDecisionMakerBlock.cleanup("123 456") == "123_456"
@pytest.mark.asyncio
async def test_hyphenated_field_names(self):
"""Test that hyphens are preserved (valid in function names)."""
assert SmartDecisionMakerBlock.cleanup("field-name") == "field-name"
assert SmartDecisionMakerBlock.cleanup("Field-Name") == "field-name"
class TestDynamicFieldsWithSpaces:
"""Tests for dynamic fields with spaces in their names."""
@pytest.mark.asyncio
async def test_dynamic_dict_field_with_spaces(self):
"""Test dynamic dictionary fields where the key contains spaces."""
block = SmartDecisionMakerBlock()
mock_node = Mock()
mock_node.id = "test-node-id"
mock_node.block = Mock()
mock_node.block.name = "CreateDictionary"
mock_node.block.description = "Creates a dictionary"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={"properties": {}, "required": ["values"]}
)
mock_node.block.input_schema.get_field_schema = Mock(
side_effect=KeyError("not found")
)
# Dynamic field with a key containing spaces
mock_links = [
Mock(
sink_name="values_#_User Name", # Dict key with space
sink_id="test-node-id",
source_id="smart_node_id",
),
]
signature = await block._create_block_function_signature(mock_node, mock_links)
properties = signature["function"]["parameters"]["properties"]
field_mapping = signature["function"]["_field_mapping"]
# The cleaned name should be in properties
expected_clean = SmartDecisionMakerBlock.cleanup("values_#_User Name")
assert expected_clean in properties
# Field mapping should map back to original
assert field_mapping[expected_clean] == "values_#_User Name"
class TestAgentModeWithSpaces:
"""Tests for agent mode with field names containing spaces."""
@pytest.mark.asyncio
async def test_agent_mode_tool_execution_with_spaces(self):
"""Test that agent mode correctly handles field names with spaces."""
import threading
from collections import defaultdict
import backend.blocks.llm as llm_module
from backend.data.execution import ExecutionContext
block = SmartDecisionMakerBlock()
original_field = "Max Keyword Difficulty"
clean_field = SmartDecisionMakerBlock.cleanup(original_field)
mock_tool_call = MagicMock()
mock_tool_call.id = "call_1"
mock_tool_call.function.name = "seo_tool"
mock_tool_call.function.arguments = json.dumps({
clean_field: 50 # LLM uses clean name
})
mock_response_1 = MagicMock()
mock_response_1.response = None
mock_response_1.tool_calls = [mock_tool_call]
mock_response_1.prompt_tokens = 50
mock_response_1.completion_tokens = 25
mock_response_1.reasoning = None
mock_response_1.raw_response = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_1", "type": "function"}],
}
mock_response_2 = MagicMock()
mock_response_2.response = "Task completed"
mock_response_2.tool_calls = []
mock_response_2.prompt_tokens = 30
mock_response_2.completion_tokens = 15
mock_response_2.reasoning = None
mock_response_2.raw_response = {"role": "assistant", "content": "Task completed"}
llm_call_mock = AsyncMock()
llm_call_mock.side_effect = [mock_response_1, mock_response_2]
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "seo_tool",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {
clean_field: original_field,
},
"parameters": {
"properties": {
clean_field: {"type": "integer"},
},
"required": [clean_field],
},
},
}
]
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block-id"
mock_db_client.get_node.return_value = mock_node
mock_node_exec_result = MagicMock()
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
# The input data should use ORIGINAL field name
mock_input_data = {original_field: 50}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_input_data,
)
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
"result": {"status": "success"}
}
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
), patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = AsyncMock()
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
mock_node_stats = MagicMock()
mock_node_stats.error = None
mock_execution_processor.on_node_execution = AsyncMock(
return_value=mock_node_stats
)
input_data = SmartDecisionMakerBlock.Input(
prompt="Analyze keywords",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT,
agent_mode_max_iterations=3,
)
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
# Verify upsert was called with original field name
upsert_calls = mock_db_client.upsert_execution_input.call_args_list
assert len(upsert_calls) > 0
# Check that the original field name was used
for call in upsert_calls:
input_name = call.kwargs.get("input_name") or call.args[2]
# The input name should be the original (mapped back)
assert input_name == original_field
class TestRequiredFieldsWithSpaces:
"""Tests for required field handling with spaces in names."""
@pytest.mark.asyncio
async def test_required_fields_use_clean_names(self):
"""Test that required fields array uses clean names for API compatibility."""
block = SmartDecisionMakerBlock()
mock_node = Mock()
mock_node.id = "test-node-id"
mock_node.block = Mock()
mock_node.block.name = "TestBlock"
mock_node.block.description = "Test"
mock_node.block.input_schema = Mock()
mock_node.block.input_schema.jsonschema = Mock(
return_value={
"properties": {},
"required": ["Max Keyword Difficulty", "Search Query"],
}
)
def get_field_schema(field_name):
return {"type": "string", "description": f"Field: {field_name}"}
mock_node.block.input_schema.get_field_schema = get_field_schema
mock_links = [
Mock(sink_name="Max Keyword Difficulty", sink_id="test-node-id", source_id="smart_node_id"),
Mock(sink_name="Search Query", sink_id="test-node-id", source_id="smart_node_id"),
]
signature = await block._create_block_function_signature(mock_node, mock_links)
required = signature["function"]["parameters"]["required"]
# Required array should use CLEAN names for API compatibility
assert "max_keyword_difficulty" in required
assert "search_query" in required
# Original names should NOT be in required
assert "Max Keyword Difficulty" not in required
assert "Search Query" not in required

View File

@@ -1,5 +1,5 @@
from gravitasml.parser import Parser
from gravitasml.token import Token, tokenize
from gravitasml.token import tokenize
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
from backend.data.model import SchemaField
@@ -25,38 +25,6 @@ class XMLParserBlock(Block):
],
)
@staticmethod
def _validate_tokens(tokens: list[Token]) -> None:
"""Ensure the XML has a single root element and no stray text."""
if not tokens:
raise ValueError("XML input is empty.")
depth = 0
root_seen = False
for token in tokens:
if token.type == "TAG_OPEN":
if depth == 0 and root_seen:
raise ValueError("XML must have a single root element.")
depth += 1
if depth == 1:
root_seen = True
elif token.type == "TAG_CLOSE":
depth -= 1
if depth < 0:
raise SyntaxError("Unexpected closing tag in XML input.")
elif token.type in {"TEXT", "ESCAPE"}:
if depth == 0 and token.value:
raise ValueError(
"XML contains text outside the root element; "
"wrap content in a single root tag."
)
if depth != 0:
raise SyntaxError("Unclosed tag detected in XML input.")
if not root_seen:
raise ValueError("XML must include a root element.")
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Security fix: Add size limits to prevent XML bomb attacks
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
@@ -67,9 +35,7 @@ class XMLParserBlock(Block):
)
try:
tokens = list(tokenize(input_data.input_xml))
self._validate_tokens(tokens)
tokens = tokenize(input_data.input_xml)
parser = Parser(tokens)
parsed_result = parser.parse()
yield "parsed_xml", parsed_result

View File

@@ -111,8 +111,6 @@ class TranscribeYoutubeVideoBlock(Block):
return parsed_url.path.split("/")[2]
if parsed_url.path[:3] == "/v/":
return parsed_url.path.split("/")[2]
if parsed_url.path.startswith("/shorts/"):
return parsed_url.path.split("/")[2]
raise ValueError(f"Invalid YouTube URL: {url}")
def get_transcript(

View File

@@ -244,7 +244,11 @@ def websocket(server_address: str, graph_exec_id: str):
import websockets.asyncio.client
from backend.api.ws_api import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
from backend.server.ws_api import (
WSMessage,
WSMethod,
WSSubscribeGraphExecutionRequest,
)
async def send_message(server_address: str):
uri = f"ws://{server_address}"

View File

@@ -1 +0,0 @@
"""CLI utilities for backend development & administration"""

View File

@@ -1,57 +0,0 @@
#!/usr/bin/env python3
"""
Script to generate OpenAPI JSON specification for the FastAPI app.
This script imports the FastAPI app from backend.api.rest_api and outputs
the OpenAPI specification as JSON to stdout or a specified file.
Usage:
`poetry run python generate_openapi_json.py`
`poetry run python generate_openapi_json.py --output openapi.json`
`poetry run python generate_openapi_json.py --indent 4 --output openapi.json`
"""
import json
import os
from pathlib import Path
import click
@click.command()
@click.option(
"--output",
type=click.Path(dir_okay=False, path_type=Path),
help="Output file path (default: stdout)",
)
@click.option(
"--pretty",
type=click.BOOL,
default=False,
help="Pretty-print JSON output (indented 2 spaces)",
)
def main(output: Path, pretty: bool):
"""Generate and output the OpenAPI JSON specification."""
openapi_schema = get_openapi_schema()
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
if output:
output.write_text(json_output)
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
click.echo(f"\n{json_output[:500]} ...")
else:
print(json_output)
def get_openapi_schema():
"""Get the OpenAPI schema from the FastAPI app"""
from backend.api.rest_api import app
return app.openapi()
if __name__ == "__main__":
os.environ["LOG_LEVEL"] = "ERROR" # disable stdout log output
main()

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
from backend.api.features.library.model import LibraryAgentPreset
from backend.server.v2.library.model import LibraryAgentPreset
from .graph import NodeModel
from .integrations import Webhook # noqa: F401

View File

@@ -1,24 +1,22 @@
import logging
import uuid
from datetime import datetime, timezone
from typing import Literal, Optional
from typing import Optional
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission, APIKeyStatus
from prisma.models import APIKey as PrismaAPIKey
from prisma.types import APIKeyWhereUniqueInput
from pydantic import Field
from pydantic import BaseModel, Field
from backend.data.includes import MAX_USER_API_KEYS_FETCH
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from .base import APIAuthorizationInfo
logger = logging.getLogger(__name__)
keysmith = APIKeySmith()
class APIKeyInfo(APIAuthorizationInfo):
class APIKeyInfo(BaseModel):
id: str
name: str
head: str = Field(
@@ -28,9 +26,12 @@ class APIKeyInfo(APIAuthorizationInfo):
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
)
status: APIKeyStatus
permissions: list[APIKeyPermission]
created_at: datetime
last_used_at: Optional[datetime] = None
revoked_at: Optional[datetime] = None
description: Optional[str] = None
type: Literal["api_key"] = "api_key" # type: ignore
user_id: str
@staticmethod
def from_db(api_key: PrismaAPIKey):
@@ -40,7 +41,7 @@ class APIKeyInfo(APIAuthorizationInfo):
head=api_key.head,
tail=api_key.tail,
status=APIKeyStatus(api_key.status),
scopes=[APIKeyPermission(p) for p in api_key.permissions],
permissions=[APIKeyPermission(p) for p in api_key.permissions],
created_at=api_key.createdAt,
last_used_at=api_key.lastUsedAt,
revoked_at=api_key.revokedAt,
@@ -210,7 +211,7 @@ async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
return required_permission in api_key.scopes
return required_permission in api_key.permissions
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:

View File

@@ -1,15 +0,0 @@
from datetime import datetime
from typing import Literal, Optional
from prisma.enums import APIKeyPermission
from pydantic import BaseModel
class APIAuthorizationInfo(BaseModel):
user_id: str
scopes: list[APIKeyPermission]
type: Literal["oauth", "api_key"]
created_at: datetime
expires_at: Optional[datetime] = None
last_used_at: Optional[datetime] = None
revoked_at: Optional[datetime] = None

View File

@@ -1,872 +0,0 @@
"""
OAuth 2.0 Provider Data Layer
Handles management of OAuth applications, authorization codes,
access tokens, and refresh tokens.
Hashing strategy:
- Access tokens & Refresh tokens: SHA256 (deterministic, allows direct lookup by hash)
- Client secrets: Scrypt with salt (lookup by client_id, then verify with salt)
"""
import hashlib
import logging
import secrets
import uuid
from datetime import datetime, timedelta, timezone
from typing import Literal, Optional
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission as APIPermission
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
from prisma.models import OAuthApplication as PrismaOAuthApplication
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
from prisma.types import OAuthApplicationUpdateInput
from pydantic import BaseModel, Field, SecretStr
from .base import APIAuthorizationInfo
logger = logging.getLogger(__name__)
keysmith = APIKeySmith() # Only used for client secret hashing (Scrypt)
def _generate_token() -> str:
"""Generate a cryptographically secure random token."""
return secrets.token_urlsafe(32)
def _hash_token(token: str) -> str:
"""Hash a token using SHA256 (deterministic, for direct lookup)."""
return hashlib.sha256(token.encode()).hexdigest()
# Token TTLs
AUTHORIZATION_CODE_TTL = timedelta(minutes=10)
ACCESS_TOKEN_TTL = timedelta(hours=1)
REFRESH_TOKEN_TTL = timedelta(days=30)
ACCESS_TOKEN_PREFIX = "agpt_xt_"
REFRESH_TOKEN_PREFIX = "agpt_rt_"
# ============================================================================
# Exception Classes
# ============================================================================
class OAuthError(Exception):
"""Base OAuth error"""
pass
class InvalidClientError(OAuthError):
"""Invalid client_id or client_secret"""
pass
class InvalidGrantError(OAuthError):
"""Invalid or expired authorization code/refresh token"""
def __init__(self, reason: str):
self.reason = reason
super().__init__(f"Invalid grant: {reason}")
class InvalidTokenError(OAuthError):
"""Invalid, expired, or revoked token"""
def __init__(self, reason: str):
self.reason = reason
super().__init__(f"Invalid token: {reason}")
# ============================================================================
# Data Models
# ============================================================================
class OAuthApplicationInfo(BaseModel):
"""OAuth application information (without client secret hash)"""
id: str
name: str
description: Optional[str] = None
logo_url: Optional[str] = None
client_id: str
redirect_uris: list[str]
grant_types: list[str]
scopes: list[APIPermission]
owner_id: str
is_active: bool
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(app: PrismaOAuthApplication):
return OAuthApplicationInfo(
id=app.id,
name=app.name,
description=app.description,
logo_url=app.logoUrl,
client_id=app.clientId,
redirect_uris=app.redirectUris,
grant_types=app.grantTypes,
scopes=[APIPermission(s) for s in app.scopes],
owner_id=app.ownerId,
is_active=app.isActive,
created_at=app.createdAt,
updated_at=app.updatedAt,
)
class OAuthApplicationInfoWithSecret(OAuthApplicationInfo):
"""OAuth application with client secret hash (for validation)"""
client_secret_hash: str
client_secret_salt: str
@staticmethod
def from_db(app: PrismaOAuthApplication):
return OAuthApplicationInfoWithSecret(
**OAuthApplicationInfo.from_db(app).model_dump(),
client_secret_hash=app.clientSecret,
client_secret_salt=app.clientSecretSalt,
)
def verify_secret(self, plaintext_secret: str) -> bool:
"""Verify a plaintext client secret against the stored hash"""
# Use keysmith.verify_key() with stored salt
return keysmith.verify_key(
plaintext_secret, self.client_secret_hash, self.client_secret_salt
)
class OAuthAuthorizationCodeInfo(BaseModel):
"""Authorization code information"""
id: str
code: str
created_at: datetime
expires_at: datetime
application_id: str
user_id: str
scopes: list[APIPermission]
redirect_uri: str
code_challenge: Optional[str] = None
code_challenge_method: Optional[str] = None
used_at: Optional[datetime] = None
@property
def is_used(self) -> bool:
return self.used_at is not None
@staticmethod
def from_db(code: PrismaOAuthAuthorizationCode):
return OAuthAuthorizationCodeInfo(
id=code.id,
code=code.code,
created_at=code.createdAt,
expires_at=code.expiresAt,
application_id=code.applicationId,
user_id=code.userId,
scopes=[APIPermission(s) for s in code.scopes],
redirect_uri=code.redirectUri,
code_challenge=code.codeChallenge,
code_challenge_method=code.codeChallengeMethod,
used_at=code.usedAt,
)
class OAuthAccessTokenInfo(APIAuthorizationInfo):
"""Access token information"""
id: str
expires_at: datetime # type: ignore
application_id: str
type: Literal["oauth"] = "oauth" # type: ignore
@staticmethod
def from_db(token: PrismaOAuthAccessToken):
return OAuthAccessTokenInfo(
id=token.id,
user_id=token.userId,
scopes=[APIPermission(s) for s in token.scopes],
created_at=token.createdAt,
expires_at=token.expiresAt,
last_used_at=None,
revoked_at=token.revokedAt,
application_id=token.applicationId,
)
class OAuthAccessToken(OAuthAccessTokenInfo):
"""Access token with plaintext token included (sensitive)"""
token: SecretStr = Field(description="Plaintext token (sensitive)")
@staticmethod
def from_db(token: PrismaOAuthAccessToken, plaintext_token: str): # type: ignore
return OAuthAccessToken(
**OAuthAccessTokenInfo.from_db(token).model_dump(),
token=SecretStr(plaintext_token),
)
class OAuthRefreshTokenInfo(BaseModel):
"""Refresh token information"""
id: str
user_id: str
scopes: list[APIPermission]
created_at: datetime
expires_at: datetime
application_id: str
revoked_at: Optional[datetime] = None
@property
def is_revoked(self) -> bool:
return self.revoked_at is not None
@staticmethod
def from_db(token: PrismaOAuthRefreshToken):
return OAuthRefreshTokenInfo(
id=token.id,
user_id=token.userId,
scopes=[APIPermission(s) for s in token.scopes],
created_at=token.createdAt,
expires_at=token.expiresAt,
application_id=token.applicationId,
revoked_at=token.revokedAt,
)
class OAuthRefreshToken(OAuthRefreshTokenInfo):
"""Refresh token with plaintext token included (sensitive)"""
token: SecretStr = Field(description="Plaintext token (sensitive)")
@staticmethod
def from_db(token: PrismaOAuthRefreshToken, plaintext_token: str): # type: ignore
return OAuthRefreshToken(
**OAuthRefreshTokenInfo.from_db(token).model_dump(),
token=SecretStr(plaintext_token),
)
class TokenIntrospectionResult(BaseModel):
"""Result of token introspection (RFC 7662)"""
active: bool
scopes: Optional[list[str]] = None
client_id: Optional[str] = None
user_id: Optional[str] = None
exp: Optional[int] = None # Unix timestamp
token_type: Optional[Literal["access_token", "refresh_token"]] = None
# ============================================================================
# OAuth Application Management
# ============================================================================
async def get_oauth_application(client_id: str) -> Optional[OAuthApplicationInfo]:
"""Get OAuth application by client ID (without secret)"""
app = await PrismaOAuthApplication.prisma().find_unique(
where={"clientId": client_id}
)
if not app:
return None
return OAuthApplicationInfo.from_db(app)
async def get_oauth_application_with_secret(
client_id: str,
) -> Optional[OAuthApplicationInfoWithSecret]:
"""Get OAuth application by client ID (with secret hash for validation)"""
app = await PrismaOAuthApplication.prisma().find_unique(
where={"clientId": client_id}
)
if not app:
return None
return OAuthApplicationInfoWithSecret.from_db(app)
async def validate_client_credentials(
client_id: str, client_secret: str
) -> OAuthApplicationInfo:
"""
Validate client credentials and return application info.
Raises:
InvalidClientError: If client_id or client_secret is invalid, or app is inactive
"""
app = await get_oauth_application_with_secret(client_id)
if not app:
raise InvalidClientError("Invalid client_id")
if not app.is_active:
raise InvalidClientError("Application is not active")
# Verify client secret
if not app.verify_secret(client_secret):
raise InvalidClientError("Invalid client_secret")
# Return without secret hash
return OAuthApplicationInfo(**app.model_dump(exclude={"client_secret_hash"}))
def validate_redirect_uri(app: OAuthApplicationInfo, redirect_uri: str) -> bool:
"""Validate that redirect URI is registered for the application"""
return redirect_uri in app.redirect_uris
def validate_scopes(
app: OAuthApplicationInfo, requested_scopes: list[APIPermission]
) -> bool:
"""Validate that all requested scopes are allowed for the application"""
return all(scope in app.scopes for scope in requested_scopes)
# ============================================================================
# Authorization Code Flow
# ============================================================================
def _generate_authorization_code() -> str:
"""Generate a cryptographically secure authorization code"""
# 32 bytes = 256 bits of entropy
return secrets.token_urlsafe(32)
async def create_authorization_code(
application_id: str,
user_id: str,
scopes: list[APIPermission],
redirect_uri: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[Literal["S256", "plain"]] = None,
) -> OAuthAuthorizationCodeInfo:
"""
Create a new authorization code.
Expires in 10 minutes and can only be used once.
"""
code = _generate_authorization_code()
now = datetime.now(timezone.utc)
expires_at = now + AUTHORIZATION_CODE_TTL
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
data={
"id": str(uuid.uuid4()),
"code": code,
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
"redirectUri": redirect_uri,
"codeChallenge": code_challenge,
"codeChallengeMethod": code_challenge_method,
}
)
return OAuthAuthorizationCodeInfo.from_db(saved_code)
async def consume_authorization_code(
code: str,
application_id: str,
redirect_uri: str,
code_verifier: Optional[str] = None,
) -> tuple[str, list[APIPermission]]:
"""
Consume an authorization code and return (user_id, scopes).
This marks the code as used and validates:
- Code exists and matches application
- Code is not expired
- Code has not been used
- Redirect URI matches
- PKCE code verifier matches (if code challenge was provided)
Raises:
InvalidGrantError: If code is invalid, expired, used, or PKCE fails
"""
auth_code = await PrismaOAuthAuthorizationCode.prisma().find_unique(
where={"code": code}
)
if not auth_code:
raise InvalidGrantError("authorization code not found")
# Validate application
if auth_code.applicationId != application_id:
raise InvalidGrantError(
"authorization code does not belong to this application"
)
# Check if already used
if auth_code.usedAt is not None:
raise InvalidGrantError(
f"authorization code already used at {auth_code.usedAt}"
)
# Check expiration
now = datetime.now(timezone.utc)
if auth_code.expiresAt < now:
raise InvalidGrantError("authorization code expired")
# Validate redirect URI
if auth_code.redirectUri != redirect_uri:
raise InvalidGrantError("redirect_uri mismatch")
# Validate PKCE if code challenge was provided
if auth_code.codeChallenge:
if not code_verifier:
raise InvalidGrantError("code_verifier required but not provided")
if not _verify_pkce(
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
):
raise InvalidGrantError("PKCE verification failed")
# Mark code as used
await PrismaOAuthAuthorizationCode.prisma().update(
where={"code": code},
data={"usedAt": now},
)
return auth_code.userId, [APIPermission(s) for s in auth_code.scopes]
def _verify_pkce(
code_verifier: str, code_challenge: str, code_challenge_method: Optional[str]
) -> bool:
"""
Verify PKCE code verifier against code challenge.
Supports:
- S256: SHA256(code_verifier) == code_challenge
- plain: code_verifier == code_challenge
"""
if code_challenge_method == "S256":
# Hash the verifier with SHA256 and base64url encode
hashed = hashlib.sha256(code_verifier.encode("ascii")).digest()
computed_challenge = (
secrets.token_urlsafe(len(hashed)).encode("ascii").decode("ascii")
)
# For proper base64url encoding
import base64
computed_challenge = (
base64.urlsafe_b64encode(hashed).decode("ascii").rstrip("=")
)
return secrets.compare_digest(computed_challenge, code_challenge)
elif code_challenge_method == "plain" or code_challenge_method is None:
# Plain comparison
return secrets.compare_digest(code_verifier, code_challenge)
else:
logger.warning(f"Unsupported code challenge method: {code_challenge_method}")
return False
# ============================================================================
# Access Token Management
# ============================================================================
async def create_access_token(
application_id: str, user_id: str, scopes: list[APIPermission]
) -> OAuthAccessToken:
"""
Create a new access token.
Returns OAuthAccessToken (with plaintext token).
"""
plaintext_token = ACCESS_TOKEN_PREFIX + _generate_token()
token_hash = _hash_token(plaintext_token)
now = datetime.now(timezone.utc)
expires_at = now + ACCESS_TOKEN_TTL
saved_token = await PrismaOAuthAccessToken.prisma().create(
data={
"id": str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
}
)
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
async def validate_access_token(
token: str,
) -> tuple[OAuthAccessTokenInfo, OAuthApplicationInfo]:
"""
Validate an access token and return token info.
Raises:
InvalidTokenError: If token is invalid, expired, or revoked
InvalidClientError: If the client application is not marked as active
"""
token_hash = _hash_token(token)
# Direct lookup by hash
access_token = await PrismaOAuthAccessToken.prisma().find_unique(
where={"token": token_hash}, include={"Application": True}
)
if not access_token:
raise InvalidTokenError("access token not found")
if not access_token.Application: # should be impossible
raise InvalidClientError("Client application not found")
if not access_token.Application.isActive:
raise InvalidClientError("Client application is disabled")
if access_token.revokedAt is not None:
raise InvalidTokenError("access token has been revoked")
# Check expiration
now = datetime.now(timezone.utc)
if access_token.expiresAt < now:
raise InvalidTokenError("access token expired")
return (
OAuthAccessTokenInfo.from_db(access_token),
OAuthApplicationInfo.from_db(access_token.Application),
)
async def revoke_access_token(
token: str, application_id: str
) -> OAuthAccessTokenInfo | None:
"""
Revoke an access token.
Args:
token: The plaintext access token to revoke
application_id: The application ID making the revocation request.
Only tokens belonging to this application will be revoked.
Returns:
OAuthAccessTokenInfo if token was found and revoked, None otherwise.
Note:
Always performs exactly 2 DB queries regardless of outcome to prevent
timing side-channel attacks that could reveal token existence.
"""
try:
token_hash = _hash_token(token)
# Use update_many to filter by both token and applicationId
updated_count = await PrismaOAuthAccessToken.prisma().update_many(
where={
"token": token_hash,
"applicationId": application_id,
"revokedAt": None,
},
data={"revokedAt": datetime.now(timezone.utc)},
)
# Always perform second query to ensure constant time
result = await PrismaOAuthAccessToken.prisma().find_unique(
where={"token": token_hash}
)
# Only return result if we actually revoked something
if updated_count == 0:
return None
return OAuthAccessTokenInfo.from_db(result) if result else None
except Exception as e:
logger.exception(f"Error revoking access token: {e}")
return None
# ============================================================================
# Refresh Token Management
# ============================================================================
async def create_refresh_token(
application_id: str, user_id: str, scopes: list[APIPermission]
) -> OAuthRefreshToken:
"""
Create a new refresh token.
Returns OAuthRefreshToken (with plaintext token).
"""
plaintext_token = REFRESH_TOKEN_PREFIX + _generate_token()
token_hash = _hash_token(plaintext_token)
now = datetime.now(timezone.utc)
expires_at = now + REFRESH_TOKEN_TTL
saved_token = await PrismaOAuthRefreshToken.prisma().create(
data={
"id": str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
}
)
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
async def refresh_tokens(
refresh_token: str, application_id: str
) -> tuple[OAuthAccessToken, OAuthRefreshToken]:
"""
Use a refresh token to create new access and refresh tokens.
Returns (new_access_token, new_refresh_token) both with plaintext tokens included.
Raises:
InvalidGrantError: If refresh token is invalid, expired, or revoked
"""
token_hash = _hash_token(refresh_token)
# Direct lookup by hash
rt = await PrismaOAuthRefreshToken.prisma().find_unique(where={"token": token_hash})
if not rt:
raise InvalidGrantError("refresh token not found")
# NOTE: no need to check Application.isActive, this is checked by the token endpoint
if rt.revokedAt is not None:
raise InvalidGrantError("refresh token has been revoked")
# Validate application
if rt.applicationId != application_id:
raise InvalidGrantError("refresh token does not belong to this application")
# Check expiration
now = datetime.now(timezone.utc)
if rt.expiresAt < now:
raise InvalidGrantError("refresh token expired")
# Revoke old refresh token
await PrismaOAuthRefreshToken.prisma().update(
where={"token": token_hash},
data={"revokedAt": now},
)
# Create new access and refresh tokens with same scopes
scopes = [APIPermission(s) for s in rt.scopes]
new_access_token = await create_access_token(
rt.applicationId,
rt.userId,
scopes,
)
new_refresh_token = await create_refresh_token(
rt.applicationId,
rt.userId,
scopes,
)
return new_access_token, new_refresh_token
async def revoke_refresh_token(
token: str, application_id: str
) -> OAuthRefreshTokenInfo | None:
"""
Revoke a refresh token.
Args:
token: The plaintext refresh token to revoke
application_id: The application ID making the revocation request.
Only tokens belonging to this application will be revoked.
Returns:
OAuthRefreshTokenInfo if token was found and revoked, None otherwise.
Note:
Always performs exactly 2 DB queries regardless of outcome to prevent
timing side-channel attacks that could reveal token existence.
"""
try:
token_hash = _hash_token(token)
# Use update_many to filter by both token and applicationId
updated_count = await PrismaOAuthRefreshToken.prisma().update_many(
where={
"token": token_hash,
"applicationId": application_id,
"revokedAt": None,
},
data={"revokedAt": datetime.now(timezone.utc)},
)
# Always perform second query to ensure constant time
result = await PrismaOAuthRefreshToken.prisma().find_unique(
where={"token": token_hash}
)
# Only return result if we actually revoked something
if updated_count == 0:
return None
return OAuthRefreshTokenInfo.from_db(result) if result else None
except Exception as e:
logger.exception(f"Error revoking refresh token: {e}")
return None
# ============================================================================
# Token Introspection
# ============================================================================
async def introspect_token(
token: str,
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None,
) -> TokenIntrospectionResult:
"""
Introspect a token and return its metadata (RFC 7662).
Returns TokenIntrospectionResult with active=True and metadata if valid,
or active=False if the token is invalid/expired/revoked.
"""
# Try as access token first (or if hint says "access_token")
if token_type_hint != "refresh_token":
try:
token_info, app = await validate_access_token(token)
return TokenIntrospectionResult(
active=True,
scopes=list(s.value for s in token_info.scopes),
client_id=app.client_id if app else None,
user_id=token_info.user_id,
exp=int(token_info.expires_at.timestamp()),
token_type="access_token",
)
except InvalidTokenError:
pass # Try as refresh token
# Try as refresh token
token_hash = _hash_token(token)
refresh_token = await PrismaOAuthRefreshToken.prisma().find_unique(
where={"token": token_hash}
)
if refresh_token and refresh_token.revokedAt is None:
# Check if valid (not expired)
now = datetime.now(timezone.utc)
if refresh_token.expiresAt > now:
app = await get_oauth_application_by_id(refresh_token.applicationId)
return TokenIntrospectionResult(
active=True,
scopes=list(s for s in refresh_token.scopes),
client_id=app.client_id if app else None,
user_id=refresh_token.userId,
exp=int(refresh_token.expiresAt.timestamp()),
token_type="refresh_token",
)
# Token not found or inactive
return TokenIntrospectionResult(active=False)
async def get_oauth_application_by_id(app_id: str) -> Optional[OAuthApplicationInfo]:
"""Get OAuth application by ID"""
app = await PrismaOAuthApplication.prisma().find_unique(where={"id": app_id})
if not app:
return None
return OAuthApplicationInfo.from_db(app)
async def list_user_oauth_applications(user_id: str) -> list[OAuthApplicationInfo]:
"""Get all OAuth applications owned by a user"""
apps = await PrismaOAuthApplication.prisma().find_many(
where={"ownerId": user_id},
order={"createdAt": "desc"},
)
return [OAuthApplicationInfo.from_db(app) for app in apps]
async def update_oauth_application(
app_id: str,
*,
owner_id: str,
is_active: Optional[bool] = None,
logo_url: Optional[str] = None,
) -> Optional[OAuthApplicationInfo]:
"""
Update OAuth application active status.
Only the owner can update their app's status.
Returns the updated app info, or None if app not found or not owned by user.
"""
# First verify ownership
app = await PrismaOAuthApplication.prisma().find_first(
where={"id": app_id, "ownerId": owner_id}
)
if not app:
return None
patch: OAuthApplicationUpdateInput = {}
if is_active is not None:
patch["isActive"] = is_active
if logo_url:
patch["logoUrl"] = logo_url
if not patch:
return OAuthApplicationInfo.from_db(app) # return unchanged
updated_app = await PrismaOAuthApplication.prisma().update(
where={"id": app_id},
data=patch,
)
return OAuthApplicationInfo.from_db(updated_app) if updated_app else None
# ============================================================================
# Token Cleanup
# ============================================================================
async def cleanup_expired_oauth_tokens() -> dict[str, int]:
"""
Delete expired OAuth tokens from the database.
This removes:
- Expired authorization codes (10 min TTL)
- Expired access tokens (1 hour TTL)
- Expired refresh tokens (30 day TTL)
Returns a dict with counts of deleted tokens by type.
"""
now = datetime.now(timezone.utc)
# Delete expired authorization codes
codes_result = await PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"expiresAt": {"lt": now}}
)
# Delete expired access tokens
access_result = await PrismaOAuthAccessToken.prisma().delete_many(
where={"expiresAt": {"lt": now}}
)
# Delete expired refresh tokens
refresh_result = await PrismaOAuthRefreshToken.prisma().delete_many(
where={"expiresAt": {"lt": now}}
)
deleted = {
"authorization_codes": codes_result,
"access_tokens": access_result,
"refresh_tokens": refresh_result,
}
total = sum(deleted.values())
if total > 0:
logger.info(f"Cleaned up {total} expired OAuth tokens: {deleted}")
return deleted

View File

@@ -50,8 +50,6 @@ from .model import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from .graph import Link
app_config = Config()
@@ -474,7 +472,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.requires_human_review: bool = False
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -617,77 +614,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
) from ex
async def is_block_exec_need_review(
self,
input_data: BlockInput,
*,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
**kwargs,
) -> tuple[bool, BlockInput]:
"""
Check if this block execution needs human review and handle the review process.
Returns:
Tuple of (should_pause, input_data_to_use)
- should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer)
"""
# Skip review if not required or safe mode is disabled
if not self.requires_human_review or not execution_context.safe_mode:
return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper
# Handle the review request and get decision
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=True,
)
if decision is None:
# We're awaiting review - pause execution
return True, input_data
if not decision.should_proceed:
# Review was rejected, raise an error to stop execution
raise BlockExecutionError(
message=f"Block execution rejected by reviewer: {decision.message}",
block_name=self.name,
block_id=self.id,
)
# Review was approved - use the potentially modified data
# ReviewResult.data must be a dict for block inputs
reviewed_data = decision.review_result.data
if not isinstance(reviewed_data, dict):
raise BlockExecutionError(
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
block_name=self.name,
block_id=self.id,
)
return False, reviewed_data
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Check for review requirement and get potentially modified input data
should_pause, input_data = await self.is_block_exec_need_review(
input_data, **kwargs
)
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
@@ -695,7 +622,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,

View File

@@ -59,13 +59,12 @@ from backend.integrations.credentials_store import (
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 4,
LlmModel.O3_MINI: 2,
LlmModel.O1: 16,
LlmModel.O3_MINI: 2, # $1.10 / $4.40
LlmModel.O1: 16, # $15 / $60
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5_2: 6,
LlmModel.GPT5_1: 5,
LlmModel.GPT5: 2,
LlmModel.GPT5_1: 5,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 5,
@@ -88,7 +87,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.AIML_API_LLAMA3_3_70B: 1,
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
LlmModel.LLAMA3_3_70B: 1,
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_3: 1,
LlmModel.OLLAMA_LLAMA3_2: 1,

View File

@@ -16,7 +16,6 @@ from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBala
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
from pydantic import BaseModel
from backend.api.features.admin.model import UserHistoryResponse
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.db import query_raw_with_schema
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
@@ -30,6 +29,7 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.json import SafeJson, dumps
@@ -341,19 +341,6 @@ class UserCreditBase(ABC):
if result:
# UserBalance is already updated by the CTE
# Clear insufficient funds notification flags when credits are added
# so user can receive alerts again if they run out in the future.
if transaction.amount > 0 and transaction.type in [
CreditTransactionType.GRANT,
CreditTransactionType.TOP_UP,
]:
from backend.executor.manager import (
clear_insufficient_funds_notifications,
)
await clear_insufficient_funds_notifications(user_id)
return result[0]["balance"]
async def _add_transaction(
@@ -543,22 +530,6 @@ class UserCreditBase(ABC):
if result:
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
# UserBalance is already updated by the CTE
# Clear insufficient funds notification flags when credits are added
# so user can receive alerts again if they run out in the future.
if (
amount > 0
and is_active
and transaction_type
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
):
# Lazy import to avoid circular dependency with executor.manager
from backend.executor.manager import (
clear_insufficient_funds_notifications,
)
await clear_insufficient_funds_notifications(user_id)
return new_balance, tx_key
# If no result, either user doesn't exist or insufficient balance

View File

@@ -111,7 +111,7 @@ def get_database_schema() -> str:
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
"""Execute raw SQL query with proper schema handling."""
schema = get_database_schema()
schema_prefix = f'"{schema}".' if schema != "public" else ""
schema_prefix = f"{schema}." if schema != "public" else ""
formatted_query = query_template.format(schema_prefix=schema_prefix)
import prisma as prisma_module

View File

@@ -1,9 +1,8 @@
import logging
import queue
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from enum import Enum
from multiprocessing import Manager
from queue import Empty
from typing import (
TYPE_CHECKING,
Annotated,
@@ -383,7 +382,6 @@ class GraphExecutionWithNodes(GraphExecution):
self,
execution_context: ExecutionContext,
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
):
return GraphExecutionEntry(
user_id=self.user_id,
@@ -391,7 +389,6 @@ class GraphExecutionWithNodes(GraphExecution):
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip or set(),
execution_context=execution_context,
)
@@ -1147,8 +1144,6 @@ class GraphExecutionEntry(BaseModel):
graph_id: str
graph_version: int
nodes_input_masks: Optional[NodesInputMasks] = None
nodes_to_skip: set[str] = Field(default_factory=set)
"""Node IDs that should be skipped due to optional credentials not being configured."""
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
@@ -1168,12 +1163,16 @@ class NodeExecutionEntry(BaseModel):
class ExecutionQueue(Generic[T]):
"""
Queue for managing the execution of agents.
This will be shared between different processes
Thread-safe queue for managing node execution within a single graph execution.
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
threads within the same process. If migrating back to ProcessPoolExecutor,
replace with multiprocessing.Manager().Queue() for cross-process safety.
"""
def __init__(self):
self.queue = Manager().Queue()
# Thread-safe queue (not multiprocessing) — see class docstring
self.queue: queue.Queue[T] = queue.Queue()
def add(self, execution: T) -> T:
self.queue.put(execution)
@@ -1188,7 +1187,7 @@ class ExecutionQueue(Generic[T]):
def get_or_none(self) -> T | None:
try:
return self.queue.get_nowait()
except Empty:
except queue.Empty:
return None

View File

@@ -0,0 +1,60 @@
"""Tests for ExecutionQueue thread-safety."""
import queue
import threading
import pytest
from backend.data.execution import ExecutionQueue
def test_execution_queue_uses_stdlib_queue():
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
q = ExecutionQueue()
assert isinstance(q.queue, queue.Queue)
def test_basic_operations():
"""Test add, get, empty, and get_or_none."""
q = ExecutionQueue()
assert q.empty() is True
assert q.get_or_none() is None
result = q.add("item1")
assert result == "item1"
assert q.empty() is False
item = q.get()
assert item == "item1"
assert q.empty() is True
def test_thread_safety():
"""Test concurrent access from multiple threads."""
q = ExecutionQueue()
results = []
num_items = 100
def producer():
for i in range(num_items):
q.add(f"item_{i}")
def consumer():
count = 0
while count < num_items:
item = q.get_or_none()
if item is not None:
results.append(item)
count += 1
producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer)
producer_thread.start()
consumer_thread.start()
producer_thread.join(timeout=5)
consumer_thread.join(timeout=5)
assert len(results) == num_items

View File

@@ -94,15 +94,6 @@ class Node(BaseDbModel):
input_links: list[Link] = []
output_links: list[Link] = []
@property
def credentials_optional(self) -> bool:
"""
Whether credentials are optional for this node.
When True and credentials are not configured, the node will be skipped
during execution rather than causing a validation error.
"""
return self.metadata.get("credentials_optional", False)
@property
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
@@ -244,10 +235,7 @@ class BaseGraph(BaseDbModel):
return any(
node.block_id
for node in self.nodes
if (
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
or node.block.requires_human_review
)
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
)
@property
@@ -338,35 +326,7 @@ class Graph(BaseGraph):
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
schema = self._credentials_input_schema.jsonschema()
# Determine which credential fields are required based on credentials_optional metadata
graph_credentials_inputs = self.aggregate_credentials_inputs()
required_fields = []
# Build a map of node_id -> node for quick lookup
all_nodes = {node.id: node for node in self.nodes}
for sub_graph in self.sub_graphs:
for node in sub_graph.nodes:
all_nodes[node.id] = node
for field_key, (
_field_info,
node_field_pairs,
) in graph_credentials_inputs.items():
# A field is required if ANY node using it has credentials_optional=False
is_required = False
for node_id, _field_name in node_field_pairs:
node = all_nodes.get(node_id)
if node and not node.credentials_optional:
is_required = True
break
if is_required:
required_fields.append(field_key)
schema["required"] = required_fields
return schema
return self._credentials_input_schema.jsonschema()
@property
def _credentials_input_schema(self) -> type[BlockSchema]:

View File

@@ -6,14 +6,14 @@ import fastapi.exceptions
import pytest
from pytest_snapshot.plugin import Snapshot
import backend.api.features.store.model as store
from backend.api.model import CreateGraph
import backend.server.v2.store.model as store
from backend.blocks.basic import StoreValueBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.data.block import BlockSchema, BlockSchemaInput
from backend.data.graph import Graph, Link, Node
from backend.data.model import SchemaField
from backend.data.user import DEFAULT_USER_ID
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_user
from backend.util.test import SpinTestServer
@@ -396,58 +396,3 @@ async def test_access_store_listing_graph(server: SpinTestServer):
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
)
assert got_graph is not None
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================
def test_node_credentials_optional_default():
"""Test that credentials_optional defaults to False when not set in metadata."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={},
)
assert node.credentials_optional is False
def test_node_credentials_optional_true():
"""Test that credentials_optional returns True when explicitly set."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={"credentials_optional": True},
)
assert node.credentials_optional is True
def test_node_credentials_optional_false():
"""Test that credentials_optional returns False when explicitly set to False."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={"credentials_optional": False},
)
assert node.credentials_optional is False
def test_node_credentials_optional_with_other_metadata():
"""Test that credentials_optional works correctly with other metadata present."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={
"position": {"x": 100, "y": 200},
"customized_name": "My Custom Node",
"credentials_optional": True,
},
)
assert node.credentials_optional is True
assert node.metadata["position"] == {"x": 100, "y": 200}
assert node.metadata["customized_name"] == "My Custom Node"

View File

@@ -13,7 +13,7 @@ from prisma.models import PendingHumanReview
from prisma.types import PendingHumanReviewUpdateInput
from pydantic import BaseModel
from backend.api.features.executions.review.model import (
from backend.server.v2.executions.review.model import (
PendingHumanReviewModel,
SafeJsonData,
)
@@ -100,7 +100,7 @@ async def get_or_create_human_review(
return None
else:
return ReviewResult(
data=review.payload,
data=review.payload if review.status == ReviewStatus.APPROVED else None,
status=review.status,
message=review.reviewMessage or "",
processed=review.processed,

View File

@@ -23,7 +23,7 @@ from backend.util.exceptions import NotFoundError
from backend.util.json import SafeJson
if TYPE_CHECKING:
from backend.api.features.library.model import LibraryAgentPreset
from backend.server.v2.library.model import LibraryAgentPreset
from .db import BaseDbModel
from .graph import NodeModel
@@ -79,7 +79,7 @@ class WebhookWithRelations(Webhook):
# integrations.py → library/model.py → integrations.py (for Webhook)
# Runtime import is used in WebhookWithRelations.from_db() method instead
# Import at runtime to avoid circular dependency
from backend.api.features.library.model import LibraryAgentPreset
from backend.server.v2.library.model import LibraryAgentPreset
return WebhookWithRelations(
**Webhook.from_db(webhook).model_dump(),
@@ -285,8 +285,8 @@ async def unlink_webhook_from_graph(
user_id: The ID of the user (for authorization)
"""
# Avoid circular imports
from backend.api.features.library.db import set_preset_webhook
from backend.data.graph import set_node_webhook
from backend.server.v2.library.db import set_preset_webhook
# Find all nodes in this graph that use this webhook
nodes = await AgentNode.prisma().find_many(

View File

@@ -4,8 +4,8 @@ from typing import AsyncGenerator
from pydantic import BaseModel, field_serializer
from backend.api.model import NotificationPayload
from backend.data.event_bus import AsyncRedisEventBus
from backend.server.model import NotificationPayload
from backend.util.settings import Settings

View File

@@ -9,8 +9,6 @@ from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from backend.api.features.store.model import StoreAgentDetails
from backend.api.model import OnboardingNotificationPayload
from backend.data import execution as execution_db
from backend.data.credit import get_user_credit_model
from backend.data.notification_bus import (
@@ -18,6 +16,8 @@ from backend.data.notification_bus import (
NotificationEvent,
)
from backend.data.user import get_user_by_id
from backend.server.model import OnboardingNotificationPayload
from backend.server.v2.store.model import StoreAgentDetails
from backend.util.cache import cached
from backend.util.json import SafeJson
from backend.util.timezone_utils import get_user_timezone_or_utc
@@ -442,8 +442,6 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
agentGraphVersions=agent.agentGraphVersions,
agentGraphId=agent.agentGraphId,
last_updated=agent.updated_at,
)
for agent in recommended_agents

View File

@@ -1,513 +0,0 @@
"""
Tests for dynamic fields edge cases and failure modes.
Covers failure modes:
8. No Type Validation in Dynamic Field Merging
17. No Validation of Dynamic Field Paths
"""
from typing import Any
import pytest
from backend.data.dynamic_fields import (
DICT_SPLIT,
LIST_SPLIT,
OBJC_SPLIT,
extract_base_field_name,
get_dynamic_field_description,
is_dynamic_field,
is_tool_pin,
merge_execution_input,
parse_execution_output,
sanitize_pin_name,
)
class TestDynamicFieldMergingTypeValidation:
"""
Tests for Failure Mode #8: No Type Validation in Dynamic Field Merging
When merging dynamic fields, there's no validation that intermediate
structures have the correct type, leading to potential type coercion errors.
"""
def test_merge_dict_field_creates_dict(self):
"""Test that dictionary fields create dict structure."""
data = {
"values_#_name": "Alice",
"values_#_age": 30,
}
result = merge_execution_input(data)
assert "values" in result
assert isinstance(result["values"], dict)
assert result["values"]["name"] == "Alice"
assert result["values"]["age"] == 30
def test_merge_list_field_creates_list(self):
"""Test that list fields create list structure."""
data = {
"items_$_0": "first",
"items_$_1": "second",
"items_$_2": "third",
}
result = merge_execution_input(data)
assert "items" in result
assert isinstance(result["items"], list)
assert result["items"] == ["first", "second", "third"]
def test_merge_with_existing_primitive_type_conflict(self):
"""
Test behavior when merging into existing primitive value.
BUG: If the base field already exists as a primitive,
merging a dynamic field may fail or corrupt data.
"""
# Pre-existing primitive value
data = {
"value": "I am a string", # Primitive
"value_#_key": "dict value", # Dynamic dict field
}
# This may raise an error or produce unexpected results
# depending on merge order and implementation
try:
result = merge_execution_input(data)
# If it succeeds, check what happened
# The primitive may have been overwritten
if isinstance(result.get("value"), dict):
# Primitive was converted to dict - data loss!
assert "key" in result["value"]
else:
# Or the dynamic field was ignored
pass
except (TypeError, AttributeError):
# Expected error when trying to merge into primitive
pass
def test_merge_list_with_gaps(self):
"""Test merging list fields with non-contiguous indices."""
data = {
"items_$_0": "zero",
"items_$_2": "two", # Gap at index 1
"items_$_5": "five", # Larger gap
}
result = merge_execution_input(data)
assert "items" in result
# Check how gaps are handled
items = result["items"]
assert items[0] == "zero"
# Index 1 may be None or missing
assert items[2] == "two"
assert items[5] == "five"
def test_merge_nested_dynamic_fields(self):
"""Test merging deeply nested dynamic fields."""
data = {
"data_#_users_$_0": "user1",
"data_#_users_$_1": "user2",
"data_#_config_#_enabled": True,
}
result = merge_execution_input(data)
# Complex nested structures should be created
assert "data" in result
def test_merge_object_field(self):
"""Test merging object attribute fields."""
data = {
"user_@_name": "Alice",
"user_@_email": "alice@example.com",
}
result = merge_execution_input(data)
assert "user" in result
# Object fields create dict-like structure
assert result["user"]["name"] == "Alice"
assert result["user"]["email"] == "alice@example.com"
def test_merge_mixed_field_types(self):
"""Test merging mixed regular and dynamic fields."""
data = {
"regular": "value",
"dict_field_#_key": "dict_value",
"list_field_$_0": "list_item",
}
result = merge_execution_input(data)
assert result["regular"] == "value"
assert result["dict_field"]["key"] == "dict_value"
assert result["list_field"][0] == "list_item"
class TestDynamicFieldPathValidation:
"""
Tests for Failure Mode #17: No Validation of Dynamic Field Paths
When traversing dynamic field paths, intermediate None values
can cause TypeErrors instead of graceful failures.
"""
def test_parse_output_with_none_intermediate(self):
"""
Test parse_execution_output with None intermediate value.
If data contains {"items": None} and we try to access items[0],
it should return None gracefully, not raise TypeError.
"""
# Output with nested path
output_item = ("data_$_0", "value")
# When the base is None, should return None
# This tests the path traversal logic
result = parse_execution_output(
output_item,
link_output_selector="data",
sink_node_id=None,
sink_pin_name=None,
)
# Should handle gracefully (return the value or None)
# Not raise TypeError
def test_extract_base_field_name_with_multiple_delimiters(self):
"""Test extracting base name with multiple delimiters."""
# Multiple dict delimiters
assert extract_base_field_name("a_#_b_#_c") == "a"
# Multiple list delimiters
assert extract_base_field_name("a_$_0_$_1") == "a"
# Mixed delimiters
assert extract_base_field_name("a_#_b_$_0") == "a"
def test_is_dynamic_field_edge_cases(self):
"""Test is_dynamic_field with edge cases."""
# Standard dynamic fields
assert is_dynamic_field("values_#_key") is True
assert is_dynamic_field("items_$_0") is True
assert is_dynamic_field("obj_@_attr") is True
# Regular fields
assert is_dynamic_field("regular") is False
assert is_dynamic_field("with_underscore") is False
# Edge cases
assert is_dynamic_field("") is False
assert is_dynamic_field("_#_") is True # Just delimiter
assert is_dynamic_field("a_#_") is True # Trailing delimiter
def test_sanitize_pin_name_with_tool_pins(self):
"""Test sanitize_pin_name with various tool pin formats."""
# Tool pins should return "tools"
assert sanitize_pin_name("tools") == "tools"
assert sanitize_pin_name("tools_^_node_~_field") == "tools"
# Dynamic fields should return base name
assert sanitize_pin_name("values_#_key") == "values"
assert sanitize_pin_name("items_$_0") == "items"
# Regular fields unchanged
assert sanitize_pin_name("regular") == "regular"
class TestDynamicFieldDescriptions:
"""Tests for dynamic field description generation."""
def test_dict_field_description(self):
"""Test description for dictionary fields."""
desc = get_dynamic_field_description("values_#_user_name")
assert "Dictionary field" in desc
assert "values['user_name']" in desc
def test_list_field_description(self):
"""Test description for list fields."""
desc = get_dynamic_field_description("items_$_0")
assert "List item 0" in desc
assert "items[0]" in desc
def test_object_field_description(self):
"""Test description for object fields."""
desc = get_dynamic_field_description("user_@_email")
assert "Object attribute" in desc
assert "user.email" in desc
def test_regular_field_description(self):
"""Test description for regular (non-dynamic) fields."""
desc = get_dynamic_field_description("regular_field")
assert desc == "Value for regular_field"
def test_description_with_numeric_key(self):
"""Test description with numeric dictionary key."""
desc = get_dynamic_field_description("values_#_123")
assert "Dictionary field" in desc
assert "values['123']" in desc
class TestParseExecutionOutputToolRouting:
"""Tests for tool pin routing in parse_execution_output."""
def test_tool_pin_routing_exact_match(self):
"""Test tool pin routing with exact match."""
output_item = ("tools_^_node-123_~_field_name", "value")
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name="field_name",
)
assert result == "value"
def test_tool_pin_routing_node_mismatch(self):
"""Test tool pin routing with node ID mismatch."""
output_item = ("tools_^_node-123_~_field_name", "value")
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="different-node",
sink_pin_name="field_name",
)
assert result is None
def test_tool_pin_routing_field_mismatch(self):
"""Test tool pin routing with field name mismatch."""
output_item = ("tools_^_node-123_~_field_name", "value")
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name="different_field",
)
assert result is None
def test_tool_pin_missing_required_params(self):
"""Test that tool pins require node_id and pin_name."""
output_item = ("tools_^_node-123_~_field", "value")
with pytest.raises(ValueError, match="must be provided"):
parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id=None,
sink_pin_name="field",
)
with pytest.raises(ValueError, match="must be provided"):
parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name=None,
)
class TestParseExecutionOutputDynamicFields:
"""Tests for dynamic field routing in parse_execution_output."""
def test_dict_field_extraction(self):
"""Test extraction of dictionary field value."""
# The output_item is (field_name, data_structure)
data = {"key1": "value1", "key2": "value2"}
output_item = ("values", data)
result = parse_execution_output(
output_item,
link_output_selector="values_#_key1",
sink_node_id=None,
sink_pin_name=None,
)
assert result == "value1"
def test_list_field_extraction(self):
"""Test extraction of list item value."""
data = ["zero", "one", "two"]
output_item = ("items", data)
result = parse_execution_output(
output_item,
link_output_selector="items_$_1",
sink_node_id=None,
sink_pin_name=None,
)
assert result == "one"
def test_nested_field_extraction(self):
"""Test extraction of nested field value."""
data = {
"users": [
{"name": "Alice", "email": "alice@example.com"},
{"name": "Bob", "email": "bob@example.com"},
]
}
output_item = ("data", data)
# Access nested path
result = parse_execution_output(
output_item,
link_output_selector="data_#_users",
sink_node_id=None,
sink_pin_name=None,
)
assert result == data["users"]
def test_missing_key_returns_none(self):
"""Test that missing keys return None."""
data = {"existing": "value"}
output_item = ("values", data)
result = parse_execution_output(
output_item,
link_output_selector="values_#_nonexistent",
sink_node_id=None,
sink_pin_name=None,
)
assert result is None
def test_index_out_of_bounds_returns_none(self):
"""Test that out-of-bounds indices return None."""
data = ["zero", "one"]
output_item = ("items", data)
result = parse_execution_output(
output_item,
link_output_selector="items_$_99",
sink_node_id=None,
sink_pin_name=None,
)
assert result is None
class TestIsToolPin:
"""Tests for is_tool_pin function."""
def test_tools_prefix(self):
"""Test that 'tools_^_' prefix is recognized."""
assert is_tool_pin("tools_^_node_~_field") is True
assert is_tool_pin("tools_^_anything") is True
def test_tools_exact(self):
"""Test that exact 'tools' is recognized."""
assert is_tool_pin("tools") is True
def test_non_tool_pins(self):
"""Test that non-tool pins are not recognized."""
assert is_tool_pin("input") is False
assert is_tool_pin("output") is False
assert is_tool_pin("toolsomething") is False
assert is_tool_pin("my_tools") is False
assert is_tool_pin("") is False
class TestMergeExecutionInputEdgeCases:
"""Edge case tests for merge_execution_input."""
def test_empty_input(self):
"""Test merging empty input."""
result = merge_execution_input({})
assert result == {}
def test_only_regular_fields(self):
"""Test merging only regular fields (no dynamic)."""
data = {"a": 1, "b": 2, "c": 3}
result = merge_execution_input(data)
assert result == data
def test_overwrite_behavior(self):
"""Test behavior when same key is set multiple times."""
# This shouldn't happen in practice, but test the behavior
data = {
"values_#_key": "first",
}
result = merge_execution_input(data)
assert result["values"]["key"] == "first"
def test_numeric_string_keys(self):
"""Test handling of numeric string keys in dict fields."""
data = {
"values_#_123": "numeric_key",
"values_#_456": "another_numeric",
}
result = merge_execution_input(data)
assert result["values"]["123"] == "numeric_key"
assert result["values"]["456"] == "another_numeric"
def test_special_characters_in_keys(self):
"""Test handling of special characters in keys."""
data = {
"values_#_key-with-dashes": "value1",
"values_#_key.with.dots": "value2",
}
result = merge_execution_input(data)
assert result["values"]["key-with-dashes"] == "value1"
assert result["values"]["key.with.dots"] == "value2"
def test_deeply_nested_list(self):
"""Test deeply nested list indices."""
data = {
"matrix_$_0_$_0": "0,0",
"matrix_$_0_$_1": "0,1",
"matrix_$_1_$_0": "1,0",
"matrix_$_1_$_1": "1,1",
}
# Note: Current implementation may not support this depth
# Test documents expected behavior
try:
result = merge_execution_input(data)
# If supported, verify structure
except (KeyError, TypeError, IndexError):
# Deep nesting may not be supported
pass
def test_none_values(self):
"""Test handling of None values in input."""
data = {
"regular": None,
"dict_#_key": None,
"list_$_0": None,
}
result = merge_execution_input(data)
assert result["regular"] is None
assert result["dict"]["key"] is None
assert result["list"][0] is None
def test_complex_values(self):
"""Test handling of complex values (dicts, lists)."""
data = {
"values_#_nested_dict": {"inner": "value"},
"values_#_nested_list": [1, 2, 3],
}
result = merge_execution_input(data)
assert result["values"]["nested_dict"] == {"inner": "value"}
assert result["values"]["nested_list"] == [1, 2, 3]

View File

@@ -1,463 +0,0 @@
"""
Tests for dynamic field routing with sanitized names.
This test file specifically tests the parse_execution_output function
which is responsible for routing tool outputs to the correct nodes.
The critical bug this addresses is the mismatch between:
- emit keys using sanitized names (e.g., "max_keyword_difficulty")
- sink_pin_name using original names (e.g., "Max Keyword Difficulty")
"""
import re
from typing import Any
import pytest
from backend.data.dynamic_fields import (
DICT_SPLIT,
LIST_SPLIT,
OBJC_SPLIT,
extract_base_field_name,
get_dynamic_field_description,
is_dynamic_field,
is_tool_pin,
merge_execution_input,
parse_execution_output,
sanitize_pin_name,
)
def cleanup(s: str) -> str:
"""
Simulate SmartDecisionMakerBlock.cleanup() for testing.
Clean up names for use as tool function names.
"""
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
class TestParseExecutionOutputToolRouting:
"""Tests for tool pin routing in parse_execution_output."""
def test_exact_match_routes_correctly(self):
"""When emit key field exactly matches sink_pin_name, routing works."""
output_item = ("tools_^_node-123_~_query", "test value")
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name="query",
)
assert result == "test value"
def test_sanitized_emit_vs_original_sink_fails(self):
"""
CRITICAL BUG TEST: When emit key uses sanitized name but sink uses original,
routing fails.
"""
# Backend emits with sanitized name
sanitized_field = cleanup("Max Keyword Difficulty")
output_item = (f"tools_^_node-123_~_{sanitized_field}", 50)
# Frontend link has original name
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name="Max Keyword Difficulty", # Original name
)
# BUG: This returns None because sanitized != original
# Once fixed, change this to: assert result == 50
assert result is None, "Expected None due to sanitization mismatch bug"
def test_node_id_mismatch_returns_none(self):
"""When node IDs don't match, routing should return None."""
output_item = ("tools_^_node-123_~_query", "test value")
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="different-node", # Different node
sink_pin_name="query",
)
assert result is None
def test_both_node_and_pin_must_match(self):
"""Both node_id and pin_name must match for routing to succeed."""
output_item = ("tools_^_node-123_~_query", "test value")
# Wrong node, right pin
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="wrong-node",
sink_pin_name="query",
)
assert result is None
# Right node, wrong pin
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name="wrong_pin",
)
assert result is None
# Right node, right pin
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name="query",
)
assert result == "test value"
class TestToolPinRoutingWithSpecialCharacters:
"""Tests for tool pin routing with various special characters in names."""
@pytest.mark.parametrize(
"original_name,sanitized_name",
[
("Max Keyword Difficulty", "max_keyword_difficulty"),
("Search Volume (Monthly)", "search_volume__monthly_"),
("CPC ($)", "cpc____"),
("User's Input", "user_s_input"),
("Query #1", "query__1"),
("API.Response", "api_response"),
("Field@Name", "field_name"),
("Test\tTab", "test_tab"),
("Test\nNewline", "test_newline"),
],
)
def test_routing_mismatch_with_special_chars(self, original_name, sanitized_name):
"""
Test that various special characters cause routing mismatches.
This test documents the current buggy behavior where sanitized emit keys
don't match original sink_pin_names.
"""
# Verify sanitization
assert cleanup(original_name) == sanitized_name
# Backend emits with sanitized name
output_item = (f"tools_^_node-123_~_{sanitized_name}", "value")
# Frontend link has original name
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name=original_name,
)
# BUG: Returns None due to mismatch
assert result is None, f"Routing should fail for '{original_name}' vs '{sanitized_name}'"
class TestToolPinMissingParameters:
"""Tests for missing required parameters in parse_execution_output."""
def test_missing_sink_node_id_raises_error(self):
"""Missing sink_node_id should raise ValueError for tool pins."""
output_item = ("tools_^_node-123_~_query", "test value")
with pytest.raises(ValueError, match="sink_node_id and sink_pin_name must be provided"):
parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id=None,
sink_pin_name="query",
)
def test_missing_sink_pin_name_raises_error(self):
"""Missing sink_pin_name should raise ValueError for tool pins."""
output_item = ("tools_^_node-123_~_query", "test value")
with pytest.raises(ValueError, match="sink_node_id and sink_pin_name must be provided"):
parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id="node-123",
sink_pin_name=None,
)
class TestIsToolPin:
"""Tests for is_tool_pin function."""
def test_tools_prefix_is_tool_pin(self):
"""Names starting with 'tools_^_' are tool pins."""
assert is_tool_pin("tools_^_node_~_field") is True
assert is_tool_pin("tools_^_anything") is True
def test_tools_exact_is_tool_pin(self):
"""Exact 'tools' is a tool pin."""
assert is_tool_pin("tools") is True
def test_non_tool_pins(self):
"""Non-tool pin names should return False."""
assert is_tool_pin("input") is False
assert is_tool_pin("output") is False
assert is_tool_pin("my_tools") is False
assert is_tool_pin("toolsomething") is False
class TestSanitizePinName:
"""Tests for sanitize_pin_name function."""
def test_extracts_base_from_dynamic_field(self):
"""Should extract base field name from dynamic fields."""
assert sanitize_pin_name("values_#_key") == "values"
assert sanitize_pin_name("items_$_0") == "items"
assert sanitize_pin_name("obj_@_attr") == "obj"
def test_returns_tools_for_tool_pins(self):
"""Tool pins should be sanitized to 'tools'."""
assert sanitize_pin_name("tools_^_node_~_field") == "tools"
assert sanitize_pin_name("tools") == "tools"
def test_regular_field_unchanged(self):
"""Regular field names should be unchanged."""
assert sanitize_pin_name("query") == "query"
assert sanitize_pin_name("max_difficulty") == "max_difficulty"
class TestDynamicFieldDescriptions:
"""Tests for dynamic field description generation."""
def test_dict_field_description_with_spaces_in_key(self):
"""Dictionary field keys with spaces should generate correct descriptions."""
# After cleanup, "User Name" becomes "user_name" in the field name
# But the original key might have had spaces
desc = get_dynamic_field_description("values_#_user_name")
assert "Dictionary field" in desc
assert "values['user_name']" in desc
def test_list_field_description(self):
"""List field descriptions should include index."""
desc = get_dynamic_field_description("items_$_0")
assert "List item 0" in desc
assert "items[0]" in desc
def test_object_field_description(self):
"""Object field descriptions should include attribute."""
desc = get_dynamic_field_description("user_@_email")
assert "Object attribute" in desc
assert "user.email" in desc
class TestMergeExecutionInput:
"""Tests for merge_execution_input function."""
def test_merges_dict_fields(self):
"""Dictionary fields should be merged into nested structure."""
data = {
"values_#_name": "Alice",
"values_#_age": 30,
"other_field": "unchanged",
}
result = merge_execution_input(data)
assert "values" in result
assert result["values"]["name"] == "Alice"
assert result["values"]["age"] == 30
assert result["other_field"] == "unchanged"
def test_merges_list_fields(self):
"""List fields should be merged into arrays."""
data = {
"items_$_0": "first",
"items_$_1": "second",
"items_$_2": "third",
}
result = merge_execution_input(data)
assert "items" in result
assert result["items"] == ["first", "second", "third"]
def test_merges_mixed_fields(self):
"""Mixed regular and dynamic fields should all be preserved."""
data = {
"regular": "value",
"dict_#_key": "dict_value",
"list_$_0": "list_item",
}
result = merge_execution_input(data)
assert result["regular"] == "value"
assert result["dict"]["key"] == "dict_value"
assert result["list"] == ["list_item"]
class TestExtractBaseFieldName:
"""Tests for extract_base_field_name function."""
def test_extracts_from_dict_delimiter(self):
"""Should extract base name before _#_ delimiter."""
assert extract_base_field_name("values_#_name") == "values"
assert extract_base_field_name("user_#_email_#_domain") == "user"
def test_extracts_from_list_delimiter(self):
"""Should extract base name before _$_ delimiter."""
assert extract_base_field_name("items_$_0") == "items"
assert extract_base_field_name("data_$_1_$_nested") == "data"
def test_extracts_from_object_delimiter(self):
"""Should extract base name before _@_ delimiter."""
assert extract_base_field_name("obj_@_attr") == "obj"
def test_no_delimiter_returns_original(self):
"""Names without delimiters should be returned unchanged."""
assert extract_base_field_name("regular_field") == "regular_field"
assert extract_base_field_name("query") == "query"
class TestIsDynamicField:
"""Tests for is_dynamic_field function."""
def test_dict_delimiter_is_dynamic(self):
"""Fields with _#_ are dynamic."""
assert is_dynamic_field("values_#_key") is True
def test_list_delimiter_is_dynamic(self):
"""Fields with _$_ are dynamic."""
assert is_dynamic_field("items_$_0") is True
def test_object_delimiter_is_dynamic(self):
"""Fields with _@_ are dynamic."""
assert is_dynamic_field("obj_@_attr") is True
def test_regular_fields_not_dynamic(self):
"""Regular field names without delimiters are not dynamic."""
assert is_dynamic_field("regular_field") is False
assert is_dynamic_field("query") is False
assert is_dynamic_field("Max Keyword Difficulty") is False
class TestRoutingEndToEnd:
"""End-to-end tests for the full routing flow."""
def test_successful_routing_without_spaces(self):
"""Full routing flow works when no spaces in names."""
field_name = "query"
node_id = "test-node-123"
# Emit key (as created by SmartDecisionMaker)
emit_key = f"tools_^_{node_id}_~_{cleanup(field_name)}"
output_item = (emit_key, "search term")
# Route (as called by executor)
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id=node_id,
sink_pin_name=field_name,
)
assert result == "search term"
def test_failed_routing_with_spaces(self):
"""
Full routing flow FAILS when names have spaces.
This test documents the exact bug scenario:
1. Frontend creates link with sink_name="Max Keyword Difficulty"
2. SmartDecisionMaker emits with sanitized name in key
3. Executor calls parse_execution_output with original sink_pin_name
4. Routing fails because names don't match
"""
original_field_name = "Max Keyword Difficulty"
sanitized_field_name = cleanup(original_field_name)
node_id = "test-node-123"
# Step 1 & 2: SmartDecisionMaker emits with sanitized name
emit_key = f"tools_^_{node_id}_~_{sanitized_field_name}"
output_item = (emit_key, 50)
# Step 3: Executor routes with original name from link
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id=node_id,
sink_pin_name=original_field_name, # Original from link!
)
# Step 4: BUG - Returns None instead of 50
assert result is None
# This is what should happen after fix:
# assert result == 50
def test_multiple_fields_with_spaces(self):
"""Test routing multiple fields where some have spaces."""
node_id = "test-node"
fields = {
"query": "test", # No spaces - should work
"Max Difficulty": 100, # Spaces - will fail
"min_volume": 1000, # No spaces - should work
}
results = {}
for original_name, value in fields.items():
sanitized = cleanup(original_name)
emit_key = f"tools_^_{node_id}_~_{sanitized}"
output_item = (emit_key, value)
result = parse_execution_output(
output_item,
link_output_selector="tools",
sink_node_id=node_id,
sink_pin_name=original_name,
)
results[original_name] = result
# Fields without spaces work
assert results["query"] == "test"
assert results["min_volume"] == 1000
# Fields with spaces fail
assert results["Max Difficulty"] is None # BUG!
class TestProposedFix:
"""
Tests for the proposed fix.
The fix should sanitize sink_pin_name before comparison in parse_execution_output.
This class contains tests that will pass once the fix is implemented.
"""
def test_routing_should_sanitize_both_sides(self):
"""
PROPOSED FIX: parse_execution_output should sanitize sink_pin_name
before comparing with the field from emit key.
Current behavior: Direct string comparison
Fixed behavior: Compare cleanup(target_input_pin) == cleanup(sink_pin_name)
"""
original_field = "Max Keyword Difficulty"
sanitized_field = cleanup(original_field)
node_id = "node-123"
emit_key = f"tools_^_{node_id}_~_{sanitized_field}"
output_item = (emit_key, 50)
# Extract the comparison being made
selector = emit_key[8:] # Remove "tools_^_"
target_node_id, target_input_pin = selector.split("_~_", 1)
# Current comparison (FAILS):
current_comparison = (target_input_pin == original_field)
assert current_comparison is False, "Current comparison fails"
# Proposed fixed comparison (PASSES):
# Either sanitize sink_pin_name, or sanitize both
fixed_comparison = (target_input_pin == cleanup(original_field))
assert fixed_comparison is True, "Fixed comparison should pass"

View File

@@ -2,11 +2,6 @@ import logging
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.api.features.library.db import (
add_store_agent_to_library,
list_library_agents,
)
from backend.api.features.store.db import get_store_agent_details, get_store_agents
from backend.data import db
from backend.data.analytics import (
get_accuracy_trends_and_alerts,
@@ -66,6 +61,8 @@ from backend.data.user import (
get_user_notification_preference,
update_user_integrations,
)
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
from backend.util.service import (
AppService,
AppServiceClient,

View File

@@ -48,8 +48,27 @@ from backend.data.notifications import (
ZeroBalanceData,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.activity_status_generator import (
generate_activity_status_for_execution,
)
from backend.executor.utils import (
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_EXCHANGE,
GRAPH_EXECUTION_QUEUE_NAME,
GRAPH_EXECUTION_ROUTING_KEY,
CancelExecutionEvent,
ExecutionOutputEntry,
LogMetadata,
NodeExecutionProgress,
block_usage_cost,
create_execution_queue_config,
execution_usage_cost,
validate_exec,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.notifications.notifications import queue_notification
from backend.server.v2.AutoMod.manager import automod_manager
from backend.util import json
from backend.util.clients import (
get_async_execution_event_bus,
@@ -76,24 +95,7 @@ from backend.util.retry import (
)
from backend.util.settings import Settings
from .activity_status_generator import generate_activity_status_for_execution
from .automod.manager import automod_manager
from .cluster_lock import ClusterLock
from .utils import (
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_EXCHANGE,
GRAPH_EXECUTION_QUEUE_NAME,
GRAPH_EXECUTION_ROUTING_KEY,
CancelExecutionEvent,
ExecutionOutputEntry,
LogMetadata,
NodeExecutionProgress,
block_usage_cost,
create_execution_queue_config,
execution_usage_cost,
validate_exec,
)
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
@@ -114,40 +116,6 @@ utilization_gauge = Gauge(
"Ratio of active graph runs to max graph workers",
)
# Redis key prefix for tracking insufficient funds Discord notifications.
# We only send one notification per user per agent until they top up credits.
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
# TTL for the notification flag (30 days) - acts as a fallback cleanup
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
async def clear_insufficient_funds_notifications(user_id: str) -> int:
"""
Clear all insufficient funds notification flags for a user.
This should be called when a user tops up their credits, allowing
Discord notifications to be sent again if they run out of funds.
Args:
user_id: The user ID to clear notifications for.
Returns:
The number of keys that were deleted.
"""
try:
redis_client = await redis.get_redis_async()
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
keys = [key async for key in redis_client.scan_iter(match=pattern)]
if keys:
return await redis_client.delete(*keys)
return 0
except Exception as e:
logger.warning(
f"Failed to clear insufficient funds notification flags for user "
f"{user_id}: {e}"
)
return 0
# Thread-local storage for ExecutionProcessor instances
_tls = threading.local()
@@ -178,7 +146,6 @@ async def execute_node(
execution_processor: "ExecutionProcessor",
execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> BlockOutput:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -246,7 +213,6 @@ async def execute_node(
"user_id": user_id,
"execution_context": execution_context,
"execution_processor": execution_processor,
"nodes_to_skip": nodes_to_skip or set(),
}
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
@@ -544,7 +510,6 @@ class ExecutionProcessor:
node_exec_progress: NodeExecutionProgress,
nodes_input_masks: Optional[NodesInputMasks],
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
nodes_to_skip: Optional[set[str]] = None,
) -> NodeExecutionStats:
log_metadata = LogMetadata(
logger=_logger,
@@ -567,7 +532,6 @@ class ExecutionProcessor:
db_client=db_client,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
)
if isinstance(status, BaseException):
raise status
@@ -613,7 +577,6 @@ class ExecutionProcessor:
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> ExecutionStatus:
status = ExecutionStatus.RUNNING
@@ -650,7 +613,6 @@ class ExecutionProcessor:
execution_processor=self,
execution_stats=stats,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
):
await persist_output(output_name, output_data)
@@ -962,21 +924,6 @@ class ExecutionProcessor:
queued_node_exec = execution_queue.get()
# Check if this node should be skipped due to optional credentials
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
log_metadata.info(
f"Skipping node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id} - optional credentials not configured"
)
# Mark the node as completed without executing
# No outputs will be produced, so downstream nodes won't trigger
update_node_execution_status(
db_client=db_client,
exec_id=queued_node_exec.node_exec_id,
status=ExecutionStatus.COMPLETED,
)
continue
log_metadata.debug(
f"Dispatching node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id}",
@@ -1037,7 +984,6 @@ class ExecutionProcessor:
execution_stats,
execution_stats_lock,
),
nodes_to_skip=graph_exec.nodes_to_skip,
),
self.node_execution_loop,
)
@@ -1317,40 +1263,12 @@ class ExecutionProcessor:
graph_id: str,
e: InsufficientBalanceError,
):
# Check if we've already sent a notification for this user+agent combo.
# We only send one notification per user per agent until they top up credits.
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
try:
redis_client = redis.get_redis()
# SET NX returns True only if the key was newly set (didn't exist)
is_new_notification = redis_client.set(
redis_key,
"1",
nx=True,
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
)
if not is_new_notification:
# Already notified for this user+agent, skip all notifications
logger.debug(
f"Skipping duplicate insufficient funds notification for "
f"user={user_id}, graph={graph_id}"
)
return
except Exception as redis_error:
# If Redis fails, log and continue to send the notification
# (better to occasionally duplicate than to never notify)
logger.warning(
f"Failed to check/set insufficient funds notification flag in Redis: "
f"{redis_error}"
)
shortfall = abs(e.amount) - e.balance
metadata = db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
# Queue user email notification
queue_notification(
NotificationEventModel(
user_id=user_id,
@@ -1364,7 +1282,6 @@ class ExecutionProcessor:
)
)
# Send Discord system alert
try:
user_email = db_client.get_user_email_by_id(user_id)

View File

@@ -1,560 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma.enums import NotificationType
from backend.data.notifications import ZeroBalanceData
from backend.executor.manager import (
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
ExecutionProcessor,
clear_insufficient_funds_notifications,
)
from backend.util.exceptions import InsufficientBalanceError
from backend.util.test import SpinTestServer
async def async_iter(items):
"""Helper to create an async iterator from a list."""
for item in items:
yield item
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_sends_discord_alert_first_time(
server: SpinTestServer,
):
"""Test that the first insufficient funds notification sends a Discord alert."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72, # $0.72
amount=-714, # Attempting to spend $7.14
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to simulate first-time notification (set returns True)
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.return_value = True # Key was newly set
# Create mock database client
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify notification was queued
mock_queue_notif.assert_called_once()
notification_call = mock_queue_notif.call_args[0][0]
assert notification_call.type == NotificationType.ZERO_BALANCE
assert notification_call.user_id == user_id
assert isinstance(notification_call.data, ZeroBalanceData)
assert notification_call.data.current_balance == 72
# Verify Redis was checked with correct key pattern
expected_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
mock_redis_client.set.assert_called_once()
call_args = mock_redis_client.set.call_args
assert call_args[0][0] == expected_key
assert call_args[1]["nx"] is True
# Verify Discord alert was sent
mock_client.discord_system_alert.assert_called_once()
discord_message = mock_client.discord_system_alert.call_args[0][0]
assert "Insufficient Funds Alert" in discord_message
assert "test@example.com" in discord_message
assert "Test Agent" in discord_message
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_skips_duplicate_notifications(
server: SpinTestServer,
):
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to simulate duplicate notification (set returns False/None)
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.return_value = None # Key already existed
# Create mock database client
mock_db_client = MagicMock()
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify email notification was NOT queued (deduplication worked)
mock_queue_notif.assert_not_called()
# Verify Discord alert was NOT sent (deduplication worked)
mock_client.discord_system_alert.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
server: SpinTestServer,
):
"""Test that different agents for the same user get separate Discord alerts."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id_1 = "test-graph-111"
graph_id_2 = "test-graph-222"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch("backend.executor.manager.queue_notification"), patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
# Both calls return True (first time for each agent)
mock_redis_client.set.return_value = True
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# First agent notification
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_1,
e=error,
)
# Second agent notification
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_2,
e=error,
)
# Verify Discord alerts were sent for both agents
assert mock_client.discord_system_alert.call_count == 2
# Verify Redis was called with different keys
assert mock_redis_client.set.call_count == 2
calls = mock_redis_client.set.call_args_list
assert (
calls[0][0][0]
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_1}"
)
assert (
calls[1][0][0]
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_2}"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
"""Test that clearing notifications removes all keys for a user."""
user_id = "test-user-123"
with patch("backend.executor.manager.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
# Mock scan_iter to return some keys as an async iterator
mock_keys = [
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1",
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-2",
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-3",
]
mock_redis_client.scan_iter.return_value = async_iter(mock_keys)
# delete is awaited, so use AsyncMock
mock_redis_client.delete = AsyncMock(return_value=3)
# Clear notifications
result = await clear_insufficient_funds_notifications(user_id)
# Verify correct pattern was used
expected_pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
mock_redis_client.scan_iter.assert_called_once_with(match=expected_pattern)
# Verify delete was called with all keys
mock_redis_client.delete.assert_called_once_with(*mock_keys)
# Verify return value
assert result == 3
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestServer):
"""Test clearing notifications when there are no keys to clear."""
user_id = "test-user-no-notifications"
with patch("backend.executor.manager.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
# Mock scan_iter to return no keys as an async iterator
mock_redis_client.scan_iter.return_value = async_iter([])
# Clear notifications
result = await clear_insufficient_funds_notifications(user_id)
# Verify delete was not called
mock_redis_client.delete.assert_not_called()
# Verify return value
assert result == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications_handles_redis_error(
server: SpinTestServer,
):
"""Test that clearing notifications handles Redis errors gracefully."""
user_id = "test-user-redis-error"
with patch("backend.executor.manager.redis") as mock_redis_module:
# Mock get_redis_async to raise an error
mock_redis_module.get_redis_async = AsyncMock(
side_effect=Exception("Redis connection failed")
)
# Clear notifications should not raise, just return 0
result = await clear_insufficient_funds_notifications(user_id)
# Verify it returned 0 (graceful failure)
assert result == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_continues_on_redis_error(
server: SpinTestServer,
):
"""Test that both email and Discord notifications are still sent when Redis fails."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to raise an error
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.side_effect = Exception("Redis connection error")
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify email notification was still queued despite Redis error
mock_queue_notif.assert_called_once()
# Verify Discord alert was still sent despite Redis error
mock_client.discord_system_alert.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_clears_notifications_on_grant(server: SpinTestServer):
"""Test that _add_transaction clears notification flags when adding GRANT credits."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-grant-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter(
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
)
mock_redis_client.delete = AsyncMock(return_value=1)
# Create a concrete instance
credit_model = UserCredit()
# Call _add_transaction with GRANT type (should clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=500, # Positive amount
transaction_type=CreditTransactionType.GRANT,
is_active=True, # Active transaction
)
# Verify notification clearing was called
mock_redis_module.get_redis_async.assert_called_once()
mock_redis_client.scan_iter.assert_called_once_with(
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestServer):
"""Test that _add_transaction clears notification flags when adding TOP_UP credits."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-topup-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter([])
mock_redis_client.delete = AsyncMock(return_value=0)
credit_model = UserCredit()
# Call _add_transaction with TOP_UP type (should clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=1000, # Positive amount
transaction_type=CreditTransactionType.TOP_UP,
is_active=True,
)
# Verify notification clearing was attempted
mock_redis_module.get_redis_async.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_skips_clearing_for_inactive_transaction(
server: SpinTestServer,
):
"""Test that _add_transaction does NOT clear notifications for inactive transactions."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-inactive"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
# Mock async Redis
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
credit_model = UserCredit()
# Call _add_transaction with is_active=False (should NOT clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=500,
transaction_type=CreditTransactionType.TOP_UP,
is_active=False, # Inactive - pending Stripe payment
)
# Verify notification clearing was NOT called
mock_redis_module.get_redis_async.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_skips_clearing_for_usage_transaction(
server: SpinTestServer,
):
"""Test that _add_transaction does NOT clear notifications for USAGE transactions."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-usage"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
# Mock async Redis
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
credit_model = UserCredit()
# Call _add_transaction with USAGE type (spending, should NOT clear)
await credit_model._add_transaction(
user_id=user_id,
amount=-100, # Negative - spending credits
transaction_type=CreditTransactionType.USAGE,
is_active=True,
)
# Verify notification clearing was NOT called
mock_redis_module.get_redis_async.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_enable_transaction_clears_notifications(server: SpinTestServer):
"""Test that _enable_transaction clears notification flags when enabling a TOP_UP."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-enable"
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
"backend.data.credit.query_raw_with_schema"
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
# Mock finding the pending transaction
mock_transaction = MagicMock()
mock_transaction.amount = 1000
mock_transaction.type = CreditTransactionType.TOP_UP
mock_credit_tx.prisma.return_value.find_first = AsyncMock(
return_value=mock_transaction
)
# Mock the query to return updated balance
mock_query.return_value = [{"balance": 1500}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter(
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
)
mock_redis_client.delete = AsyncMock(return_value=1)
credit_model = UserCredit()
# Call _enable_transaction (simulates Stripe checkout completion)
from backend.util.json import SafeJson
await credit_model._enable_transaction(
transaction_key="cs_test_123",
user_id=user_id,
metadata=SafeJson({"payment": "completed"}),
)
# Verify notification clearing was called
mock_redis_module.get_redis_async.assert_called_once()
mock_redis_client.scan_iter.assert_called_once_with(
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
)

View File

@@ -3,16 +3,16 @@ import logging
import fastapi.responses
import pytest
import backend.api.features.library.model
import backend.api.features.store.model
from backend.api.model import CreateGraph
from backend.api.rest_api import AgentServer
import backend.server.v2.library.model
import backend.server.v2.store.model
from backend.blocks.basic import StoreValueBlock
from backend.blocks.data_manipulation import FindInDictionaryBlock
from backend.blocks.io import AgentInputBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.data.model import User
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.test import SpinTestServer, wait_execution
@@ -356,7 +356,7 @@ async def test_execute_preset(server: SpinTestServer):
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
name="Test Preset With Clash",
description="Test preset with clashing input values",
graph_id=test_graph.id,
@@ -444,7 +444,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
name="Test Preset With Clash",
description="Test preset with clashing input values",
graph_id=test_graph.id,
@@ -485,7 +485,7 @@ async def test_store_listing_graph(server: SpinTestServer):
test_user = await create_test_user()
test_graph = await create_graph(server, create_test_graph(), test_user)
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
agent_id=test_graph.id,
agent_version=test_graph.version,
slug=test_graph.id,
@@ -514,7 +514,7 @@ async def test_store_listing_graph(server: SpinTestServer):
admin_user = await create_test_user(alt_user=True)
await server.agent_server.test_review_store_listing(
backend.api.features.store.model.ReviewSubmissionRequest(
backend.server.v2.store.model.ReviewSubmissionRequest(
store_listing_version_id=slv_id,
is_approved=True,
comments="Test comments",
@@ -523,7 +523,7 @@ async def test_store_listing_graph(server: SpinTestServer):
)
# Add the approved store listing to the admin user's library so they can execute it
from backend.api.features.library.db import add_store_agent_to_library
from backend.server.v2.library.db import add_store_agent_to_library
await add_store_agent_to_library(
store_listing_version_id=slv_id, user_id=admin_user.id

View File

@@ -23,7 +23,6 @@ from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput
@@ -243,12 +242,6 @@ def cleanup_expired_files():
run_async(cleanup_expired_files_async())
def cleanup_oauth_tokens():
"""Clean up expired OAuth tokens from the database."""
# Wait for completion
run_async(cleanup_expired_oauth_tokens())
def execution_accuracy_alerts():
"""Check execution accuracy and send alerts if drops are detected."""
return report_execution_accuracy_alerts()
@@ -453,17 +446,6 @@ class Scheduler(AppService):
jobstore=Jobstores.EXECUTION.value,
)
# OAuth Token Cleanup - configurable interval
self.scheduler.add_job(
cleanup_oauth_tokens,
id="cleanup_oauth_tokens",
trigger="interval",
replace_existing=True,
seconds=config.oauth_token_cleanup_interval_hours
* 3600, # Convert hours to seconds
jobstore=Jobstores.EXECUTION.value,
)
# Execution Accuracy Monitoring - configurable interval
self.scheduler.add_job(
execution_accuracy_alerts,
@@ -622,11 +604,6 @@ class Scheduler(AppService):
"""Manually trigger cleanup of expired cloud storage files."""
return cleanup_expired_files()
@expose
def execute_cleanup_oauth_tokens(self):
"""Manually trigger cleanup of expired OAuth tokens."""
return cleanup_oauth_tokens()
@expose
def execute_report_execution_accuracy_alerts(self):
"""Manually trigger execution accuracy alert checking."""

View File

@@ -1,7 +1,7 @@
import pytest
from backend.api.model import CreateGraph
from backend.data import db
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.clients import get_scheduler_client
from backend.util.test import SpinTestServer

View File

@@ -239,19 +239,14 @@ async def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[dict[str, dict[str, str]], set[str]]:
) -> dict[str, dict[str, str]]:
"""
Checks all credentials for all nodes of the graph and returns structured errors
and a set of nodes that should be skipped due to optional missing credentials.
Checks all credentials for all nodes of the graph and returns structured errors.
Returns:
tuple[
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
"""
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
nodes_to_skip: set[str] = set()
for node in graph.nodes:
block = node.block
@@ -261,46 +256,27 @@ async def _validate_node_input_credentials(
if not credentials_fields:
continue
# Track if any credential field is missing for this node
has_missing_credentials = False
for field_name, credentials_meta_type in credentials_fields.items():
try:
# Check nodes_input_masks first, then input_default
field_value = None
if (
nodes_input_masks
and (node_input_mask := nodes_input_masks.get(node.id))
and field_name in node_input_mask
):
field_value = node_input_mask[field_name]
credentials_meta = credentials_meta_type.model_validate(
node_input_mask[field_name]
)
elif field_name in node.input_default:
# For optional credentials, don't use input_default - treat as missing
# This prevents stale credential IDs from failing validation
if node.credentials_optional:
field_value = None
else:
field_value = node.input_default[field_name]
# Check if credentials are missing (None, empty, or not present)
if field_value is None or (
isinstance(field_value, dict) and not field_value.get("id")
):
has_missing_credentials = True
# If node has credentials_optional flag, mark for skipping instead of error
if node.credentials_optional:
continue # Don't add error, will be marked for skip after loop
else:
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
credentials_meta = credentials_meta_type.model_validate(field_value)
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
# Missing credentials
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
except ValidationError as e:
# Validation error means credentials were provided but invalid
# This should always be an error, even if optional
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
continue
@@ -311,7 +287,6 @@ async def _validate_node_input_credentials(
)
except Exception as e:
# Handle any errors fetching credentials
# If credentials were explicitly configured but unavailable, it's an error
credential_errors[node.id][
field_name
] = f"Credentials not available: {e}"
@@ -338,19 +313,7 @@ async def _validate_node_input_credentials(
] = "Invalid credentials: type/provider mismatch"
continue
# If node has optional credentials and any are missing, mark for skipping
# But only if there are no other errors for this node
if (
has_missing_credentials
and node.credentials_optional
and node.id not in credential_errors
):
nodes_to_skip.add(node.id)
logger.info(
f"Node #{node.id} will be skipped: optional credentials not configured"
)
return credential_errors, nodes_to_skip
return credential_errors
def make_node_credentials_input_map(
@@ -392,25 +355,21 @@ async def validate_graph_with_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
) -> Mapping[str, Mapping[str, str]]:
"""
Validate graph including credentials and return structured errors per node,
along with a set of nodes that should be skipped due to optional missing credentials.
Validate graph including credentials and return structured errors per node.
Returns:
tuple[
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
dict[node_id, dict[field_name, error_message]]: Validation errors per node
"""
# Get input validation errors
node_input_errors = GraphModel.validate_graph_get_errors(
graph, for_run=True, nodes_input_masks=nodes_input_masks
)
# Get credential input/availability/validation errors and nodes to skip
node_credential_input_errors, nodes_to_skip = (
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
# Get credential input/availability/validation errors
node_credential_input_errors = await _validate_node_input_credentials(
graph, user_id, nodes_input_masks
)
# Merge credential errors with structural errors
@@ -419,7 +378,7 @@ async def validate_graph_with_credentials(
node_input_errors[node_id] = {}
node_input_errors[node_id].update(field_errors)
return node_input_errors, nodes_to_skip
return node_input_errors
async def _construct_starting_node_execution_input(
@@ -427,7 +386,7 @@ async def _construct_starting_node_execution_input(
user_id: str,
graph_inputs: BlockInput,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
This function checks the graph for starting nodes, validates the input data
@@ -441,14 +400,11 @@ async def _construct_starting_node_execution_input(
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
Returns:
tuple[
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
and the corresponding input data for that node.
set[str]: Node IDs that should be skipped (optional credentials not configured)
]
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
# Use new validation function that includes credentials
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
validation_errors = await validate_graph_with_credentials(
graph, user_id, nodes_input_masks
)
n_error_nodes = len(validation_errors)
@@ -489,7 +445,7 @@ async def _construct_starting_node_execution_input(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
return nodes_input, nodes_to_skip
return nodes_input
async def validate_and_construct_node_execution_input(
@@ -500,7 +456,7 @@ async def validate_and_construct_node_execution_input(
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
is_sub_graph: bool = False,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
"""
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
This centralizes the logic used by both scheduler validation and actual execution.
@@ -517,7 +473,6 @@ async def validate_and_construct_node_execution_input(
GraphModel: Full graph object for the given `graph_id`.
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
dict[str, BlockInput]: Node input masks including all passed-in credentials.
set[str]: Node IDs that should be skipped (optional credentials not configured).
Raises:
NotFoundError: If the graph is not found.
@@ -559,16 +514,14 @@ async def validate_and_construct_node_execution_input(
nodes_input_masks or {},
)
starting_nodes_input, nodes_to_skip = (
await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
starting_nodes_input = await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
return graph, starting_nodes_input, nodes_input_masks
def _merge_nodes_input_masks(
@@ -826,9 +779,6 @@ async def add_graph_execution(
# Use existing execution's compiled input masks
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
# For resumed executions, nodes_to_skip was already determined at creation time
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
nodes_to_skip: set[str] = set()
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
else:
@@ -837,7 +787,7 @@ async def add_graph_execution(
)
# Create new execution
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
graph, starting_nodes_input, compiled_nodes_input_masks = (
await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
@@ -886,7 +836,6 @@ async def add_graph_execution(
try:
graph_exec_entry = graph_exec.to_graph_execution_entry(
compiled_nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip,
execution_context=execution_context,
)
logger.info(f"Publishing execution {graph_exec.id} to execution queue")

View File

@@ -367,13 +367,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
)
# Setup mock returns
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
nodes_to_skip: set[str] = set()
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip,
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
@@ -459,212 +456,3 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
# Both executions should succeed (though they create different objects)
assert result1 == mock_graph_exec
assert result2 == mock_graph_exec_2
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================
@pytest.mark.asyncio
async def test_validate_node_input_credentials_returns_nodes_to_skip(
mocker: MockerFixture,
):
"""
Test that _validate_node_input_credentials returns nodes_to_skip set
for nodes with credentials_optional=True and missing credentials.
"""
from backend.executor.utils import _validate_node_input_credentials
# Create a mock node with credentials_optional=True
mock_node = mocker.MagicMock()
mock_node.id = "node-with-optional-creds"
mock_node.credentials_optional = True
mock_node.input_default = {} # No credentials configured
# Create a mock block with credentials field
mock_block = mocker.MagicMock()
mock_credentials_field_type = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_node.block = mock_block
# Create mock graph
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Call the function
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Node should be in nodes_to_skip, not in errors
assert mock_node.id in nodes_to_skip
assert mock_node.id not in errors
@pytest.mark.asyncio
async def test_validate_node_input_credentials_required_missing_creds_error(
mocker: MockerFixture,
):
"""
Test that _validate_node_input_credentials returns errors
for nodes with credentials_optional=False and missing credentials.
"""
from backend.executor.utils import _validate_node_input_credentials
# Create a mock node with credentials_optional=False (required)
mock_node = mocker.MagicMock()
mock_node.id = "node-with-required-creds"
mock_node.credentials_optional = False
mock_node.input_default = {} # No credentials configured
# Create a mock block with credentials field
mock_block = mocker.MagicMock()
mock_credentials_field_type = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_node.block = mock_block
# Create mock graph
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Call the function
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Node should be in errors, not in nodes_to_skip
assert mock_node.id in errors
assert "credentials" in errors[mock_node.id]
assert "required" in errors[mock_node.id]["credentials"].lower()
assert mock_node.id not in nodes_to_skip
@pytest.mark.asyncio
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
mocker: MockerFixture,
):
"""
Test that validate_graph_with_credentials returns nodes_to_skip set
from _validate_node_input_credentials.
"""
from backend.executor.utils import validate_graph_with_credentials
# Mock _validate_node_input_credentials to return specific values
mock_validate = mocker.patch(
"backend.executor.utils._validate_node_input_credentials"
)
expected_errors = {"node1": {"field": "error"}}
expected_nodes_to_skip = {"node2", "node3"}
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
# Mock GraphModel with validate_graph_get_errors method
mock_graph = mocker.MagicMock()
mock_graph.validate_graph_get_errors.return_value = {}
# Call the function
errors, nodes_to_skip = await validate_graph_with_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Verify nodes_to_skip is passed through
assert nodes_to_skip == expected_nodes_to_skip
assert "node1" in errors
@pytest.mark.asyncio
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
"""
Test that add_graph_execution properly passes nodes_to_skip
to the graph execution entry.
"""
from backend.data.execution import GraphExecutionWithNodes
from backend.executor.utils import add_graph_execution
# Mock data
graph_id = "test-graph-id"
user_id = "test-user-id"
inputs = {"test_input": "test_value"}
graph_version = 1
# Mock the graph object
mock_graph = mocker.MagicMock()
mock_graph.version = graph_version
# Starting nodes and masks
starting_nodes_input = [("node1", {"input1": "value1"})]
compiled_nodes_input_masks = {}
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
# Mock the graph execution object
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = []
# Track what's passed to to_graph_execution_entry
captured_kwargs = {}
def capture_to_entry(**kwargs):
captured_kwargs.update(kwargs)
return mocker.MagicMock()
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
# Setup mocks
mock_validate = mocker.patch(
"backend.executor.utils.validate_and_construct_node_execution_input"
)
mock_edb = mocker.patch("backend.executor.utils.execution_db")
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_udb = mocker.patch("backend.executor.utils.user_db")
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
# Setup returns - include nodes_to_skip in the tuple
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip, # This should be passed through
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
mock_user = mocker.MagicMock()
mock_user.timezone = "UTC"
mock_settings = mocker.MagicMock()
mock_settings.human_in_the_loop_safe_mode = True
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
mock_get_queue.return_value = mocker.AsyncMock()
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
# Call the function
await add_graph_execution(
graph_id=graph_id,
user_id=user_id,
inputs=inputs,
graph_version=graph_version,
)
# Verify nodes_to_skip was passed to to_graph_execution_entry
assert "nodes_to_skip" in captured_kwargs
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip

View File

@@ -8,7 +8,6 @@ from .discord import DiscordOAuthHandler
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
from .reddit import RedditOAuthHandler
from .twitter import TwitterOAuthHandler
if TYPE_CHECKING:
@@ -21,7 +20,6 @@ _ORIGINAL_HANDLERS = [
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
RedditOAuthHandler,
TwitterOAuthHandler,
TodoistOAuthHandler,
]

View File

@@ -1,208 +0,0 @@
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
from backend.util.settings import Settings
settings = Settings()
class RedditOAuthHandler(BaseOAuthHandler):
"""
Reddit OAuth 2.0 handler.
Based on the documentation at:
- https://github.com/reddit-archive/reddit/wiki/OAuth2
Notes:
- Reddit requires `duration=permanent` to get refresh tokens
- Access tokens expire after 1 hour (3600 seconds)
- Reddit requires HTTP Basic Auth for token requests
- Reddit requires a unique User-Agent header
"""
PROVIDER_NAME = ProviderName.REDDIT
DEFAULT_SCOPES: ClassVar[list[str]] = [
"identity", # Get username, verify auth
"read", # Access posts and comments
"submit", # Submit new posts and comments
"edit", # Edit own posts and comments
"history", # Access user's post history
"privatemessages", # Access inbox and send private messages
"flair", # Access and set flair on posts/subreddits
]
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
USERNAME_URL = "https://oauth.reddit.com/api/v1/me"
REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
"""Generate Reddit OAuth 2.0 authorization URL"""
scopes = self.handle_default_scopes(scopes)
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(scopes),
"state": state,
"duration": "permanent", # Required for refresh tokens
}
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
"""Exchange authorization code for access tokens"""
scopes = self.handle_default_scopes(scopes)
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
}
# Reddit requires HTTP Basic Auth for token requests
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.TOKEN_URL, headers=headers, data=data, auth=auth
)
if not response.ok:
error_text = response.text()
raise ValueError(
f"Reddit token exchange failed: {response.status} - {error_text}"
)
tokens = response.json()
if "error" in tokens:
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
username = await self._get_username(tokens["access_token"])
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
username=username,
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
scopes=scopes,
)
async def _get_username(self, access_token: str) -> str:
"""Get the username from the access token"""
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": settings.config.reddit_user_agent,
}
response = await Requests().get(self.USERNAME_URL, headers=headers)
if not response.ok:
raise ValueError(f"Failed to get Reddit username: {response.status}")
data = response.json()
return data.get("name", "unknown")
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
"""Refresh access tokens using refresh token"""
if not credentials.refresh_token:
raise ValueError("No refresh token available")
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
}
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.TOKEN_URL, headers=headers, data=data, auth=auth
)
if not response.ok:
error_text = response.text()
raise ValueError(
f"Reddit token refresh failed: {response.status} - {error_text}"
)
tokens = response.json()
if "error" in tokens:
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
username = await self._get_username(tokens["access_token"])
# Reddit may or may not return a new refresh token
new_refresh_token = tokens.get("refresh_token")
if new_refresh_token:
refresh_token: SecretStr | None = SecretStr(new_refresh_token)
elif credentials.refresh_token:
# Keep the existing refresh token
refresh_token = credentials.refresh_token
else:
refresh_token = None
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=credentials.title,
username=username,
access_token=tokens["access_token"],
refresh_token=refresh_token,
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
refresh_token_expires_at=None,
scopes=credentials.scopes,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
"""Revoke the access token"""
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
}
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.REVOKE_URL, headers=headers, data=data, auth=auth
)
# Reddit returns 204 No Content on successful revocation
return response.ok

View File

@@ -149,10 +149,10 @@ async def setup_webhook_for_block(
async def migrate_legacy_triggered_graphs():
from prisma.models import AgentGraph
from backend.api.features.library.db import create_preset
from backend.api.features.library.model import LibraryAgentPresetCreatable
from backend.data.graph import AGENT_GRAPH_INCLUDE, GraphModel, set_node_webhook
from backend.data.model import is_credentials_field_name
from backend.server.v2.library.db import create_preset
from backend.server.v2.library.model import LibraryAgentPresetCreatable
triggered_graphs = [
GraphModel.from_db(_graph)

View File

@@ -1,5 +1,5 @@
from backend.api.rest_api import AgentServer
from backend.app import run_processes
from backend.server.rest_api import AgentServer
def main():

View File

@@ -3,12 +3,12 @@ from typing import Dict, Set
from fastapi import WebSocket
from backend.api.model import NotificationPayload, WSMessage, WSMethod
from backend.data.execution import (
ExecutionEventType,
GraphExecutionEvent,
NodeExecutionEvent,
)
from backend.server.model import NotificationPayload, WSMessage, WSMethod
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,

View File

@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock
import pytest
from fastapi import WebSocket
from backend.api.conn_manager import ConnectionManager
from backend.api.model import NotificationPayload, WSMessage, WSMethod
from backend.data.execution import (
ExecutionStatus,
GraphExecutionEvent,
NodeExecutionEvent,
)
from backend.server.conn_manager import ConnectionManager
from backend.server.model import NotificationPayload, WSMessage, WSMethod
@pytest.fixture

View File

@@ -0,0 +1,29 @@
from fastapi import FastAPI
from backend.monitoring.instrumentation import instrument_fastapi
from backend.server.middleware.security import SecurityHeadersMiddleware
from .routes.integrations import integrations_router
from .routes.tools import tools_router
from .routes.v1 import v1_router
external_app = FastAPI(
title="AutoGPT External API",
description="External API for AutoGPT integrations",
docs_url="/docs",
version="1.0",
)
external_app.add_middleware(SecurityHeadersMiddleware)
external_app.include_router(v1_router, prefix="/v1")
external_app.include_router(tools_router, prefix="/v1")
external_app.include_router(integrations_router, prefix="/v1")
# Add Prometheus instrumentation
instrument_fastapi(
external_app,
service_name="external-api",
expose_endpoint=True,
endpoint="/metrics",
include_in_schema=True,
)

View File

@@ -0,0 +1,36 @@
from fastapi import HTTPException, Security
from fastapi.security import APIKeyHeader
from prisma.enums import APIKeyPermission
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
"""Base middleware for API key authentication"""
if api_key is None:
raise HTTPException(status_code=401, detail="Missing API key")
api_key_obj = await validate_api_key(api_key)
if not api_key_obj:
raise HTTPException(status_code=401, detail="Invalid API key")
return api_key_obj
def require_permission(permission: APIKeyPermission):
"""Dependency function for checking specific permissions"""
async def check_permission(
api_key: APIKeyInfo = Security(require_api_key),
) -> APIKeyInfo:
if not has_permission(api_key, permission):
raise HTTPException(
status_code=403,
detail=f"API key lacks the required permission '{permission}'",
)
return api_key
return check_permission

Some files were not shown because too many files have changed in this diff Show More