mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-08 13:55:06 -05:00
Compare commits
4 Commits
claude/tes
...
fix/execut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b20f4cd13 | ||
|
|
a3d0f9cbd2 | ||
|
|
02ddb51446 | ||
|
|
750e096f15 |
@@ -1,37 +0,0 @@
|
||||
{
|
||||
"worktreeCopyPatterns": [
|
||||
".env*",
|
||||
".vscode/**",
|
||||
".auth/**",
|
||||
".claude/**",
|
||||
"autogpt_platform/.env*",
|
||||
"autogpt_platform/backend/.env*",
|
||||
"autogpt_platform/frontend/.env*",
|
||||
"autogpt_platform/frontend/.auth/**",
|
||||
"autogpt_platform/db/docker/.env*"
|
||||
],
|
||||
"worktreeCopyIgnores": [
|
||||
"**/node_modules/**",
|
||||
"**/dist/**",
|
||||
"**/.git/**",
|
||||
"**/Thumbs.db",
|
||||
"**/.DS_Store",
|
||||
"**/.next/**",
|
||||
"**/__pycache__/**",
|
||||
"**/.ruff_cache/**",
|
||||
"**/.pytest_cache/**",
|
||||
"**/*.pyc",
|
||||
"**/playwright-report/**",
|
||||
"**/logs/**",
|
||||
"**/site/**"
|
||||
],
|
||||
"worktreePathTemplate": "$BASE_PATH.worktree",
|
||||
"postCreateCmd": [
|
||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||
"cd autogpt_platform/frontend && pnpm install",
|
||||
"cd docs && pip install -r requirements.txt"
|
||||
],
|
||||
"terminalCommand": "code .",
|
||||
"deleteBranchWithWorktree": false
|
||||
}
|
||||
@@ -16,7 +16,6 @@
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
|
||||
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
12
.github/workflows/copilot-setup-steps.yml
vendored
12
.github/workflows/copilot-setup-steps.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -108,16 +108,6 @@ jobs:
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
# Remove large unused tools to free disk space for Docker builds
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -134,7 +134,7 @@ jobs:
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
|
||||
@@ -11,7 +11,7 @@ jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
|
||||
2
.github/workflows/repo-pr-label.yml
vendored
2
.github/workflows/repo-pr-label.yml
vendored
@@ -61,6 +61,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v6
|
||||
- uses: actions/labeler@v5
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
@@ -12,7 +12,6 @@ reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
@@ -34,7 +33,6 @@ init-env:
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
@@ -57,9 +57,6 @@ class APIKeySmith:
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
if not raw_key.startswith(self.PREFIX):
|
||||
raise ValueError("Key without 'agpt_' prefix would fail validation")
|
||||
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to add 401 responses
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
# Wrap current method to allow stacking OpenAPI schema modifiers like this
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
|
||||
@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ import fastapi.testclient
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
from .v1.routes import v1_router
|
||||
|
||||
external_api = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||
external_api.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_api,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
@@ -1,107 +0,0 @@
|
||||
from fastapi import HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidTokenError,
|
||||
OAuthAccessTokenInfo,
|
||||
validate_access_token,
|
||||
)
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
bearer_auth = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Middleware for API key authentication only"""
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
|
||||
)
|
||||
|
||||
api_key_obj = await validate_api_key(api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
return api_key_obj
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> OAuthAccessTokenInfo:
|
||||
"""Middleware for OAuth access token authentication only"""
|
||||
if bearer is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Authorization header",
|
||||
)
|
||||
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
return token_info
|
||||
|
||||
|
||||
async def require_auth(
|
||||
api_key: str | None = Security(api_key_header),
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
"""
|
||||
Unified authentication middleware supporting both API keys and OAuth tokens.
|
||||
|
||||
Supports two authentication methods, which are checked in order:
|
||||
1. X-API-Key header (existing API key authentication)
|
||||
2. Authorization: Bearer <token> header (OAuth access token)
|
||||
|
||||
Returns:
|
||||
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
|
||||
"""
|
||||
# Try API key first
|
||||
if api_key is not None:
|
||||
api_key_info = await validate_api_key(api_key)
|
||||
if api_key_info:
|
||||
return api_key_info
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
# Try OAuth bearer token
|
||||
if bearer is not None:
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
return token_info
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
# No credentials provided
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication. Provide API key or access token.",
|
||||
)
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""
|
||||
Dependency function for checking specific permissions
|
||||
(works with API keys and OAuth tokens)
|
||||
"""
|
||||
|
||||
async def check_permission(
|
||||
auth: APIAuthorizationInfo = Security(require_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
if permission not in auth.scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission: {permission.value}",
|
||||
)
|
||||
return auth
|
||||
|
||||
return check_permission
|
||||
@@ -1,340 +0,0 @@
|
||||
"""Tests for analytics API endpoints."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from .analytics import router as analytics_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_metric endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging."""
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "metric-123-uuid"
|
||||
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string,test_id",
|
||||
[
|
||||
(100, "api_calls_count", "external_api", "integer_value"),
|
||||
(0, "error_count", "no_errors", "zero_value"),
|
||||
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||
(999999999, "large_number", "max_value", "large_number"),
|
||||
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_various_values(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
metric_value: float,
|
||||
metric_name: str,
|
||||
data_string: str,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values."""
|
||||
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": metric_name,
|
||||
"metric_value": metric_value,
|
||||
"data_string": data_string,
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"metric_id": response.json(), "test_case": test_id},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
f"analytics_metric_{test_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"metric_name": "test"}, "Field required"),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
|
||||
"Input should be a valid number",
|
||||
),
|
||||
(
|
||||
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_metric_value_and_data_string",
|
||||
"invalid_metric_value_type",
|
||||
"empty_metric_name",
|
||||
"empty_data_string",
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid metric requests."""
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_metric_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "test_metric",
|
||||
"metric_value": 1.0,
|
||||
"data_string": "test",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Database connection failed" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_analytics endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging."""
|
||||
mock_result = Mock(id="analytics-789-uuid")
|
||||
mock_log_analytics = mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "user_action",
|
||||
"data": {
|
||||
"action": "button_click",
|
||||
"button_id": "submit_form",
|
||||
"timestamp": "2023-01-01T00:00:00Z",
|
||||
"metadata": {"form_type": "registration", "fields_filled": 5},
|
||||
},
|
||||
"data_index": "button_click_submit_form",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "analytics-789-uuid"
|
||||
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_analytics_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_complex_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw analytics logging with complex nested data structures."""
|
||||
mock_result = Mock(id="analytics-complex-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "agent_execution",
|
||||
"data": {
|
||||
"agent_id": "agent_123",
|
||||
"execution_id": "exec_456",
|
||||
"status": "completed",
|
||||
"duration_ms": 3500,
|
||||
"nodes_executed": 15,
|
||||
"blocks_used": [
|
||||
{"block_id": "llm_block", "count": 3},
|
||||
{"block_id": "http_block", "count": 5},
|
||||
{"block_id": "code_block", "count": 2},
|
||||
],
|
||||
"errors": [],
|
||||
"metadata": {
|
||||
"trigger": "manual",
|
||||
"user_tier": "premium",
|
||||
"environment": "production",
|
||||
},
|
||||
},
|
||||
"data_index": "agent_123_exec_456",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"analytics_id": response.json(), "logged_data": request_data["data"]},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
"analytics_log_analytics_complex_data",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"type": "test"}, "Field required"),
|
||||
(
|
||||
{"type": "test", "data": "not_a_dict", "data_index": "test"},
|
||||
"Input should be a valid dictionary",
|
||||
),
|
||||
({"type": "test", "data": {"key": "value"}}, "Field required"),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_data_and_data_index",
|
||||
"invalid_data_type",
|
||||
"missing_data_index",
|
||||
],
|
||||
)
|
||||
def test_log_raw_analytics_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid analytics requests."""
|
||||
response = client.post("/log_raw_analytics", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_analytics_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Analytics DB unreachable"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "test_event",
|
||||
"data": {"key": "value"},
|
||||
"data_index": "test_index",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Analytics DB unreachable" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
@@ -1,833 +0,0 @@
|
||||
"""
|
||||
OAuth 2.0 Provider Endpoints
|
||||
|
||||
Implements OAuth 2.0 Authorization Code flow with PKCE support.
|
||||
|
||||
Flow:
|
||||
1. User clicks "Login with AutoGPT" in 3rd party app
|
||||
2. App redirects user to /auth/authorize with client_id, redirect_uri, scope, state
|
||||
3. User sees consent screen (if not already logged in, redirects to login first)
|
||||
4. User approves → backend creates authorization code
|
||||
5. User redirected back to app with code
|
||||
6. App exchanges code for access/refresh tokens at /api/oauth/token
|
||||
7. App uses access token to call external API endpoints
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
|
||||
from gcloud.aio import storage as async_storage
|
||||
from PIL import Image
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidGrantError,
|
||||
OAuthApplicationInfo,
|
||||
TokenIntrospectionResult,
|
||||
consume_authorization_code,
|
||||
create_access_token,
|
||||
create_authorization_code,
|
||||
create_refresh_token,
|
||||
get_oauth_application,
|
||||
get_oauth_application_by_id,
|
||||
introspect_token,
|
||||
list_user_oauth_applications,
|
||||
refresh_tokens,
|
||||
revoke_access_token,
|
||||
revoke_refresh_token,
|
||||
update_oauth_application,
|
||||
validate_client_credentials,
|
||||
validate_redirect_uri,
|
||||
validate_scopes,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""OAuth 2.0 token response"""
|
||||
|
||||
token_type: Literal["Bearer"] = "Bearer"
|
||||
access_token: str
|
||||
access_token_expires_at: datetime
|
||||
refresh_token: str
|
||||
refresh_token_expires_at: datetime
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""OAuth 2.0 error response"""
|
||||
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthApplicationPublicInfo(BaseModel):
|
||||
"""Public information about an OAuth application (for consent screen)"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Info Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/app/{client_id}",
|
||||
responses={
|
||||
404: {"description": "Application not found or disabled"},
|
||||
},
|
||||
)
|
||||
async def get_oauth_app_info(
|
||||
client_id: str, user_id: str = Security(get_user_id)
|
||||
) -> OAuthApplicationPublicInfo:
|
||||
"""
|
||||
Get public information about an OAuth application.
|
||||
|
||||
This endpoint is used by the consent screen to display application details
|
||||
to the user before they authorize access.
|
||||
|
||||
Returns:
|
||||
- name: Application name
|
||||
- description: Application description (if provided)
|
||||
- scopes: List of scopes the application is allowed to request
|
||||
"""
|
||||
app = await get_oauth_application(client_id)
|
||||
if not app or not app.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found",
|
||||
)
|
||||
|
||||
return OAuthApplicationPublicInfo(
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
logo_url=app.logo_url,
|
||||
scopes=[s.value for s in app.scopes],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AuthorizeRequest(BaseModel):
|
||||
"""OAuth 2.0 authorization request"""
|
||||
|
||||
client_id: str = Field(description="Client identifier")
|
||||
redirect_uri: str = Field(description="Redirect URI")
|
||||
scopes: list[str] = Field(description="List of scopes")
|
||||
state: str = Field(description="Anti-CSRF token from client")
|
||||
response_type: str = Field(
|
||||
default="code", description="Must be 'code' for authorization code flow"
|
||||
)
|
||||
code_challenge: str = Field(description="PKCE code challenge (required)")
|
||||
code_challenge_method: Literal["S256", "plain"] = Field(
|
||||
default="S256", description="PKCE code challenge method (S256 recommended)"
|
||||
)
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
"""OAuth 2.0 authorization response with redirect URL"""
|
||||
|
||||
redirect_url: str = Field(description="URL to redirect the user to")
|
||||
|
||||
|
||||
@router.post("/authorize")
|
||||
async def authorize(
|
||||
request: AuthorizeRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""
|
||||
OAuth 2.0 Authorization Endpoint
|
||||
|
||||
User must be logged in (authenticated with Supabase JWT).
|
||||
This endpoint creates an authorization code and returns a redirect URL.
|
||||
|
||||
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
|
||||
|
||||
The frontend consent screen should call this endpoint after the user approves,
|
||||
then redirect the user to the returned `redirect_url`.
|
||||
|
||||
Request Body:
|
||||
- client_id: The OAuth application's client ID
|
||||
- redirect_uri: Where to redirect after authorization (must match registered URI)
|
||||
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
|
||||
- state: Anti-CSRF token provided by client (will be returned in redirect)
|
||||
- response_type: Must be "code" (for authorization code flow)
|
||||
- code_challenge: PKCE code challenge (required)
|
||||
- code_challenge_method: "S256" (recommended) or "plain"
|
||||
|
||||
Returns:
|
||||
- redirect_url: The URL to redirect the user to (includes authorization code)
|
||||
|
||||
Error cases return a redirect_url with error parameters, or raise HTTPException
|
||||
for critical errors (like invalid redirect_uri).
|
||||
"""
|
||||
try:
|
||||
# Validate response_type
|
||||
if request.response_type != "code":
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"unsupported_response_type",
|
||||
"Only 'code' response type is supported",
|
||||
)
|
||||
|
||||
# Get application
|
||||
app = await get_oauth_application(request.client_id)
|
||||
if not app:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Unknown client_id",
|
||||
)
|
||||
|
||||
if not app.is_active:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Application is not active",
|
||||
)
|
||||
|
||||
# Validate redirect URI
|
||||
if not validate_redirect_uri(app, request.redirect_uri):
|
||||
# For invalid redirect_uri, we can't redirect safely
|
||||
# Must return error instead
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"Invalid redirect_uri. "
|
||||
f"Must be one of: {', '.join(app.redirect_uris)}"
|
||||
),
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
try:
|
||||
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
|
||||
except ValueError as e:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
f"Invalid scope: {e}",
|
||||
)
|
||||
|
||||
if not requested_scopes:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"At least one scope is required",
|
||||
)
|
||||
|
||||
if not validate_scopes(app, requested_scopes):
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"Application is not authorized for all requested scopes. "
|
||||
f"Allowed: {', '.join(s.value for s in app.scopes)}",
|
||||
)
|
||||
|
||||
# Create authorization code
|
||||
auth_code = await create_authorization_code(
|
||||
application_id=app.id,
|
||||
user_id=user_id,
|
||||
scopes=requested_scopes,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_challenge=request.code_challenge,
|
||||
code_challenge_method=request.code_challenge_method,
|
||||
)
|
||||
|
||||
# Build redirect URL with authorization code
|
||||
params = {
|
||||
"code": auth_code.code,
|
||||
"state": request.state,
|
||||
}
|
||||
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
|
||||
|
||||
logger.info(
|
||||
f"Authorization code issued for user #{user_id} "
|
||||
f"and app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"server_error",
|
||||
"An unexpected error occurred",
|
||||
)
|
||||
|
||||
|
||||
def _error_redirect_url(
|
||||
redirect_uri: str,
|
||||
state: str,
|
||||
error: str,
|
||||
error_description: Optional[str] = None,
|
||||
) -> AuthorizeResponse:
|
||||
"""Helper to build redirect URL with OAuth error parameters"""
|
||||
params = {
|
||||
"error": error,
|
||||
"state": state,
|
||||
}
|
||||
if error_description:
|
||||
params["error_description"] = error_description
|
||||
|
||||
redirect_url = f"{redirect_uri}?{urlencode(params)}"
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenRequestByCode(BaseModel):
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(description="Authorization code")
|
||||
redirect_uri: str = Field(
|
||||
description="Redirect URI (must match authorization request)"
|
||||
)
|
||||
client_id: str
|
||||
client_secret: str
|
||||
code_verifier: str = Field(description="PKCE code verifier")
|
||||
|
||||
|
||||
class TokenRequestByRefreshToken(BaseModel):
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
async def token(
|
||||
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Endpoint
|
||||
|
||||
Exchanges authorization code or refresh token for access token.
|
||||
|
||||
Grant Types:
|
||||
1. authorization_code: Exchange authorization code for tokens
|
||||
- Required: grant_type, code, redirect_uri, client_id, client_secret
|
||||
- Optional: code_verifier (required if PKCE was used)
|
||||
|
||||
2. refresh_token: Exchange refresh token for new access token
|
||||
- Required: grant_type, refresh_token, client_id, client_secret
|
||||
|
||||
Returns:
|
||||
- access_token: Bearer token for API access (1 hour TTL)
|
||||
- token_type: "Bearer"
|
||||
- expires_in: Seconds until access token expires
|
||||
- refresh_token: Token for refreshing access (30 days TTL)
|
||||
- scopes: List of scopes
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(
|
||||
request.client_id, request.client_secret
|
||||
)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Handle authorization_code grant
|
||||
if request.grant_type == "authorization_code":
|
||||
# Consume authorization code
|
||||
try:
|
||||
user_id, scopes = await consume_authorization_code(
|
||||
code=request.code,
|
||||
application_id=app.id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_verifier=request.code_verifier,
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Create access and refresh tokens
|
||||
access_token = await create_access_token(app.id, user_id, scopes)
|
||||
refresh_token = await create_refresh_token(app.id, user_id, scopes)
|
||||
|
||||
logger.info(
|
||||
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
|
||||
"via authorization code"
|
||||
)
|
||||
|
||||
if not access_token.token or not refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=access_token.token.get_secret_value(),
|
||||
access_token_expires_at=access_token.expires_at,
|
||||
refresh_token=refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=refresh_token.expires_at,
|
||||
scopes=list(s.value for s in scopes),
|
||||
)
|
||||
|
||||
# Handle refresh_token grant
|
||||
elif request.grant_type == "refresh_token":
|
||||
# Refresh access token
|
||||
try:
|
||||
new_access_token, new_refresh_token = await refresh_tokens(
|
||||
request.refresh_token, app.id
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Tokens refreshed for user #{new_access_token.user_id} "
|
||||
f"by app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
if not new_access_token.token or not new_refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=new_access_token.token.get_secret_value(),
|
||||
access_token_expires_at=new_access_token.expires_at,
|
||||
refresh_token=new_refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=new_refresh_token.expires_at,
|
||||
scopes=list(s.value for s in new_access_token.scopes),
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported grant_type: {request.grant_type}. "
|
||||
"Must be 'authorization_code' or 'refresh_token'",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/introspect")
|
||||
async def introspect(
|
||||
token: str = Body(description="Token to introspect"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
) -> TokenIntrospectionResult:
|
||||
"""
|
||||
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
|
||||
|
||||
Allows clients to check if a token is valid and get its metadata.
|
||||
|
||||
Returns:
|
||||
- active: Whether the token is currently active
|
||||
- scopes: List of authorized scopes (if active)
|
||||
- client_id: The client the token was issued to (if active)
|
||||
- user_id: The user the token represents (if active)
|
||||
- exp: Expiration timestamp (if active)
|
||||
- token_type: "access_token" or "refresh_token" (if active)
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Introspect the token
|
||||
return await introspect_token(token, token_type_hint)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/revoke")
|
||||
async def revoke(
|
||||
token: str = Body(description="Token to revoke"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
):
|
||||
"""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
|
||||
|
||||
Allows clients to revoke an access or refresh token.
|
||||
|
||||
Note: Revoking a refresh token does NOT revoke associated access tokens.
|
||||
Revoking an access token does NOT revoke the associated refresh token.
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Try to revoke as access token first
|
||||
# Note: We pass app.id to ensure the token belongs to the authenticated app
|
||||
if token_type_hint != "refresh_token":
|
||||
revoked = await revoke_access_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Access token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Try to revoke as refresh token
|
||||
revoked = await revoke_refresh_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Refresh token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Per RFC 7009, revocation endpoint returns 200 even if token not found
|
||||
# or if token belongs to a different application.
|
||||
# This prevents token scanning attacks.
|
||||
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Management Endpoints (for app owners)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get("/apps/mine")
|
||||
async def list_my_oauth_apps(
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> list[OAuthApplicationInfo]:
|
||||
"""
|
||||
List all OAuth applications owned by the current user.
|
||||
|
||||
Returns a list of OAuth applications with their details including:
|
||||
- id, name, description, logo_url
|
||||
- client_id (public identifier)
|
||||
- redirect_uris, grant_types, scopes
|
||||
- is_active status
|
||||
- created_at, updated_at timestamps
|
||||
|
||||
Note: client_secret is never returned for security reasons.
|
||||
"""
|
||||
return await list_user_oauth_applications(user_id)
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/status")
|
||||
async def update_app_status(
|
||||
app_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
is_active: bool = Body(description="Whether the app should be active", embed=True),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Enable or disable an OAuth application.
|
||||
|
||||
Only the application owner can update the status.
|
||||
When disabled, the application cannot be used for new authorizations
|
||||
and existing access tokens will fail validation.
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
action = "enabled" if is_active else "disabled"
|
||||
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
class UpdateAppLogoRequest(BaseModel):
|
||||
logo_url: str = Field(description="URL of the uploaded logo image")
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/logo")
|
||||
async def update_app_logo(
|
||||
app_id: str,
|
||||
request: UpdateAppLogoRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Update the logo URL for an OAuth application.
|
||||
|
||||
Only the application owner can update the logo.
|
||||
The logo should be uploaded first using the media upload endpoint,
|
||||
then this endpoint is called with the resulting URL.
|
||||
|
||||
Logo requirements:
|
||||
- Must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=request.logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
# Logo upload constraints
|
||||
LOGO_MIN_SIZE = 512
|
||||
LOGO_MAX_SIZE = 2048
|
||||
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
|
||||
|
||||
|
||||
@router.post("/apps/{app_id}/logo/upload")
|
||||
async def upload_app_logo(
|
||||
app_id: str,
|
||||
file: UploadFile,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Upload a logo image for an OAuth application.
|
||||
|
||||
Requirements:
|
||||
- Image must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
- Allowed formats: JPEG, PNG, WebP
|
||||
- Maximum file size: 3MB
|
||||
|
||||
The image is uploaded to cloud storage and the app's logoUrl is updated.
|
||||
Returns the updated application info.
|
||||
"""
|
||||
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Check GCS configuration
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Media storage is not configured",
|
||||
)
|
||||
|
||||
# Validate content type
|
||||
content_type = file.content_type
|
||||
if content_type not in LOGO_ALLOWED_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
file_bytes = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading logo file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to read uploaded file",
|
||||
)
|
||||
|
||||
# Check file size
|
||||
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"File too large. "
|
||||
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
|
||||
),
|
||||
)
|
||||
|
||||
# Validate image dimensions
|
||||
try:
|
||||
image = Image.open(io.BytesIO(file_bytes))
|
||||
width, height = image.size
|
||||
|
||||
if width != height:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo must be square. Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width < LOGO_MIN_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width > LOGO_MAX_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating logo image: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid image file",
|
||||
)
|
||||
|
||||
# Scan for viruses
|
||||
filename = file.filename or "logo"
|
||||
await scan_content_safe(file_bytes, filename=filename)
|
||||
|
||||
# Generate unique filename
|
||||
file_ext = os.path.splitext(filename)[1].lower() or ".png"
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
|
||||
|
||||
# Upload to GCS
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
|
||||
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading logo to GCS: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to upload logo",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
# Update the app with the new logo URL
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
|
||||
"""
|
||||
Delete the current logo file for the given app, if there is one in our cloud storage
|
||||
"""
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
|
||||
|
||||
if app.logo_url and app.logo_url.startswith(storage_base_url):
|
||||
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
|
||||
old_path = app.logo_url.replace(storage_base_url, "")
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
await async_client.delete(bucket_name, old_path)
|
||||
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
|
||||
except Exception as e:
|
||||
# Log but don't fail - the new logo was uploaded successfully
|
||||
logger.warning(
|
||||
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,41 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def sort_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to sort the endpoints,
|
||||
schemas, and responses.
|
||||
"""
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
|
||||
# Sort endpoints
|
||||
openapi_schema["paths"] = dict(sorted(openapi_schema["paths"].items()))
|
||||
|
||||
# Sort endpoints -> methods
|
||||
for p in openapi_schema["paths"].keys():
|
||||
openapi_schema["paths"][p] = dict(
|
||||
sorted(openapi_schema["paths"][p].items())
|
||||
)
|
||||
|
||||
# Sort endpoints -> methods -> responses
|
||||
for m in openapi_schema["paths"][p].keys():
|
||||
openapi_schema["paths"][p][m]["responses"] = dict(
|
||||
sorted(openapi_schema["paths"][p][m]["responses"].items())
|
||||
)
|
||||
|
||||
# Sort schemas and responses as well
|
||||
for k in openapi_schema["components"].keys():
|
||||
openapi_schema["components"][k] = dict(
|
||||
sorted(openapi_schema["components"][k].items())
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
@@ -36,10 +36,10 @@ def main(**kwargs):
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
@@ -50,7 +49,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -82,7 +81,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
|
||||
@@ -20,7 +20,6 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import Requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -247,11 +246,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -427,11 +422,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -608,11 +599,7 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -6,9 +6,6 @@ import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
from prisma.types import Serializable
|
||||
|
||||
from backend.sdk import (
|
||||
BaseWebhooksManager,
|
||||
@@ -87,9 +84,7 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
# update webhook config
|
||||
await update_webhook(
|
||||
webhook.id,
|
||||
config=cast(
|
||||
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
|
||||
),
|
||||
config={"base_id": base_id, "cursor": response.cursor},
|
||||
)
|
||||
|
||||
event_type = "notification"
|
||||
|
||||
@@ -106,10 +106,7 @@ class ConditionBlock(Block):
|
||||
ComparisonOperator.LESS_THAN_OR_EQUAL: lambda a, b: a <= b,
|
||||
}
|
||||
|
||||
try:
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Comparison failed: {e}") from e
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
|
||||
yield "result", result
|
||||
|
||||
|
||||
@@ -182,10 +182,13 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
# Handle missing key, null value, or valid list value
|
||||
if isinstance(first_result, dict):
|
||||
items = first_result.get("items") or []
|
||||
else:
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
|
||||
@@ -15,7 +15,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
@@ -60,18 +59,11 @@ class FirecrawlExtractBlock(Block):
|
||||
) -> BlockOutput:
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
schema=input_data.output_schema,
|
||||
enable_web_search=input_data.enable_web_search,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Extract failed: {e}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
schema=input_data.output_schema,
|
||||
enable_web_search=input_data.enable_web_search,
|
||||
)
|
||||
|
||||
yield "data", extract_result.data
|
||||
|
||||
@@ -19,7 +19,6 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import ModerationError
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -154,8 +153,6 @@ class AIImageEditorBlock(Block):
|
||||
),
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
seed=input_data.seed,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
yield "output_image", result
|
||||
|
||||
@@ -167,8 +164,6 @@ class AIImageEditorBlock(Block):
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
seed: Optional[int],
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
@@ -178,21 +173,11 @@ class AIImageEditorBlock(Block):
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
try:
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
except Exception as e:
|
||||
if "flagged as sensitive" in str(e).lower():
|
||||
raise ModerationError(
|
||||
message="Content was flagged as sensitive by the model provider",
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
moderation_type="model_provider",
|
||||
)
|
||||
raise ValueError(f"Model execution failed: {e}") from e
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
if isinstance(output, list) and output:
|
||||
output = output[0]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,184 +0,0 @@
|
||||
"""
|
||||
Shared helpers for Human-In-The-Loop (HITL) review functionality.
|
||||
Used by both the dedicated HumanInTheLoopBlock and blocks that require human review.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReviewDecision(BaseModel):
|
||||
"""Result of a review decision."""
|
||||
|
||||
should_proceed: bool
|
||||
message: str
|
||||
review_result: ReviewResult
|
||||
|
||||
|
||||
class HITLReviewHelper:
|
||||
"""Helper class for Human-In-The-Loop review operations."""
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
||||
"""Create or retrieve a human review from the database."""
|
||||
return await get_database_manager_async_client().get_or_create_human_review(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_node_execution_status(**kwargs) -> None:
|
||||
"""Update the execution status of a node."""
|
||||
await async_update_node_execution_status(
|
||||
db_client=get_database_manager_async_client(), **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_review_processed_status(
|
||||
node_exec_id: str, processed: bool
|
||||
) -> None:
|
||||
"""Update the processed status of a review."""
|
||||
return await get_database_manager_async_client().update_review_processed_status(
|
||||
node_exec_id, processed
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _handle_review_request(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Handle a review request for a block that requires human review.
|
||||
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
Returns:
|
||||
ReviewResult if review is complete, None if waiting for human input
|
||||
|
||||
Raises:
|
||||
Exception: If review creation or status update fails
|
||||
"""
|
||||
# Skip review if safe mode is disabled - return auto-approved result
|
||||
if not execution_context.safe_mode:
|
||||
logger.info(
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
||||
)
|
||||
return ReviewResult(
|
||||
data=input_data,
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="Auto-approved (safe mode disabled)",
|
||||
processed=True,
|
||||
node_exec_id=node_exec_id,
|
||||
)
|
||||
|
||||
result = await HITLReviewHelper.get_or_create_human_review(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
input_data=input_data,
|
||||
message=f"Review required for {block_name} execution",
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
logger.info(
|
||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return None # Signal that execution should pause
|
||||
|
||||
# Mark review as processed if not already done
|
||||
if not result.processed:
|
||||
await HITLReviewHelper.update_review_processed_status(
|
||||
node_exec_id=node_exec_id, processed=True
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def handle_review_decision(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewDecision]:
|
||||
"""
|
||||
Handle a review request and return the decision in a single call.
|
||||
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
Returns:
|
||||
ReviewDecision if review is complete (approved/rejected),
|
||||
None if execution should pause (awaiting review)
|
||||
"""
|
||||
review_result = await HITLReviewHelper._handle_review_request(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
if review_result is None:
|
||||
# Still awaiting review - return None to pause execution
|
||||
return None
|
||||
|
||||
# Review is complete, determine outcome
|
||||
should_proceed = review_result.status == ReviewStatus.APPROVED
|
||||
message = review_result.message or (
|
||||
"Execution approved by reviewer"
|
||||
if should_proceed
|
||||
else "Execution rejected by reviewer"
|
||||
)
|
||||
|
||||
return ReviewDecision(
|
||||
should_proceed=should_proceed, message=message, review_result=review_result
|
||||
)
|
||||
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.blocks.helpers.review import HITLReviewHelper
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -12,9 +11,11 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.data.model import SchemaField
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,11 +45,11 @@ class HumanInTheLoopBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
approved_data: Any = SchemaField(
|
||||
description="The data when approved (may be modified by reviewer)"
|
||||
reviewed_data: Any = SchemaField(
|
||||
description="The data after human review (may be modified)"
|
||||
)
|
||||
rejected_data: Any = SchemaField(
|
||||
description="The data when rejected (may be modified by reviewer)"
|
||||
status: Literal["approved", "rejected"] = SchemaField(
|
||||
description="Status of the review: 'approved' or 'rejected'"
|
||||
)
|
||||
review_message: str = SchemaField(
|
||||
description="Any message provided by the reviewer", default=""
|
||||
@@ -68,29 +69,36 @@ class HumanInTheLoopBlock(Block):
|
||||
"editable": True,
|
||||
},
|
||||
test_output=[
|
||||
("approved_data", {"name": "John Doe", "age": 30}),
|
||||
("status", "approved"),
|
||||
("reviewed_data", {"name": "John Doe", "age": 30}),
|
||||
],
|
||||
test_mock={
|
||||
"handle_review_decision": lambda **kwargs: type(
|
||||
"ReviewDecision",
|
||||
(),
|
||||
{
|
||||
"should_proceed": True,
|
||||
"message": "Test approval message",
|
||||
"review_result": ReviewResult(
|
||||
data={"name": "John Doe", "age": 30},
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="",
|
||||
processed=False,
|
||||
node_exec_id="test-node-exec-id",
|
||||
),
|
||||
},
|
||||
)(),
|
||||
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
|
||||
data={"name": "John Doe", "age": 30},
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="",
|
||||
processed=False,
|
||||
node_exec_id="test-node-exec-id",
|
||||
),
|
||||
"update_node_execution_status": lambda *_args, **_kwargs: None,
|
||||
"update_review_processed_status": lambda *_args, **_kwargs: None,
|
||||
},
|
||||
)
|
||||
|
||||
async def handle_review_decision(self, **kwargs):
|
||||
return await HITLReviewHelper.handle_review_decision(**kwargs)
|
||||
async def get_or_create_human_review(self, **kwargs):
|
||||
return await get_database_manager_async_client().get_or_create_human_review(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def update_node_execution_status(self, **kwargs):
|
||||
return await async_update_node_execution_status(
|
||||
db_client=get_database_manager_async_client(), **kwargs
|
||||
)
|
||||
|
||||
async def update_review_processed_status(self, node_exec_id: str, processed: bool):
|
||||
return await get_database_manager_async_client().update_review_processed_status(
|
||||
node_exec_id, processed
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -102,38 +110,60 @@ class HumanInTheLoopBlock(Block):
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
**_kwargs,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not execution_context.safe_mode:
|
||||
logger.info(
|
||||
f"HITL block skipping review for node {node_exec_id} - safe mode disabled"
|
||||
)
|
||||
yield "approved_data", input_data.data
|
||||
yield "status", "approved"
|
||||
yield "reviewed_data", input_data.data
|
||||
yield "review_message", "Auto-approved (safe mode disabled)"
|
||||
return
|
||||
|
||||
decision = await self.handle_review_decision(
|
||||
input_data=input_data.data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=input_data.editable,
|
||||
)
|
||||
try:
|
||||
result = await self.get_or_create_human_review(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
input_data=input_data.data,
|
||||
message=input_data.name,
|
||||
editable=input_data.editable,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
if decision is None:
|
||||
return
|
||||
if result is None:
|
||||
logger.info(
|
||||
f"HITL block pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
try:
|
||||
await self.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update node status for HITL block {node_exec_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
status = decision.review_result.status
|
||||
if status == ReviewStatus.APPROVED:
|
||||
yield "approved_data", decision.review_result.data
|
||||
elif status == ReviewStatus.REJECTED:
|
||||
yield "rejected_data", decision.review_result.data
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected review status: {status}")
|
||||
if not result.processed:
|
||||
await self.update_review_processed_status(
|
||||
node_exec_id=node_exec_id, processed=True
|
||||
)
|
||||
|
||||
if decision.message:
|
||||
yield "review_message", decision.message
|
||||
if result.status == ReviewStatus.APPROVED:
|
||||
yield "status", "approved"
|
||||
yield "reviewed_data", result.data
|
||||
if result.message:
|
||||
yield "review_message", result.message
|
||||
|
||||
elif result.status == ReviewStatus.REJECTED:
|
||||
yield "status", "rejected"
|
||||
if result.message:
|
||||
yield "review_message", result.message
|
||||
|
||||
@@ -2,6 +2,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -331,8 +332,8 @@ class IdeogramModelBlock(Block):
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch image with V3 endpoint: {e}") from e
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
|
||||
|
||||
async def _run_model_legacy(
|
||||
self,
|
||||
@@ -384,8 +385,8 @@ class IdeogramModelBlock(Block):
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch image with legacy endpoint: {e}") from e
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
|
||||
|
||||
async def upscale_image(self, api_key: SecretStr, image_url: str):
|
||||
url = "https://api.ideogram.ai/upscale"
|
||||
@@ -412,5 +413,5 @@ class IdeogramModelBlock(Block):
|
||||
|
||||
return (response.json())["data"][0]["url"]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to upscale image: {e}") from e
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to upscale image: {str(e)}")
|
||||
|
||||
@@ -16,7 +16,6 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
@@ -57,17 +56,7 @@ class SearchTheWebBlock(Block, GetRequest):
|
||||
|
||||
# Prepend the Jina Search URL to the encoded query
|
||||
jina_search_url = f"https://s.jina.ai/{encoded_query}"
|
||||
|
||||
try:
|
||||
results = await self.get_request(
|
||||
jina_search_url, headers=headers, json=False
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Search failed: {e}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
results = await self.get_request(jina_search_url, headers=headers, json=False)
|
||||
|
||||
# Output the search results
|
||||
yield "results", results
|
||||
|
||||
@@ -92,9 +92,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
@@ -195,9 +194,8 @@ MODEL_METADATA = {
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
|
||||
@@ -305,8 +303,6 @@ MODEL_METADATA = {
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
|
||||
}
|
||||
|
||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
@@ -794,7 +790,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -859,7 +855,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -1225,7 +1221,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -1321,7 +1317,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for summarizing the text.",
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
@@ -1538,7 +1534,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for the conversation.",
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
@@ -1576,7 +1572,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1639,7 +1635,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -1696,7 +1692,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,7 +18,6 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError, BlockInputError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -112,27 +111,9 @@ class ReplicateModelBlock(Block):
|
||||
yield "status", "succeeded"
|
||||
yield "model_name", input_data.model_name
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"Error running Replicate model: {error_msg}")
|
||||
|
||||
# Input validation errors (422, 400) → BlockInputError
|
||||
if (
|
||||
"422" in error_msg
|
||||
or "Input validation failed" in error_msg
|
||||
or "400" in error_msg
|
||||
):
|
||||
raise BlockInputError(
|
||||
message=f"Invalid model inputs: {error_msg}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
# Everything else → BlockExecutionError
|
||||
else:
|
||||
raise BlockExecutionError(
|
||||
message=f"Replicate model error: {error_msg}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
error_msg = f"Unexpected error running Replicate model: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
|
||||
"""
|
||||
|
||||
@@ -45,16 +45,10 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
|
||||
# Note: User-Agent is now automatically set by the request library
|
||||
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
|
||||
try:
|
||||
response = await self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise ValueError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch Wikipedia summary: {e}") from e
|
||||
response = await self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise RuntimeError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
|
||||
@@ -226,7 +226,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=llm.DEFAULT_LLM_MODEL,
|
||||
default=llm.LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -391,12 +391,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
"""
|
||||
block = sink_node.block
|
||||
|
||||
# Use custom name from node metadata if set, otherwise fall back to block.name
|
||||
custom_name = sink_node.metadata.get("customized_name")
|
||||
tool_name = custom_name if custom_name else block.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"description": block.description,
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
@@ -493,12 +489,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
)
|
||||
|
||||
# Use custom name from node metadata if set, otherwise fall back to graph name
|
||||
custom_name = sink_node.metadata.get("customized_name")
|
||||
tool_name = custom_name if custom_name else sink_graph_meta.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
@@ -983,28 +975,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
nodes_to_skip: set[str] | None = None,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||
original_tool_count = len(tool_functions)
|
||||
|
||||
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
|
||||
if nodes_to_skip:
|
||||
tool_functions = [
|
||||
tf
|
||||
for tf in tool_functions
|
||||
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
|
||||
]
|
||||
|
||||
# Only raise error if we had tools but they were all filtered out
|
||||
if original_tool_count > 0 and not tool_functions:
|
||||
raise ValueError(
|
||||
"No available tools to execute - all downstream nodes are unavailable "
|
||||
"(possibly due to missing optional credentials)"
|
||||
)
|
||||
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
conversation_history = input_data.conversation_history or []
|
||||
|
||||
@@ -196,15 +196,6 @@ class TestXMLParserBlockSecurity:
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
|
||||
pass
|
||||
|
||||
async def test_rejects_text_outside_root(self):
|
||||
"""Ensure parser surfaces readable errors for invalid root text."""
|
||||
block = XMLParserBlock()
|
||||
invalid_xml = "<root><child>value</child></root> trailing"
|
||||
|
||||
with pytest.raises(ValueError, match="text outside the root element"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
llm_model=llm.DEFAULT_LLM_MODEL,
|
||||
llm_model=llm.LlmModel.GPT4O,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
@@ -65,7 +65,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore # type: ignore
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AITextGeneratorBlock.Input(
|
||||
prompt="Generate text",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
@@ -228,7 +228,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text=long_text,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=100, # Small chunks
|
||||
chunk_overlap=10,
|
||||
@@ -299,7 +299,7 @@ class TestLLMStatsTracking:
|
||||
# Test with very short text (should only need 1 chunk + 1 final summary)
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="This is a short text.",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000, # Large enough to avoid chunking
|
||||
)
|
||||
@@ -346,7 +346,7 @@ class TestLLMStatsTracking:
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -387,7 +387,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AIListGeneratorBlock.Input(
|
||||
focus="test items",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_retries=3,
|
||||
)
|
||||
@@ -469,7 +469,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"result": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -513,7 +513,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
)
|
||||
@@ -558,7 +558,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
max_tokens=1000,
|
||||
@@ -593,7 +593,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -623,7 +623,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000,
|
||||
)
|
||||
@@ -654,7 +654,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
"""
|
||||
Standalone tests for pin name sanitization that can run without full backend dependencies.
|
||||
|
||||
These tests verify the core sanitization logic independently of the full system.
|
||||
Run with: python -m pytest test_pin_sanitization_standalone.py -v
|
||||
Or simply: python test_pin_sanitization_standalone.py
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
# Simulate the exact cleanup function from SmartDecisionMakerBlock
|
||||
def cleanup(s: str) -> str:
|
||||
"""Clean up names for use as tool function names."""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
|
||||
# Simulate the key parts of parse_execution_output
|
||||
def simulate_tool_routing(
|
||||
emit_key: str,
|
||||
sink_node_id: str,
|
||||
sink_pin_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Simulate the routing comparison from parse_execution_output.
|
||||
|
||||
Returns True if routing would succeed, False otherwise.
|
||||
"""
|
||||
if not emit_key.startswith("tools_^_") or "_~_" not in emit_key:
|
||||
return False
|
||||
|
||||
# Extract routing info from emit key: tools_^_{node_id}_~_{field}
|
||||
selector = emit_key[8:] # Remove "tools_^_"
|
||||
target_node_id, target_input_pin = selector.split("_~_", 1)
|
||||
|
||||
# Current (buggy) comparison - direct string comparison
|
||||
return target_node_id == sink_node_id and target_input_pin == sink_pin_name
|
||||
|
||||
|
||||
def simulate_fixed_tool_routing(
|
||||
emit_key: str,
|
||||
sink_node_id: str,
|
||||
sink_pin_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Simulate the FIXED routing comparison.
|
||||
|
||||
The fix: sanitize sink_pin_name before comparison.
|
||||
"""
|
||||
if not emit_key.startswith("tools_^_") or "_~_" not in emit_key:
|
||||
return False
|
||||
|
||||
selector = emit_key[8:]
|
||||
target_node_id, target_input_pin = selector.split("_~_", 1)
|
||||
|
||||
# Fixed comparison - sanitize sink_pin_name
|
||||
return target_node_id == sink_node_id and target_input_pin == cleanup(sink_pin_name)
|
||||
|
||||
|
||||
class TestCleanupFunction:
|
||||
"""Tests for the cleanup function."""
|
||||
|
||||
def test_spaces_to_underscores(self):
|
||||
assert cleanup("Max Keyword Difficulty") == "max_keyword_difficulty"
|
||||
|
||||
def test_mixed_case_to_lowercase(self):
|
||||
assert cleanup("MaxKeywordDifficulty") == "maxkeyworddifficulty"
|
||||
|
||||
def test_special_chars_to_underscores(self):
|
||||
assert cleanup("field@name!") == "field_name_"
|
||||
assert cleanup("CPC ($)") == "cpc____"
|
||||
|
||||
def test_preserves_valid_chars(self):
|
||||
assert cleanup("valid_name-123") == "valid_name-123"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert cleanup("") == ""
|
||||
|
||||
def test_consecutive_spaces(self):
|
||||
assert cleanup("a b") == "a___b"
|
||||
|
||||
def test_unicode(self):
|
||||
assert cleanup("café") == "caf_"
|
||||
|
||||
|
||||
class TestCurrentRoutingBehavior:
|
||||
"""Tests demonstrating the current (buggy) routing behavior."""
|
||||
|
||||
def test_exact_match_works(self):
|
||||
"""When names match exactly, routing works."""
|
||||
emit_key = "tools_^_node-123_~_query"
|
||||
assert simulate_tool_routing(emit_key, "node-123", "query") is True
|
||||
|
||||
def test_spaces_cause_failure(self):
|
||||
"""When sink_pin has spaces, routing fails."""
|
||||
sanitized = cleanup("Max Keyword Difficulty")
|
||||
emit_key = f"tools_^_node-123_~_{sanitized}"
|
||||
assert simulate_tool_routing(emit_key, "node-123", "Max Keyword Difficulty") is False
|
||||
|
||||
def test_special_chars_cause_failure(self):
|
||||
"""When sink_pin has special chars, routing fails."""
|
||||
sanitized = cleanup("CPC ($)")
|
||||
emit_key = f"tools_^_node-123_~_{sanitized}"
|
||||
assert simulate_tool_routing(emit_key, "node-123", "CPC ($)") is False
|
||||
|
||||
|
||||
class TestFixedRoutingBehavior:
|
||||
"""Tests demonstrating the fixed routing behavior."""
|
||||
|
||||
def test_exact_match_still_works(self):
|
||||
"""When names match exactly, routing still works."""
|
||||
emit_key = "tools_^_node-123_~_query"
|
||||
assert simulate_fixed_tool_routing(emit_key, "node-123", "query") is True
|
||||
|
||||
def test_spaces_work_with_fix(self):
|
||||
"""With the fix, spaces in sink_pin work."""
|
||||
sanitized = cleanup("Max Keyword Difficulty")
|
||||
emit_key = f"tools_^_node-123_~_{sanitized}"
|
||||
assert simulate_fixed_tool_routing(emit_key, "node-123", "Max Keyword Difficulty") is True
|
||||
|
||||
def test_special_chars_work_with_fix(self):
|
||||
"""With the fix, special chars in sink_pin work."""
|
||||
sanitized = cleanup("CPC ($)")
|
||||
emit_key = f"tools_^_node-123_~_{sanitized}"
|
||||
assert simulate_fixed_tool_routing(emit_key, "node-123", "CPC ($)") is True
|
||||
|
||||
|
||||
class TestBugReproduction:
|
||||
"""Exact reproduction of the reported bug."""
|
||||
|
||||
def test_max_keyword_difficulty_bug(self):
|
||||
"""
|
||||
Reproduce the exact bug from the issue:
|
||||
|
||||
"For this agent specifically the input pin has space and unsanitized,
|
||||
the frontend somehow connect without sanitizing creating a link like:
|
||||
tools_^_767682f5-..._~_Max Keyword Difficulty
|
||||
but what's produced by backend is
|
||||
tools_^_767682f5-..._~_max_keyword_difficulty
|
||||
so the tool calls go into the void"
|
||||
"""
|
||||
node_id = "767682f5-fake-uuid"
|
||||
original_field = "Max Keyword Difficulty"
|
||||
sanitized_field = cleanup(original_field)
|
||||
|
||||
# What backend produces (emit key)
|
||||
emit_key = f"tools_^_{node_id}_~_{sanitized_field}"
|
||||
assert emit_key == f"tools_^_{node_id}_~_max_keyword_difficulty"
|
||||
|
||||
# What frontend link has (sink_pin_name)
|
||||
frontend_sink = original_field
|
||||
|
||||
# Current behavior: FAILS
|
||||
assert simulate_tool_routing(emit_key, node_id, frontend_sink) is False
|
||||
|
||||
# With fix: WORKS
|
||||
assert simulate_fixed_tool_routing(emit_key, node_id, frontend_sink) is True
|
||||
|
||||
|
||||
class TestCommonFieldNamePatterns:
|
||||
"""Test common field name patterns that could cause issues."""
|
||||
|
||||
FIELD_NAMES = [
|
||||
"Max Keyword Difficulty",
|
||||
"Search Volume (Monthly)",
|
||||
"CPC ($)",
|
||||
"User's Input",
|
||||
"Target URL",
|
||||
"API Response",
|
||||
"Query #1",
|
||||
"First Name",
|
||||
"Last Name",
|
||||
"Email Address",
|
||||
"Phone Number",
|
||||
"Total Cost ($)",
|
||||
"Discount (%)",
|
||||
"Created At",
|
||||
"Updated At",
|
||||
"Is Active",
|
||||
]
|
||||
|
||||
def test_current_behavior_fails_for_special_names(self):
|
||||
"""Current behavior fails for names with spaces/special chars."""
|
||||
failed = []
|
||||
for name in self.FIELD_NAMES:
|
||||
sanitized = cleanup(name)
|
||||
emit_key = f"tools_^_node_~_{sanitized}"
|
||||
if not simulate_tool_routing(emit_key, "node", name):
|
||||
failed.append(name)
|
||||
|
||||
# All names with spaces should fail
|
||||
names_with_spaces = [n for n in self.FIELD_NAMES if " " in n or any(c in n for c in "()$%#'")]
|
||||
assert set(failed) == set(names_with_spaces)
|
||||
|
||||
def test_fixed_behavior_works_for_all_names(self):
|
||||
"""Fixed behavior works for all names."""
|
||||
for name in self.FIELD_NAMES:
|
||||
sanitized = cleanup(name)
|
||||
emit_key = f"tools_^_node_~_{sanitized}"
|
||||
assert simulate_fixed_tool_routing(emit_key, "node", name) is True, f"Failed for: {name}"
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests manually without pytest."""
|
||||
import traceback
|
||||
|
||||
test_classes = [
|
||||
TestCleanupFunction,
|
||||
TestCurrentRoutingBehavior,
|
||||
TestFixedRoutingBehavior,
|
||||
TestBugReproduction,
|
||||
TestCommonFieldNamePatterns,
|
||||
]
|
||||
|
||||
total = 0
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_class in test_classes:
|
||||
print(f"\n{test_class.__name__}:")
|
||||
instance = test_class()
|
||||
for name in dir(instance):
|
||||
if name.startswith("test_"):
|
||||
total += 1
|
||||
try:
|
||||
getattr(instance, name)()
|
||||
print(f" ✓ {name}")
|
||||
passed += 1
|
||||
except AssertionError as e:
|
||||
print(f" ✗ {name}: {e}")
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
print(f" ✗ {name}: {e}")
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Total: {total}, Passed: {passed}, Failed: {failed}")
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
success = run_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -5,10 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ProviderName, User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
@@ -233,7 +233,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Create test input
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -335,7 +335,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
agent_mode_max_iterations=0,
|
||||
@@ -402,7 +402,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -462,7 +462,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -526,7 +526,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -648,7 +648,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
agent_mode_max_iterations=0,
|
||||
@@ -722,7 +722,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -778,7 +778,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -931,7 +931,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
|
||||
)
|
||||
@@ -1020,7 +1020,7 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0, # Traditional mode
|
||||
)
|
||||
@@ -1057,153 +1057,3 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
) # Should yield individual tool parameters
|
||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||
assert "conversations" in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the customized name (cleaned up)
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version
|
||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the block's default name
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "storevalueblock" # Default block name cleaned
|
||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-agent-node-id"
|
||||
mock_node.metadata = {"customized_name": "My Custom Agent"}
|
||||
mock_node.input_default = {
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_version": 1,
|
||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
||||
}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "test_input"
|
||||
|
||||
# Mock the database client
|
||||
mock_graph_meta = MagicMock()
|
||||
mock_graph_meta.name = "Original Agent Name"
|
||||
mock_graph_meta.description = "Agent description"
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the customized name (cleaned up)
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "my_custom_agent" # Cleaned version
|
||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
"""Test that agent node falls back to graph name when no customized_name."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-agent-node-id"
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.input_default = {
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_version": 1,
|
||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
||||
}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "test_input"
|
||||
|
||||
# Mock the database client
|
||||
mock_graph_meta = MagicMock()
|
||||
mock_graph_meta.name = "Original Agent Name"
|
||||
mock_graph_meta.description = "Agent description"
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the graph's default name
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "original_agent_name" # Graph name cleaned
|
||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
||||
|
||||
@@ -1,916 +0,0 @@
|
||||
"""
|
||||
Tests for SmartDecisionMaker agent mode specific failure modes.
|
||||
|
||||
Covers failure modes:
|
||||
2. Silent Tool Failures in Agent Mode
|
||||
3. Unbounded Agent Mode Iterations
|
||||
10. Unbounded Agent Iterations
|
||||
12. Stale Credentials in Agent Mode
|
||||
13. Tool Signature Cache Invalidation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
ExecutionParams,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
|
||||
class TestSilentToolFailuresInAgentMode:
|
||||
"""
|
||||
Tests for Failure Mode #2: Silent Tool Failures in Agent Mode
|
||||
|
||||
When tool execution fails in agent mode, the error is converted to a
|
||||
tool response and execution continues silently.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_failure_converted_to_response(self):
|
||||
"""
|
||||
Test that tool execution failures are silently converted to responses.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# First response: tool call
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "failing_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"param": "value"})
|
||||
|
||||
mock_response_1 = MagicMock()
|
||||
mock_response_1.response = None
|
||||
mock_response_1.tool_calls = [mock_tool_call]
|
||||
mock_response_1.prompt_tokens = 50
|
||||
mock_response_1.completion_tokens = 25
|
||||
mock_response_1.reasoning = None
|
||||
mock_response_1.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
}
|
||||
|
||||
# Second response: finish after seeing error
|
||||
mock_response_2 = MagicMock()
|
||||
mock_response_2.response = "I encountered an error"
|
||||
mock_response_2.tool_calls = []
|
||||
mock_response_2.prompt_tokens = 30
|
||||
mock_response_2.completion_tokens = 15
|
||||
mock_response_2.reasoning = None
|
||||
mock_response_2.raw_response = {"role": "assistant", "content": "I encountered an error"}
|
||||
|
||||
llm_call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal llm_call_count
|
||||
llm_call_count += 1
|
||||
if llm_call_count == 1:
|
||||
return mock_response_1
|
||||
return mock_response_2
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "failing_tool",
|
||||
"_sink_node_id": "sink-node",
|
||||
"_field_mapping": {"param": "param"},
|
||||
"parameters": {
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Mock database client that will fail
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_node.side_effect = Exception("Database connection failed!")
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Do something",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=5,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# The execution completed (didn't crash)
|
||||
assert "finished" in outputs or "conversations" in outputs
|
||||
|
||||
# BUG: The tool failure was silent - user doesn't know what happened
|
||||
# The error was just logged and converted to a tool response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_failure_causes_infinite_retry_loop(self):
|
||||
"""
|
||||
Test scenario where LLM keeps calling the same failing tool.
|
||||
|
||||
If tool fails but LLM doesn't realize it, it may keep trying.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
call_count = 0
|
||||
max_calls = 10 # Limit for test
|
||||
|
||||
def create_tool_call_response():
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = f"call_{call_count}"
|
||||
mock_tool_call.function.name = "persistent_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"retry": call_count})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{call_count}"}]
|
||||
}
|
||||
return mock_response
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
if call_count >= max_calls:
|
||||
# Eventually finish to prevent actual infinite loop in test
|
||||
final = MagicMock()
|
||||
final.response = "Giving up"
|
||||
final.tool_calls = []
|
||||
final.prompt_tokens = 10
|
||||
final.completion_tokens = 5
|
||||
final.reasoning = None
|
||||
final.raw_response = {"role": "assistant", "content": "Giving up"}
|
||||
return final
|
||||
|
||||
return create_tool_call_response()
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "persistent_tool",
|
||||
"_sink_node_id": "sink-node",
|
||||
"_field_mapping": {"retry": "retry"},
|
||||
"parameters": {
|
||||
"properties": {"retry": {"type": "integer"}},
|
||||
"required": ["retry"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_node.side_effect = Exception("Always fails!")
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Keep trying",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=-1, # Infinite mode!
|
||||
)
|
||||
|
||||
# Use timeout to prevent actual infinite loop
|
||||
try:
|
||||
async with asyncio.timeout(5):
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
except asyncio.TimeoutError:
|
||||
pass # Expected if we hit infinite loop
|
||||
|
||||
# Document that many calls were made before we gave up
|
||||
assert call_count >= max_calls - 1, \
|
||||
f"Expected many retries, got {call_count}"
|
||||
|
||||
|
||||
class TestUnboundedAgentIterations:
|
||||
"""
|
||||
Tests for Failure Mode #3 and #10: Unbounded Agent Mode Iterations
|
||||
|
||||
With max_iterations = -1, the agent can run forever, consuming
|
||||
unlimited tokens and compute resources.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infinite_mode_requires_llm_to_stop(self):
|
||||
"""
|
||||
Test that infinite mode (-1) only stops when LLM stops making tool calls.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
iterations = 0
|
||||
max_test_iterations = 20
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal iterations
|
||||
iterations += 1
|
||||
|
||||
if iterations >= max_test_iterations:
|
||||
# Stop to prevent actual infinite loop
|
||||
resp = MagicMock()
|
||||
resp.response = "Finally done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 10
|
||||
resp.completion_tokens = 5
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
# Keep making tool calls
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iterations}"
|
||||
tool_call.function.name = "counter_tool"
|
||||
tool_call.function.arguments = json.dumps({"count": iterations})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iterations}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "counter_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"count": "count"},
|
||||
"parameters": {
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
"required": ["count"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {"count": 1})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Count forever",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=-1, # INFINITE MODE
|
||||
)
|
||||
|
||||
async with asyncio.timeout(10):
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# We ran many iterations before stopping
|
||||
assert iterations == max_test_iterations
|
||||
# BUG: No built-in safeguard against runaway iterations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_limit_enforced(self):
|
||||
"""
|
||||
Test that max_iterations limit is properly enforced.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
iterations = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal iterations
|
||||
iterations += 1
|
||||
|
||||
# Always make tool calls (never finish voluntarily)
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iterations}"
|
||||
tool_call.function.name = "endless_tool"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iterations}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "endless_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
MAX_ITERATIONS = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Run forever",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=MAX_ITERATIONS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Should have stopped at max iterations
|
||||
assert iterations == MAX_ITERATIONS
|
||||
assert "finished" in outputs
|
||||
assert "limit reached" in outputs["finished"].lower()
|
||||
|
||||
|
||||
class TestStaleCredentialsInAgentMode:
|
||||
"""
|
||||
Tests for Failure Mode #12: Stale Credentials in Agent Mode
|
||||
|
||||
Credentials are validated once at start but can expire during
|
||||
long-running agent mode executions.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credentials_not_revalidated_between_iterations(self):
|
||||
"""
|
||||
Test that credentials are used without revalidation in agent mode.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
credential_check_count = 0
|
||||
iteration = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal credential_check_count, iteration
|
||||
iteration += 1
|
||||
|
||||
# Simulate credential check (in real code this happens in llm_call)
|
||||
credential_check_count += 1
|
||||
|
||||
if iteration >= 3:
|
||||
resp = MagicMock()
|
||||
resp.response = "Done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 10
|
||||
resp.completion_tokens = 5
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iteration}"
|
||||
tool_call.function.name = "test_tool"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test credentials",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=5,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Credentials were checked on each LLM call but not refreshed
|
||||
# If they expired mid-execution, we'd get auth errors
|
||||
assert credential_check_count == iteration
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credential_expiration_mid_execution(self):
|
||||
"""
|
||||
Test what happens when credentials expire during agent mode.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
iteration = 0
|
||||
|
||||
async def mock_llm_call_with_expiration(**kwargs):
|
||||
nonlocal iteration
|
||||
iteration += 1
|
||||
|
||||
if iteration >= 3:
|
||||
# Simulate credential expiration
|
||||
raise Exception("401 Unauthorized: API key expired")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iteration}"
|
||||
tool_call.function.name = "test_tool"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call_with_expiration), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test credentials",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=10,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Should have an error output
|
||||
assert "error" in outputs
|
||||
assert "expired" in outputs["error"].lower() or "unauthorized" in outputs["error"].lower()
|
||||
|
||||
|
||||
class TestToolSignatureCacheInvalidation:
|
||||
"""
|
||||
Tests for Failure Mode #13: Tool Signature Cache Invalidation
|
||||
|
||||
Tool signatures are created once at the start of run() but the
|
||||
graph could change during agent mode execution.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signatures_created_once_at_start(self):
|
||||
"""
|
||||
Test that tool signatures are only created once, not refreshed.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
signature_creation_count = 0
|
||||
iteration = 0
|
||||
|
||||
original_create_signatures = block._create_tool_node_signatures
|
||||
|
||||
async def counting_create_signatures(node_id):
|
||||
nonlocal signature_creation_count
|
||||
signature_creation_count += 1
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_v1",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal iteration
|
||||
iteration += 1
|
||||
|
||||
if iteration >= 3:
|
||||
resp = MagicMock()
|
||||
resp.response = "Done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 10
|
||||
resp.completion_tokens = 5
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iteration}"
|
||||
tool_call.function.name = "tool_v1"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", side_effect=counting_create_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test signatures",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=5,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Signatures were only created once, even though we had multiple iterations
|
||||
assert signature_creation_count == 1
|
||||
assert iteration >= 3 # We had multiple iterations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stale_signatures_cause_tool_mismatch(self):
|
||||
"""
|
||||
Test scenario where tool definitions change but agent uses stale signatures.
|
||||
"""
|
||||
# This documents the potential issue:
|
||||
# 1. Agent starts with tool_v1
|
||||
# 2. User modifies graph, tool becomes tool_v2
|
||||
# 3. Agent still thinks tool_v1 exists
|
||||
# 4. LLM calls tool_v1, but it no longer exists
|
||||
|
||||
# Since signatures are created once at start and never refreshed,
|
||||
# any changes to the graph during execution won't be reflected.
|
||||
|
||||
# This is more of a documentation test - the actual fix would
|
||||
# require either:
|
||||
# a) Refreshing signatures periodically
|
||||
# b) Locking the graph during execution
|
||||
# c) Checking tool existence before each call
|
||||
pass
|
||||
|
||||
|
||||
class TestAgentModeConversationManagement:
|
||||
"""Tests for conversation management in agent mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_grows_with_iterations(self):
|
||||
"""
|
||||
Test that conversation history grows correctly with each iteration.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
iteration = 0
|
||||
conversation_lengths = []
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal iteration
|
||||
iteration += 1
|
||||
|
||||
# Record conversation length at each call
|
||||
prompt = kwargs.get("prompt", [])
|
||||
conversation_lengths.append(len(prompt))
|
||||
|
||||
if iteration >= 3:
|
||||
resp = MagicMock()
|
||||
resp.response = "Done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 10
|
||||
resp.completion_tokens = 5
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iteration}"
|
||||
tool_call.function.name = "test_tool"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test conversation",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=5,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Conversation should grow with each iteration
|
||||
# Each iteration adds: assistant message + tool response
|
||||
assert len(conversation_lengths) == 3
|
||||
for i in range(1, len(conversation_lengths)):
|
||||
assert conversation_lengths[i] > conversation_lengths[i-1], \
|
||||
f"Conversation should grow: {conversation_lengths}"
|
||||
@@ -1,525 +0,0 @@
|
||||
"""
|
||||
Tests for SmartDecisionMaker concurrency issues and race conditions.
|
||||
|
||||
Covers failure modes:
|
||||
1. Conversation History Race Condition
|
||||
4. Concurrent Execution State Sharing
|
||||
7. Race in Pending Tool Calls
|
||||
11. Race in Pending Tool Call Retrieval
|
||||
14. Concurrent State Sharing
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
get_pending_tool_calls,
|
||||
_create_tool_response,
|
||||
_get_tool_requests,
|
||||
_get_tool_responses,
|
||||
)
|
||||
|
||||
|
||||
class TestConversationHistoryRaceCondition:
|
||||
"""
|
||||
Tests for Failure Mode #1: Conversation History Race Condition
|
||||
|
||||
When multiple executions share conversation history, concurrent
|
||||
modifications can cause data loss or corruption.
|
||||
"""
|
||||
|
||||
def test_get_pending_tool_calls_with_concurrent_modification(self):
|
||||
"""
|
||||
Test that concurrent modifications to conversation history
|
||||
can cause inconsistent pending tool call counts.
|
||||
"""
|
||||
# Shared conversation history
|
||||
conversation_history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "toolu_1"},
|
||||
{"type": "tool_use", "id": "toolu_2"},
|
||||
{"type": "tool_use", "id": "toolu_3"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def reader_thread():
|
||||
"""Repeatedly read pending calls."""
|
||||
for _ in range(100):
|
||||
try:
|
||||
pending = get_pending_tool_calls(conversation_history)
|
||||
results.append(len(pending))
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
def writer_thread():
|
||||
"""Modify conversation while readers are active."""
|
||||
for i in range(50):
|
||||
# Add a tool response
|
||||
conversation_history.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": f"toolu_{(i % 3) + 1}"}]
|
||||
})
|
||||
# Remove it
|
||||
if len(conversation_history) > 1:
|
||||
conversation_history.pop()
|
||||
|
||||
# Run concurrent readers and writers
|
||||
threads = []
|
||||
for _ in range(3):
|
||||
threads.append(threading.Thread(target=reader_thread))
|
||||
threads.append(threading.Thread(target=writer_thread))
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# The issue: results may be inconsistent due to race conditions
|
||||
# In a correct implementation, we'd expect consistent results
|
||||
# Document that this CAN produce inconsistent results
|
||||
assert len(results) > 0, "Should have some results"
|
||||
# Note: This test documents the race condition exists
|
||||
# When fixed, all results should be consistent
|
||||
|
||||
def test_prompt_list_mutation_race(self):
|
||||
"""
|
||||
Test that mutating prompt list during iteration can cause issues.
|
||||
"""
|
||||
prompt = []
|
||||
errors = []
|
||||
|
||||
def appender():
|
||||
for i in range(100):
|
||||
prompt.append({"role": "user", "content": f"msg_{i}"})
|
||||
|
||||
def extender():
|
||||
for i in range(100):
|
||||
prompt.extend([{"role": "assistant", "content": f"resp_{i}"}])
|
||||
|
||||
def reader():
|
||||
for _ in range(100):
|
||||
try:
|
||||
# Iterate while others modify
|
||||
_ = [p for p in prompt if p.get("role") == "user"]
|
||||
except RuntimeError as e:
|
||||
# "dictionary changed size during iteration" or similar
|
||||
errors.append(str(e))
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=appender),
|
||||
threading.Thread(target=extender),
|
||||
threading.Thread(target=reader),
|
||||
]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Document that race conditions can occur
|
||||
# In production, this could cause silent data corruption
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_block_runs_share_state(self):
|
||||
"""
|
||||
Test that concurrent runs on same block instance can share state incorrectly.
|
||||
|
||||
This is Failure Mode #14: Concurrent State Sharing
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Track all outputs from all runs
|
||||
all_outputs = []
|
||||
lock = threading.Lock()
|
||||
|
||||
async def run_block(run_id: int):
|
||||
"""Run the block with a unique run_id."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = f"Response for run {run_id}"
|
||||
mock_response.tool_calls = [] # No tool calls, just finish
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": f"Run {run_id}"}
|
||||
|
||||
mock_tool_signatures = []
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt=f"Prompt for run {run_id}",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id=f"graph-{run_id}",
|
||||
node_id=f"node-{run_id}",
|
||||
graph_exec_id=f"exec-{run_id}",
|
||||
node_exec_id=f"node-exec-{run_id}",
|
||||
user_id=f"user-{run_id}",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
with lock:
|
||||
all_outputs.append((run_id, outputs))
|
||||
|
||||
# Run multiple concurrent executions
|
||||
tasks = [run_block(i) for i in range(5)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Verify each run got its own response (no cross-contamination)
|
||||
for run_id, outputs in all_outputs:
|
||||
if "finished" in outputs:
|
||||
assert f"run {run_id}" in outputs["finished"].lower() or outputs["finished"] == f"Response for run {run_id}", \
|
||||
f"Run {run_id} may have received contaminated response: {outputs}"
|
||||
|
||||
|
||||
class TestPendingToolCallRace:
|
||||
"""
|
||||
Tests for Failure Mode #7 and #11: Race in Pending Tool Calls
|
||||
|
||||
The get_pending_tool_calls function can race with modifications
|
||||
to the conversation history, causing StopIteration or incorrect counts.
|
||||
"""
|
||||
|
||||
def test_pending_tool_calls_counter_accuracy(self):
|
||||
"""Test that pending tool call counting is accurate."""
|
||||
conversation = [
|
||||
# Assistant makes 3 tool calls
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1"},
|
||||
{"type": "tool_use", "id": "call_2"},
|
||||
{"type": "tool_use", "id": "call_3"},
|
||||
]
|
||||
},
|
||||
# User provides 1 response
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_1"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(conversation)
|
||||
|
||||
# Should have 2 pending (call_2, call_3)
|
||||
assert len(pending) == 2
|
||||
assert "call_2" in pending
|
||||
assert "call_3" in pending
|
||||
assert pending["call_2"] == 1
|
||||
assert pending["call_3"] == 1
|
||||
|
||||
def test_pending_tool_calls_duplicate_responses(self):
|
||||
"""Test handling of duplicate tool responses."""
|
||||
conversation = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
},
|
||||
# Duplicate responses for same call
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(conversation)
|
||||
|
||||
# call_1 has count -1 (1 request - 2 responses)
|
||||
# Should not be in pending (count <= 0)
|
||||
assert "call_1" not in pending or pending.get("call_1", 0) <= 0
|
||||
|
||||
def test_empty_conversation_no_pending(self):
|
||||
"""Test that empty conversation has no pending calls."""
|
||||
assert get_pending_tool_calls([]) == {}
|
||||
assert get_pending_tool_calls(None) == {}
|
||||
|
||||
def test_next_iter_on_empty_dict_raises_stop_iteration(self):
|
||||
"""
|
||||
Document the StopIteration vulnerability.
|
||||
|
||||
If pending_tool_calls becomes empty between the check and
|
||||
next(iter(...)), StopIteration is raised.
|
||||
"""
|
||||
pending = {}
|
||||
|
||||
# This is the pattern used in smart_decision_maker.py:1019
|
||||
# if pending_tool_calls and ...:
|
||||
# first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(iter(pending.keys()))
|
||||
|
||||
# Safe pattern should be:
|
||||
# first_call_id = next(iter(pending_tool_calls.keys()), None)
|
||||
safe_result = next(iter(pending.keys()), None)
|
||||
assert safe_result is None
|
||||
|
||||
|
||||
class TestToolRequestResponseParsing:
|
||||
"""Tests for tool request/response parsing edge cases."""
|
||||
|
||||
def test_get_tool_requests_openai_format(self):
|
||||
"""Test parsing OpenAI format tool requests."""
|
||||
entry = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_abc123"},
|
||||
{"id": "call_def456"},
|
||||
]
|
||||
}
|
||||
|
||||
requests = _get_tool_requests(entry)
|
||||
assert requests == ["call_abc123", "call_def456"]
|
||||
|
||||
def test_get_tool_requests_anthropic_format(self):
|
||||
"""Test parsing Anthropic format tool requests."""
|
||||
entry = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "toolu_abc123"},
|
||||
{"type": "text", "text": "Let me call this tool"},
|
||||
{"type": "tool_use", "id": "toolu_def456"},
|
||||
]
|
||||
}
|
||||
|
||||
requests = _get_tool_requests(entry)
|
||||
assert requests == ["toolu_abc123", "toolu_def456"]
|
||||
|
||||
def test_get_tool_requests_non_assistant_role(self):
|
||||
"""Non-assistant roles should return empty list."""
|
||||
entry = {"role": "user", "tool_calls": [{"id": "call_123"}]}
|
||||
assert _get_tool_requests(entry) == []
|
||||
|
||||
def test_get_tool_responses_openai_format(self):
|
||||
"""Test parsing OpenAI format tool responses."""
|
||||
entry = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc123",
|
||||
"content": "Result"
|
||||
}
|
||||
|
||||
responses = _get_tool_responses(entry)
|
||||
assert responses == ["call_abc123"]
|
||||
|
||||
def test_get_tool_responses_anthropic_format(self):
|
||||
"""Test parsing Anthropic format tool responses."""
|
||||
entry = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "toolu_abc123"},
|
||||
{"type": "tool_result", "tool_use_id": "toolu_def456"},
|
||||
]
|
||||
}
|
||||
|
||||
responses = _get_tool_responses(entry)
|
||||
assert responses == ["toolu_abc123", "toolu_def456"]
|
||||
|
||||
def test_get_tool_responses_mixed_content(self):
|
||||
"""Test parsing responses with mixed content types."""
|
||||
entry = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here are the results"},
|
||||
{"type": "tool_result", "tool_use_id": "toolu_123"},
|
||||
{"type": "image", "url": "http://example.com/img.png"},
|
||||
]
|
||||
}
|
||||
|
||||
responses = _get_tool_responses(entry)
|
||||
assert responses == ["toolu_123"]
|
||||
|
||||
|
||||
class TestConcurrentToolSignatureCreation:
|
||||
"""Tests for concurrent tool signature creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_signature_creation_same_node(self):
|
||||
"""
|
||||
Test that concurrent signature creation for same node
|
||||
doesn't cause issues.
|
||||
"""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "string", "description": "test"}
|
||||
)
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="field1", sink_id="test-node", source_id="source"),
|
||||
Mock(sink_name="field2", sink_id="test-node", source_id="source"),
|
||||
]
|
||||
|
||||
# Run multiple concurrent signature creations
|
||||
tasks = [
|
||||
block._create_block_function_signature(mock_node, mock_links)
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be identical
|
||||
first = results[0]
|
||||
for i, result in enumerate(results[1:], 1):
|
||||
assert result["function"]["name"] == first["function"]["name"], \
|
||||
f"Result {i} has different name"
|
||||
assert set(result["function"]["parameters"]["properties"].keys()) == \
|
||||
set(first["function"]["parameters"]["properties"].keys()), \
|
||||
f"Result {i} has different properties"
|
||||
|
||||
|
||||
class TestThreadSafetyOfCleanup:
|
||||
"""Tests for thread safety of cleanup function."""
|
||||
|
||||
def test_cleanup_is_thread_safe(self):
|
||||
"""
|
||||
Test that cleanup function is thread-safe.
|
||||
|
||||
Since it's a pure function with no shared state, it should be safe.
|
||||
"""
|
||||
results = {}
|
||||
lock = threading.Lock()
|
||||
|
||||
test_inputs = [
|
||||
"Max Keyword Difficulty",
|
||||
"Search Volume (Monthly)",
|
||||
"CPC ($)",
|
||||
"Target URL",
|
||||
]
|
||||
|
||||
def worker(input_str: str, thread_id: int):
|
||||
for _ in range(100):
|
||||
result = SmartDecisionMakerBlock.cleanup(input_str)
|
||||
with lock:
|
||||
key = f"{thread_id}_{input_str}"
|
||||
if key not in results:
|
||||
results[key] = set()
|
||||
results[key].add(result)
|
||||
|
||||
threads = []
|
||||
for i, input_str in enumerate(test_inputs):
|
||||
for j in range(3):
|
||||
t = threading.Thread(target=worker, args=(input_str, i * 3 + j))
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Each input should produce exactly one unique output
|
||||
for key, values in results.items():
|
||||
assert len(values) == 1, f"Non-deterministic cleanup for {key}: {values}"
|
||||
|
||||
|
||||
class TestAsyncConcurrencyPatterns:
|
||||
"""Tests for async concurrency patterns in the block."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_async_runs_isolation(self):
|
||||
"""
|
||||
Test that multiple async runs are properly isolated.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
run_count = 5
|
||||
results = []
|
||||
|
||||
async def single_run(run_id: int):
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = f"Unique response {run_id}"
|
||||
mock_response.tool_calls = []
|
||||
mock_response.prompt_tokens = 10
|
||||
mock_response.completion_tokens = 5
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": f"Run {run_id}"}
|
||||
|
||||
# Add small random delay to increase chance of interleaving
|
||||
await asyncio.sleep(0.001 * (run_id % 3))
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt=f"Prompt {run_id}",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id=f"g{run_id}",
|
||||
node_id=f"n{run_id}",
|
||||
graph_exec_id=f"e{run_id}",
|
||||
node_exec_id=f"ne{run_id}",
|
||||
user_id=f"u{run_id}",
|
||||
graph_version=1,
|
||||
execution_context=ExecutionContext(safe_mode=False),
|
||||
execution_processor=MagicMock(),
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
return run_id, outputs
|
||||
|
||||
# Run all concurrently
|
||||
tasks = [single_run(i) for i in range(run_count)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify isolation
|
||||
for run_id, outputs in results:
|
||||
if "finished" in outputs:
|
||||
assert str(run_id) in outputs["finished"], \
|
||||
f"Run {run_id} got wrong response: {outputs['finished']}"
|
||||
@@ -1,667 +0,0 @@
|
||||
"""
|
||||
Tests for SmartDecisionMaker conversation handling and corruption scenarios.
|
||||
|
||||
Covers failure modes:
|
||||
6. Conversation Corruption in Error Paths
|
||||
And related conversation management issues.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
get_pending_tool_calls,
|
||||
_create_tool_response,
|
||||
_combine_tool_responses,
|
||||
_convert_raw_response_to_dict,
|
||||
_get_tool_requests,
|
||||
_get_tool_responses,
|
||||
)
|
||||
|
||||
|
||||
class TestConversationCorruptionInErrorPaths:
|
||||
"""
|
||||
Tests for Failure Mode #6: Conversation Corruption in Error Paths
|
||||
|
||||
When there's a logic error (orphaned tool output), the code appends
|
||||
it as a "user" message instead of proper tool response format,
|
||||
violating LLM conversation structure.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orphaned_tool_output_creates_user_message(self):
|
||||
"""
|
||||
Test that orphaned tool output (no pending calls) creates wrong message type.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Response with no tool calls
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = "No tools needed"
|
||||
mock_response.tool_calls = []
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": "No tools needed"}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
# Orphaned tool output - no pending calls but we have output
|
||||
last_tool_output={"result": "orphaned data"},
|
||||
conversation_history=[], # Empty - no pending calls
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Check the conversation for the orphaned output handling
|
||||
# The orphaned output is logged as error but may be added as user message
|
||||
# This is the BUG: should not add orphaned outputs to conversation
|
||||
|
||||
def test_create_tool_response_anthropic_format(self):
|
||||
"""Test that Anthropic format tool responses are created correctly."""
|
||||
response = _create_tool_response(
|
||||
"toolu_abc123",
|
||||
{"result": "success"}
|
||||
)
|
||||
|
||||
assert response["role"] == "user"
|
||||
assert response["type"] == "message"
|
||||
assert isinstance(response["content"], list)
|
||||
assert response["content"][0]["type"] == "tool_result"
|
||||
assert response["content"][0]["tool_use_id"] == "toolu_abc123"
|
||||
|
||||
def test_create_tool_response_openai_format(self):
|
||||
"""Test that OpenAI format tool responses are created correctly."""
|
||||
response = _create_tool_response(
|
||||
"call_abc123",
|
||||
{"result": "success"}
|
||||
)
|
||||
|
||||
assert response["role"] == "tool"
|
||||
assert response["tool_call_id"] == "call_abc123"
|
||||
assert "content" in response
|
||||
|
||||
def test_tool_response_with_string_content(self):
|
||||
"""Test tool response creation with string content."""
|
||||
response = _create_tool_response(
|
||||
"call_123",
|
||||
"Simple string result"
|
||||
)
|
||||
|
||||
assert response["content"] == "Simple string result"
|
||||
|
||||
def test_tool_response_with_complex_content(self):
|
||||
"""Test tool response creation with complex JSON content."""
|
||||
complex_data = {
|
||||
"nested": {"key": "value"},
|
||||
"list": [1, 2, 3],
|
||||
"null": None,
|
||||
}
|
||||
|
||||
response = _create_tool_response("call_123", complex_data)
|
||||
|
||||
# Content should be JSON string
|
||||
parsed = json.loads(response["content"])
|
||||
assert parsed == complex_data
|
||||
|
||||
|
||||
class TestCombineToolResponses:
|
||||
"""Tests for combining multiple tool responses."""
|
||||
|
||||
def test_combine_single_response_unchanged(self):
|
||||
"""Test that single response is returned unchanged."""
|
||||
responses = [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "123"}]
|
||||
}
|
||||
]
|
||||
|
||||
result = _combine_tool_responses(responses)
|
||||
assert result == responses
|
||||
|
||||
def test_combine_multiple_anthropic_responses(self):
|
||||
"""Test combining multiple Anthropic responses."""
|
||||
responses = [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "123", "content": "a"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "456", "content": "b"}]
|
||||
},
|
||||
]
|
||||
|
||||
result = _combine_tool_responses(responses)
|
||||
|
||||
# Should be combined into single message
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
assert len(result[0]["content"]) == 2
|
||||
|
||||
def test_combine_mixed_responses(self):
|
||||
"""Test combining mixed Anthropic and OpenAI responses."""
|
||||
responses = [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "123"}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_456",
|
||||
"content": "openai result"
|
||||
},
|
||||
]
|
||||
|
||||
result = _combine_tool_responses(responses)
|
||||
|
||||
# Anthropic response combined, OpenAI kept separate
|
||||
assert len(result) == 2
|
||||
|
||||
def test_combine_empty_list(self):
|
||||
"""Test combining empty list."""
|
||||
result = _combine_tool_responses([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestConversationHistoryValidation:
|
||||
"""Tests for conversation history validation."""
|
||||
|
||||
def test_pending_tool_calls_basic(self):
|
||||
"""Test basic pending tool call counting."""
|
||||
history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1"},
|
||||
{"type": "tool_use", "id": "call_2"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(history)
|
||||
|
||||
assert len(pending) == 2
|
||||
assert "call_1" in pending
|
||||
assert "call_2" in pending
|
||||
|
||||
def test_pending_tool_calls_with_responses(self):
|
||||
"""Test pending calls after some responses."""
|
||||
history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1"},
|
||||
{"type": "tool_use", "id": "call_2"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_1"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(history)
|
||||
|
||||
assert len(pending) == 1
|
||||
assert "call_2" in pending
|
||||
assert "call_1" not in pending
|
||||
|
||||
def test_pending_tool_calls_all_responded(self):
|
||||
"""Test when all tool calls have responses."""
|
||||
history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(history)
|
||||
|
||||
assert len(pending) == 0
|
||||
|
||||
def test_pending_tool_calls_openai_format(self):
|
||||
"""Test pending calls with OpenAI format."""
|
||||
history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1"},
|
||||
{"id": "call_2"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "result"
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(history)
|
||||
|
||||
assert len(pending) == 1
|
||||
assert "call_2" in pending
|
||||
|
||||
|
||||
class TestConversationUpdateBehavior:
|
||||
"""Tests for conversation update behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_includes_assistant_response(self):
|
||||
"""Test that assistant responses are added to conversation."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = "Final answer"
|
||||
mock_response.tool_calls = []
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": "Final answer"}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# No conversations output when no tool calls (just finished)
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "Final answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_with_tool_calls(self):
|
||||
"""Test that tool calls are properly added to conversation."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"param": "value"})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = "I'll use the test tool"
|
||||
mock_response.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_1"}]
|
||||
}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"param": "param"},
|
||||
"parameters": {
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Should have conversations output
|
||||
assert "conversations" in outputs
|
||||
|
||||
# Conversation should include the assistant message
|
||||
conversations = outputs["conversations"]
|
||||
has_assistant = any(
|
||||
msg.get("role") == "assistant"
|
||||
for msg in conversations
|
||||
)
|
||||
assert has_assistant
|
||||
|
||||
|
||||
class TestConversationHistoryPreservation:
|
||||
"""Tests for conversation history preservation across calls."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_history_preserved(self):
|
||||
"""Test that existing conversation history is preserved."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
existing_history = [
|
||||
{"role": "user", "content": "Previous message 1"},
|
||||
{"role": "assistant", "content": "Previous response 1"},
|
||||
{"role": "user", "content": "Previous message 2"},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = "New response"
|
||||
mock_response.tool_calls = []
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": "New response"}
|
||||
|
||||
captured_prompt = []
|
||||
|
||||
async def capture_llm_call(**kwargs):
|
||||
captured_prompt.extend(kwargs.get("prompt", []))
|
||||
return mock_response
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=capture_llm_call):
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="New message",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
conversation_history=existing_history,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
pass
|
||||
|
||||
# Existing history should be in the prompt
|
||||
assert len(captured_prompt) >= len(existing_history)
|
||||
|
||||
|
||||
class TestRawResponseConversion:
|
||||
"""Tests for raw response to dict conversion."""
|
||||
|
||||
def test_string_response(self):
|
||||
"""Test conversion of string response."""
|
||||
result = _convert_raw_response_to_dict("Hello world")
|
||||
|
||||
assert result == {"role": "assistant", "content": "Hello world"}
|
||||
|
||||
def test_dict_response(self):
|
||||
"""Test that dict response is passed through."""
|
||||
original = {"role": "assistant", "content": "test", "extra": "data"}
|
||||
result = _convert_raw_response_to_dict(original)
|
||||
|
||||
assert result == original
|
||||
|
||||
def test_object_response(self):
|
||||
"""Test conversion of object response."""
|
||||
mock_obj = MagicMock()
|
||||
|
||||
with patch("backend.blocks.smart_decision_maker.json.to_dict") as mock_to_dict:
|
||||
mock_to_dict.return_value = {"role": "assistant", "content": "converted"}
|
||||
result = _convert_raw_response_to_dict(mock_obj)
|
||||
|
||||
mock_to_dict.assert_called_once_with(mock_obj)
|
||||
assert result["role"] == "assistant"
|
||||
|
||||
|
||||
class TestConversationMessageStructure:
|
||||
"""Tests for correct conversation message structure."""
|
||||
|
||||
def test_system_message_not_duplicated(self):
|
||||
"""Test that system messages are not duplicated."""
|
||||
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||
|
||||
# Existing system message in history
|
||||
existing_history = [
|
||||
{"role": "system", "content": f"{MAIN_OBJECTIVE_PREFIX}Existing system prompt"},
|
||||
]
|
||||
|
||||
# The block should not add another system message
|
||||
# This is verified by checking the prompt passed to LLM
|
||||
|
||||
def test_user_message_not_duplicated(self):
|
||||
"""Test that user messages are not duplicated."""
|
||||
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||
|
||||
# Existing user message with MAIN_OBJECTIVE_PREFIX
|
||||
existing_history = [
|
||||
{"role": "user", "content": f"{MAIN_OBJECTIVE_PREFIX}Existing user prompt"},
|
||||
]
|
||||
|
||||
# The block should not add another user message with same prefix
|
||||
# This is verified by checking the prompt passed to LLM
|
||||
|
||||
def test_tool_response_after_tool_call(self):
|
||||
"""Test that tool responses come after tool calls."""
|
||||
# Valid conversation structure
|
||||
valid_history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
|
||||
}
|
||||
]
|
||||
|
||||
# This should be valid - tool result follows tool use
|
||||
pending = get_pending_tool_calls(valid_history)
|
||||
assert len(pending) == 0
|
||||
|
||||
def test_orphaned_tool_response_detected(self):
|
||||
"""Test detection of orphaned tool responses."""
|
||||
# Invalid: tool response without matching tool call
|
||||
invalid_history = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "orphan_call"}]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(invalid_history)
|
||||
|
||||
# Orphan response creates negative count
|
||||
# Should have count -1 for orphan_call
|
||||
# But it's filtered out (count <= 0)
|
||||
assert "orphan_call" not in pending
|
||||
|
||||
|
||||
class TestValidationErrorInConversation:
|
||||
"""Tests for validation error handling in conversation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_error_feedback_not_in_final_conversation(self):
|
||||
"""
|
||||
Test that validation error feedback is not in final conversation output.
|
||||
|
||||
When retrying due to validation errors, the error feedback should
|
||||
only be used for the retry prompt, not persisted in final conversation.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
if call_count == 1:
|
||||
# First call: invalid tool call
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"wrong": "param"})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [mock_tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": None}
|
||||
return resp
|
||||
else:
|
||||
# Second call: finish
|
||||
resp = MagicMock()
|
||||
resp.response = "Done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"correct": "correct"},
|
||||
"parameters": {
|
||||
"properties": {"correct": {"type": "string"}},
|
||||
"required": ["correct"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call):
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Should have finished successfully after retry
|
||||
assert "finished" in outputs
|
||||
|
||||
# Note: In traditional mode (agent_mode_max_iterations=0),
|
||||
# conversations are only output when there are tool calls
|
||||
# After the retry succeeds with no tool calls, we just get "finished"
|
||||
@@ -1,671 +0,0 @@
|
||||
"""
|
||||
Tests for SmartDecisionMaker data integrity failure modes.
|
||||
|
||||
Covers failure modes:
|
||||
6. Conversation Corruption in Error Paths
|
||||
7. Field Name Collision Not Detected
|
||||
8. No Type Validation in Dynamic Field Merging
|
||||
9. Unhandled Field Mapping Keys
|
||||
16. Silent Value Loss in Output Routing
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
|
||||
class TestFieldNameCollisionDetection:
|
||||
"""
|
||||
Tests for Failure Mode #7: Field Name Collision Not Detected
|
||||
|
||||
When multiple field names sanitize to the same value,
|
||||
the last one silently overwrites previous mappings.
|
||||
"""
|
||||
|
||||
def test_different_names_same_sanitized_result(self):
|
||||
"""Test that different names can produce the same sanitized result."""
|
||||
cleanup = SmartDecisionMakerBlock.cleanup
|
||||
|
||||
# All these sanitize to "test_field"
|
||||
variants = [
|
||||
"test_field",
|
||||
"Test Field",
|
||||
"test field",
|
||||
"TEST_FIELD",
|
||||
"Test_Field",
|
||||
"test-field", # Note: hyphen is preserved, this is different
|
||||
]
|
||||
|
||||
sanitized = [cleanup(v) for v in variants]
|
||||
|
||||
# Count unique sanitized values
|
||||
unique = set(sanitized)
|
||||
# Most should collide (except hyphenated one)
|
||||
assert len(unique) < len(variants), \
|
||||
f"Expected collisions, got {unique}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collision_last_one_wins(self):
|
||||
"""Test that in case of collision, the last field mapping wins."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "string", "description": "test"}
|
||||
)
|
||||
|
||||
# Two fields that sanitize to the same name
|
||||
mock_links = [
|
||||
Mock(sink_name="Test Field", sink_id="test-node", source_id="source"),
|
||||
Mock(sink_name="test field", sink_id="test-node", source_id="source"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
|
||||
# Only one property (collision)
|
||||
assert len(properties) == 1
|
||||
assert "test_field" in properties
|
||||
|
||||
# The mapping has only the last one
|
||||
# This is the BUG: first field's mapping is lost
|
||||
assert field_mapping["test_field"] in ["Test Field", "test field"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collision_causes_data_loss(self):
|
||||
"""
|
||||
Test that field collision can cause actual data loss.
|
||||
|
||||
Scenario:
|
||||
1. Two fields "Field A" and "field a" both map to "field_a"
|
||||
2. LLM provides value for "field_a"
|
||||
3. Only one original field gets the value
|
||||
4. The other field's expected input is lost
|
||||
"""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Simulate processing tool calls with collision
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"field_a": "value_for_both" # LLM uses sanitized name
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
# Tool definition with collision in field mapping
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"field_a": {"type": "string"},
|
||||
},
|
||||
"required": ["field_a"],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
# BUG: Only one original name is stored
|
||||
# "Field A" was overwritten by "field a"
|
||||
"_field_mapping": {"field_a": "field a"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
input_data = processed[0].input_data
|
||||
|
||||
# Only "field a" gets the value
|
||||
assert "field a" in input_data
|
||||
assert input_data["field a"] == "value_for_both"
|
||||
|
||||
# "Field A" is completely lost!
|
||||
assert "Field A" not in input_data
|
||||
|
||||
|
||||
class TestUnhandledFieldMappingKeys:
|
||||
"""
|
||||
Tests for Failure Mode #9: Unhandled Field Mapping Keys
|
||||
|
||||
When field_mapping is missing a key, the code falls back to
|
||||
the clean name, which may not be what the sink expects.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_field_mapping_falls_back_to_clean_name(self):
|
||||
"""Test that missing field mapping falls back to clean name."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"unmapped_field": "value"
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
# Tool definition with incomplete field mapping
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"unmapped_field": {"type": "string"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {}, # Empty! No mapping for unmapped_field
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
input_data = processed[0].input_data
|
||||
|
||||
# Falls back to clean name (which IS the key since it's already clean)
|
||||
assert "unmapped_field" in input_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_field_mapping(self):
|
||||
"""Test behavior with partial field mapping."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"mapped_field": "value1",
|
||||
"unmapped_field": "value2",
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"mapped_field": {"type": "string"},
|
||||
"unmapped_field": {"type": "string"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
# Only one field is mapped
|
||||
"_field_mapping": {
|
||||
"mapped_field": "Original Mapped Field",
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
input_data = processed[0].input_data
|
||||
|
||||
# Mapped field uses original name
|
||||
assert "Original Mapped Field" in input_data
|
||||
# Unmapped field uses clean name (fallback)
|
||||
assert "unmapped_field" in input_data
|
||||
|
||||
|
||||
class TestSilentValueLossInRouting:
|
||||
"""
|
||||
Tests for Failure Mode #16: Silent Value Loss in Output Routing
|
||||
|
||||
When routing fails in parse_execution_output, it returns None
|
||||
without any logging or indication of why it failed.
|
||||
"""
|
||||
|
||||
def test_routing_mismatch_returns_none_silently(self):
|
||||
"""Test that routing mismatch returns None without error."""
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
|
||||
output_item = ("tools_^_node-123_~_sanitized_name", "important_value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="Original Name", # Doesn't match sanitized_name
|
||||
)
|
||||
|
||||
# Silently returns None
|
||||
assert result is None
|
||||
# No way to distinguish "value is None" from "routing failed"
|
||||
|
||||
def test_wrong_node_id_returns_none(self):
|
||||
"""Test that wrong node ID returns None."""
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
|
||||
output_item = ("tools_^_node-123_~_field", "value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="different-node", # Wrong node
|
||||
sink_pin_name="field",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_wrong_selector_returns_none(self):
|
||||
"""Test that wrong selector returns None."""
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
|
||||
output_item = ("tools_^_node-123_~_field", "value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="different_selector", # Wrong selector
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="field",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_cannot_distinguish_none_value_from_routing_failure(self):
|
||||
"""
|
||||
Test that None as actual value is indistinguishable from routing failure.
|
||||
"""
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
|
||||
# Case 1: Actual None value
|
||||
output_with_none = ("field_name", None)
|
||||
result1 = parse_execution_output(
|
||||
output_with_none,
|
||||
link_output_selector="field_name",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
# Case 2: Routing failure
|
||||
output_mismatched = ("field_name", "value")
|
||||
result2 = parse_execution_output(
|
||||
output_mismatched,
|
||||
link_output_selector="different_field",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
# Both return None - cannot distinguish!
|
||||
assert result1 is None
|
||||
assert result2 is None
|
||||
|
||||
|
||||
class TestProcessToolCallsInputData:
|
||||
"""Tests for _process_tool_calls input data generation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_expected_args_included(self):
|
||||
"""Test that all expected arguments are included in input_data."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"provided_field": "value",
|
||||
# optional_field not provided
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"provided_field": {"type": "string"},
|
||||
"optional_field": {"type": "string"},
|
||||
},
|
||||
"required": ["provided_field"],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {
|
||||
"provided_field": "Provided Field",
|
||||
"optional_field": "Optional Field",
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
input_data = processed[0].input_data
|
||||
|
||||
# Both fields should be in input_data
|
||||
assert "Provided Field" in input_data
|
||||
assert "Optional Field" in input_data
|
||||
|
||||
# Provided has value, optional is None
|
||||
assert input_data["Provided Field"] == "value"
|
||||
assert input_data["Optional Field"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_args_from_llm_ignored(self):
|
||||
"""Test that extra arguments from LLM not in schema are ignored."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"expected_field": "value",
|
||||
"unexpected_field": "should_be_ignored",
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"expected_field": {"type": "string"},
|
||||
# unexpected_field not in schema
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"expected_field": "Expected Field"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
input_data = processed[0].input_data
|
||||
|
||||
# Only expected field should be in input_data
|
||||
assert "Expected Field" in input_data
|
||||
assert "unexpected_field" not in input_data
|
||||
assert "Unexpected Field" not in input_data
|
||||
|
||||
|
||||
class TestToolCallMatching:
|
||||
"""Tests for tool call matching logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_not_found_skipped(self):
|
||||
"""Test that tool calls for unknown tools are skipped."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "unknown_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "known_tool", # Different name
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
"_sink_node_id": "sink",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
# Unknown tool is skipped (not processed)
|
||||
assert len(processed) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_fallback(self):
|
||||
"""Test fallback when only one tool exists but name doesn't match."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "wrong_name"
|
||||
mock_tool_call.function.arguments = json.dumps({"field": "value"})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
# Only one tool defined
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "only_tool",
|
||||
"parameters": {
|
||||
"properties": {"field": {"type": "string"}},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"field": "Field"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
# Falls back to the only tool
|
||||
assert len(processed) == 1
|
||||
assert processed[0].input_data["Field"] == "value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_tool_calls_processed(self):
|
||||
"""Test that multiple tool calls are all processed."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call_1 = Mock()
|
||||
mock_tool_call_1.function.name = "tool_a"
|
||||
mock_tool_call_1.function.arguments = json.dumps({"a": "1"})
|
||||
|
||||
mock_tool_call_2 = Mock()
|
||||
mock_tool_call_2.function.name = "tool_b"
|
||||
mock_tool_call_2.function.arguments = json.dumps({"b": "2"})
|
||||
|
||||
mock_response.tool_calls = [mock_tool_call_1, mock_tool_call_2]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_a",
|
||||
"parameters": {
|
||||
"properties": {"a": {"type": "string"}},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink_a",
|
||||
"_field_mapping": {"a": "A"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_b",
|
||||
"parameters": {
|
||||
"properties": {"b": {"type": "string"}},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink_b",
|
||||
"_field_mapping": {"b": "B"},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 2
|
||||
assert processed[0].input_data["A"] == "1"
|
||||
assert processed[1].input_data["B"] == "2"
|
||||
|
||||
|
||||
class TestOutputEmitKeyGeneration:
|
||||
"""Tests for output emit key generation consistency."""
|
||||
|
||||
def test_emit_key_uses_sanitized_field_name(self):
|
||||
"""Test that emit keys use sanitized field names."""
|
||||
cleanup = SmartDecisionMakerBlock.cleanup
|
||||
|
||||
original_field = "Max Keyword Difficulty"
|
||||
sink_node_id = "node-123"
|
||||
|
||||
sanitized = cleanup(original_field)
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{sanitized}"
|
||||
|
||||
assert emit_key == "tools_^_node-123_~_max_keyword_difficulty"
|
||||
|
||||
def test_emit_key_format_consistent(self):
|
||||
"""Test that emit key format is consistent."""
|
||||
test_cases = [
|
||||
("field", "node", "tools_^_node_~_field"),
|
||||
("Field Name", "node-123", "tools_^_node-123_~_field_name"),
|
||||
("CPC ($)", "abc", "tools_^_abc_~_cpc____"),
|
||||
]
|
||||
|
||||
cleanup = SmartDecisionMakerBlock.cleanup
|
||||
|
||||
for original_field, node_id, expected in test_cases:
|
||||
sanitized = cleanup(original_field)
|
||||
emit_key = f"tools_^_{node_id}_~_{sanitized}"
|
||||
assert emit_key == expected, \
|
||||
f"Expected {expected}, got {emit_key}"
|
||||
|
||||
def test_emit_key_sanitization_idempotent(self):
|
||||
"""Test that sanitizing an already sanitized name gives same result."""
|
||||
cleanup = SmartDecisionMakerBlock.cleanup
|
||||
|
||||
original = "Test Field Name"
|
||||
first_clean = cleanup(original)
|
||||
second_clean = cleanup(first_clean)
|
||||
|
||||
assert first_clean == second_clean
|
||||
|
||||
|
||||
class TestToolFunctionMetadata:
|
||||
"""Tests for tool function metadata handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sink_node_id_preserved(self):
|
||||
"""Test that _sink_node_id is preserved in tool function."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "specific-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "string", "description": "test"}
|
||||
)
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="field", sink_id="specific-node-id", source_id="source"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
assert signature["function"]["_sink_node_id"] == "specific-node-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_mapping_preserved(self):
|
||||
"""Test that _field_mapping is preserved in tool function."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "string", "description": "test"}
|
||||
)
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="Original Field Name", sink_id="test-node", source_id="source"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
assert "original_field_name" in field_mapping
|
||||
assert field_mapping["original_field_name"] == "Original Field Name"
|
||||
|
||||
|
||||
class TestRequiredFieldsHandling:
|
||||
"""Tests for required fields handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_fields_use_sanitized_names(self):
|
||||
"""Test that required fields array uses sanitized names."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={
|
||||
"properties": {},
|
||||
"required": ["Required Field", "Another Required"],
|
||||
}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "string", "description": "test"}
|
||||
)
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="Required Field", sink_id="test-node", source_id="source"),
|
||||
Mock(sink_name="Another Required", sink_id="test-node", source_id="source"),
|
||||
Mock(sink_name="Optional Field", sink_id="test-node", source_id="source"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
required = signature["function"]["parameters"]["required"]
|
||||
|
||||
# Should use sanitized names
|
||||
assert "required_field" in required
|
||||
assert "another_required" in required
|
||||
|
||||
# Original names should NOT be in required
|
||||
assert "Required Field" not in required
|
||||
assert "Another Required" not in required
|
||||
|
||||
# Optional field should not be required
|
||||
assert "optional_field" not in required
|
||||
assert "Optional Field" not in required
|
||||
@@ -373,7 +373,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
input_data = block.input_schema(
|
||||
prompt="Create a user dictionary",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
|
||||
)
|
||||
|
||||
@@ -594,7 +594,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
retry=3, # Allow retries
|
||||
agent_mode_max_iterations=1,
|
||||
)
|
||||
|
||||
@@ -1,871 +0,0 @@
|
||||
"""
|
||||
Tests for SmartDecisionMaker error handling failure modes.
|
||||
|
||||
Covers failure modes:
|
||||
3. JSON Deserialization Without Exception Handling
|
||||
4. Database Transaction Inconsistency
|
||||
5. Missing Null Checks After Database Calls
|
||||
15. Error Message Context Loss
|
||||
17. No Validation of Dynamic Field Paths
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
_convert_raw_response_to_dict,
|
||||
_create_tool_response,
|
||||
)
|
||||
|
||||
|
||||
class TestJSONDeserializationErrors:
|
||||
"""
|
||||
Tests for Failure Mode #3: JSON Deserialization Without Exception Handling
|
||||
|
||||
When LLM returns malformed JSON in tool call arguments, the json.loads()
|
||||
call fails without proper error handling.
|
||||
"""
|
||||
|
||||
def test_malformed_json_single_quotes(self):
|
||||
"""
|
||||
Test that single quotes in JSON cause parsing failure.
|
||||
|
||||
LLMs sometimes return {'key': 'value'} instead of {"key": "value"}
|
||||
"""
|
||||
malformed = "{'key': 'value'}"
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed)
|
||||
|
||||
def test_malformed_json_trailing_comma(self):
|
||||
"""
|
||||
Test that trailing commas cause parsing failure.
|
||||
"""
|
||||
malformed = '{"key": "value",}'
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed)
|
||||
|
||||
def test_malformed_json_unquoted_keys(self):
|
||||
"""
|
||||
Test that unquoted keys cause parsing failure.
|
||||
"""
|
||||
malformed = '{key: "value"}'
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed)
|
||||
|
||||
def test_malformed_json_python_none(self):
|
||||
"""
|
||||
Test that Python None instead of null causes failure.
|
||||
"""
|
||||
malformed = '{"key": None}'
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed)
|
||||
|
||||
def test_malformed_json_python_true_false(self):
|
||||
"""
|
||||
Test that Python True/False instead of true/false causes failure.
|
||||
"""
|
||||
malformed_true = '{"key": True}'
|
||||
malformed_false = '{"key": False}'
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed_true)
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed_false)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_returns_malformed_json_crashes_block(self):
|
||||
"""
|
||||
Test that malformed JSON from LLM causes block to crash.
|
||||
|
||||
BUG: The json.loads() at line 625, 706, 1124 can throw JSONDecodeError
|
||||
which is not caught, causing the entire block to fail.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create response with malformed JSON
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = "{'malformed': 'json'}" # Single quotes!
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {"malformed": {"type": "string"}}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
# BUG: This should raise JSONDecodeError
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class TestDatabaseTransactionInconsistency:
|
||||
"""
|
||||
Tests for Failure Mode #4: Database Transaction Inconsistency
|
||||
|
||||
When multiple database operations are performed in sequence,
|
||||
a failure partway through leaves the database in an inconsistent state.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_input_insertion_on_failure(self):
|
||||
"""
|
||||
Test that partial failures during multi-input insertion
|
||||
leave database in inconsistent state.
|
||||
"""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Track which inputs were inserted
|
||||
inserted_inputs = []
|
||||
call_count = 0
|
||||
|
||||
async def failing_upsert(node_id, graph_exec_id, input_name, input_data):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
# Fail on the third input
|
||||
if call_count == 3:
|
||||
raise Exception("Database connection lost!")
|
||||
|
||||
inserted_inputs.append(input_name)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.node_exec_id = "exec-id"
|
||||
return mock_result, {input_name: input_data}
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "multi_input_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"input1": "value1",
|
||||
"input2": "value2",
|
||||
"input3": "value3", # This one will fail
|
||||
"input4": "value4",
|
||||
"input5": "value5",
|
||||
})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multi_input_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {
|
||||
"input1": "input1",
|
||||
"input2": "input2",
|
||||
"input3": "input3",
|
||||
"input4": "input4",
|
||||
"input5": "input5",
|
||||
},
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"input1": {"type": "string"},
|
||||
"input2": {"type": "string"},
|
||||
"input3": {"type": "string"},
|
||||
"input4": {"type": "string"},
|
||||
"input5": {"type": "string"},
|
||||
},
|
||||
"required": ["input1", "input2", "input3", "input4", "input5"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_db_client.upsert_execution_input.side_effect = failing_upsert
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm, \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=1,
|
||||
)
|
||||
|
||||
# The block should fail, but some inputs were already inserted
|
||||
outputs = {}
|
||||
try:
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
except Exception:
|
||||
pass # Expected
|
||||
|
||||
# BUG: Some inputs were inserted before failure
|
||||
# Database is now in inconsistent state
|
||||
assert len(inserted_inputs) == 2, \
|
||||
f"Expected 2 inserted before failure, got {inserted_inputs}"
|
||||
assert "input1" in inserted_inputs
|
||||
assert "input2" in inserted_inputs
|
||||
# input3, input4, input5 were never inserted
|
||||
|
||||
|
||||
class TestMissingNullChecks:
|
||||
"""
|
||||
Tests for Failure Mode #5: Missing Null Checks After Database Calls
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_node_returns_none(self):
|
||||
"""
|
||||
Test handling when get_node returns None.
|
||||
"""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"param": "value"})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "nonexistent-node",
|
||||
"_field_mapping": {"param": "param"},
|
||||
"parameters": {
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_node.return_value = None # Node doesn't exist!
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm, \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=1,
|
||||
)
|
||||
|
||||
# Should raise ValueError for missing node
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_execution_outputs(self):
|
||||
"""
|
||||
Test handling when get_execution_outputs_by_node_exec_id returns empty.
|
||||
"""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
if call_count > 1:
|
||||
resp = MagicMock()
|
||||
resp.response = "Done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 10
|
||||
resp.completion_tokens = 5
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [mock_tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {} # Empty!
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=2,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Empty outputs should be handled gracefully
|
||||
# (uses "Tool executed successfully" as fallback)
|
||||
assert "finished" in outputs or "conversations" in outputs
|
||||
|
||||
|
||||
class TestErrorMessageContextLoss:
|
||||
"""
|
||||
Tests for Failure Mode #15: Error Message Context Loss
|
||||
|
||||
When exceptions are caught and converted to strings, important
|
||||
debugging information is lost.
|
||||
"""
|
||||
|
||||
def test_exception_to_string_loses_traceback(self):
|
||||
"""
|
||||
Test that converting exception to string loses traceback.
|
||||
"""
|
||||
try:
|
||||
def inner():
|
||||
raise ValueError("Inner error")
|
||||
|
||||
def outer():
|
||||
inner()
|
||||
|
||||
outer()
|
||||
except Exception as e:
|
||||
error_string = str(e)
|
||||
error_repr = repr(e)
|
||||
|
||||
# String representation loses call stack
|
||||
assert "inner" not in error_string
|
||||
assert "outer" not in error_string
|
||||
|
||||
# Even repr doesn't have full traceback
|
||||
assert "Traceback" not in error_repr
|
||||
|
||||
def test_tool_response_loses_exception_type(self):
|
||||
"""
|
||||
Test that _create_tool_response loses exception type information.
|
||||
"""
|
||||
original_error = ConnectionError("Database unreachable")
|
||||
tool_response = _create_tool_response(
|
||||
"call_123",
|
||||
f"Tool execution failed: {str(original_error)}"
|
||||
)
|
||||
|
||||
content = tool_response.get("content", "")
|
||||
|
||||
# Original exception type is lost
|
||||
assert "ConnectionError" not in content
|
||||
# Only the message remains
|
||||
assert "Database unreachable" in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_mode_error_response_lacks_context(self):
|
||||
"""
|
||||
Test that agent mode error responses lack debugging context.
|
||||
"""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({})
|
||||
|
||||
mock_response_1 = MagicMock()
|
||||
mock_response_1.response = None
|
||||
mock_response_1.tool_calls = [mock_tool_call]
|
||||
mock_response_1.prompt_tokens = 50
|
||||
mock_response_1.completion_tokens = 25
|
||||
mock_response_1.reasoning = None
|
||||
mock_response_1.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
}
|
||||
|
||||
mock_response_2 = MagicMock()
|
||||
mock_response_2.response = "Handled the error"
|
||||
mock_response_2.tool_calls = []
|
||||
mock_response_2.prompt_tokens = 30
|
||||
mock_response_2.completion_tokens = 15
|
||||
mock_response_2.reasoning = None
|
||||
mock_response_2.raw_response = {"role": "assistant", "content": "Handled"}
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return mock_response_1
|
||||
return mock_response_2
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create a complex error with nested cause
|
||||
class CustomDatabaseError(Exception):
|
||||
pass
|
||||
|
||||
def create_complex_error():
|
||||
try:
|
||||
raise ConnectionError("Network timeout after 30s")
|
||||
except ConnectionError as e:
|
||||
raise CustomDatabaseError("Failed to connect to database") from e
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
|
||||
# Make upsert raise the complex error
|
||||
try:
|
||||
create_complex_error()
|
||||
except CustomDatabaseError as e:
|
||||
mock_db_client.upsert_execution_input.side_effect = e
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=2,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Check conversation for error details
|
||||
conversations = outputs.get("conversations", [])
|
||||
error_found = False
|
||||
for msg in conversations:
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "tool_result":
|
||||
result_content = item.get("content", "")
|
||||
if "Error" in result_content or "failed" in result_content.lower():
|
||||
error_found = True
|
||||
# BUG: The error content lacks:
|
||||
# - Exception type (CustomDatabaseError)
|
||||
# - Chained cause (ConnectionError)
|
||||
# - Stack trace
|
||||
assert "CustomDatabaseError" not in result_content
|
||||
assert "ConnectionError" not in result_content
|
||||
|
||||
# Note: error_found may be False if the error prevented tool response creation
|
||||
|
||||
|
||||
class TestRawResponseConversion:
|
||||
"""Tests for _convert_raw_response_to_dict edge cases."""
|
||||
|
||||
def test_string_response_converted(self):
|
||||
"""Test that string responses are properly wrapped."""
|
||||
result = _convert_raw_response_to_dict("Hello, world!")
|
||||
assert result == {"role": "assistant", "content": "Hello, world!"}
|
||||
|
||||
def test_dict_response_unchanged(self):
|
||||
"""Test that dict responses are passed through."""
|
||||
original = {"role": "assistant", "content": "test", "extra": "field"}
|
||||
result = _convert_raw_response_to_dict(original)
|
||||
assert result == original
|
||||
|
||||
def test_object_response_converted(self):
|
||||
"""Test that objects are converted using json.to_dict."""
|
||||
mock_obj = MagicMock()
|
||||
|
||||
with patch("backend.blocks.smart_decision_maker.json.to_dict") as mock_to_dict:
|
||||
mock_to_dict.return_value = {"converted": True}
|
||||
result = _convert_raw_response_to_dict(mock_obj)
|
||||
mock_to_dict.assert_called_once_with(mock_obj)
|
||||
assert result == {"converted": True}
|
||||
|
||||
def test_none_response(self):
|
||||
"""Test handling of None response."""
|
||||
with patch("backend.blocks.smart_decision_maker.json.to_dict") as mock_to_dict:
|
||||
mock_to_dict.return_value = None
|
||||
result = _convert_raw_response_to_dict(None)
|
||||
# None is not a string or dict, so it goes through to_dict
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestValidationRetryMechanism:
|
||||
"""Tests for the validation and retry mechanism."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_error_triggers_retry(self):
|
||||
"""
|
||||
Test that validation errors trigger retry with feedback.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
prompt = kwargs.get("prompt", [])
|
||||
|
||||
if call_count == 1:
|
||||
# First call: return tool call with wrong parameter
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"wrong_param": "value"})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [mock_tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": None}
|
||||
return resp
|
||||
else:
|
||||
# Second call: check that error feedback was added
|
||||
has_error_feedback = any(
|
||||
"parameter errors" in str(msg.get("content", "")).lower()
|
||||
for msg in prompt
|
||||
)
|
||||
|
||||
# Return correct tool call
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"correct_param": "value"})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [mock_tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": None}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"correct_param": "correct_param"},
|
||||
"parameters": {
|
||||
"properties": {"correct_param": {"type": "string"}},
|
||||
"required": ["correct_param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0, # Traditional mode
|
||||
retry=3,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Should have made multiple calls due to retry
|
||||
assert call_count >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_retries_exceeded(self):
|
||||
"""
|
||||
Test behavior when max retries are exceeded.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
# Always return invalid tool call
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"wrong": "param"})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [mock_tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": None}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"correct": "correct"},
|
||||
"parameters": {
|
||||
"properties": {"correct": {"type": "string"}},
|
||||
"required": ["correct"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
retry=2, # Only 2 retries
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
# Should raise ValueError after max retries
|
||||
with pytest.raises(ValueError, match="parameter errors"):
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
pass
|
||||
@@ -1,819 +0,0 @@
|
||||
"""
|
||||
Comprehensive tests for SmartDecisionMakerBlock pin name sanitization.
|
||||
|
||||
This test file addresses the critical bug where field names with spaces/special characters
|
||||
(e.g., "Max Keyword Difficulty") are not consistently sanitized between frontend and backend,
|
||||
causing tool calls to "go into the void".
|
||||
|
||||
The core issue:
|
||||
- Frontend connects link with original name: tools_^_{node_id}_~_Max Keyword Difficulty
|
||||
- Backend emits with sanitized name: tools_^_{node_id}_~_max_keyword_difficulty
|
||||
- parse_execution_output compares sink_pin_name directly without sanitization
|
||||
- Result: mismatch causes tool calls to fail silently
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.dynamic_fields import (
|
||||
parse_execution_output,
|
||||
sanitize_pin_name,
|
||||
)
|
||||
|
||||
|
||||
class TestCleanupFunction:
|
||||
"""Tests for the SmartDecisionMakerBlock.cleanup() static method."""
|
||||
|
||||
def test_cleanup_spaces_to_underscores(self):
|
||||
"""Spaces should be replaced with underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup("Max Keyword Difficulty") == "max_keyword_difficulty"
|
||||
|
||||
def test_cleanup_mixed_case_to_lowercase(self):
|
||||
"""Mixed case should be converted to lowercase."""
|
||||
assert SmartDecisionMakerBlock.cleanup("MaxKeywordDifficulty") == "maxkeyworddifficulty"
|
||||
assert SmartDecisionMakerBlock.cleanup("UPPER_CASE") == "upper_case"
|
||||
|
||||
def test_cleanup_special_characters(self):
|
||||
"""Special characters should be replaced with underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup("field@name!") == "field_name_"
|
||||
assert SmartDecisionMakerBlock.cleanup("value#1") == "value_1"
|
||||
assert SmartDecisionMakerBlock.cleanup("test$value") == "test_value"
|
||||
assert SmartDecisionMakerBlock.cleanup("a%b^c") == "a_b_c"
|
||||
|
||||
def test_cleanup_preserves_valid_characters(self):
|
||||
"""Valid characters (alphanumeric, underscore, hyphen) should be preserved."""
|
||||
assert SmartDecisionMakerBlock.cleanup("valid_name-123") == "valid_name-123"
|
||||
assert SmartDecisionMakerBlock.cleanup("abc123") == "abc123"
|
||||
|
||||
def test_cleanup_empty_string(self):
|
||||
"""Empty string should return empty string."""
|
||||
assert SmartDecisionMakerBlock.cleanup("") == ""
|
||||
|
||||
def test_cleanup_only_special_chars(self):
|
||||
"""String of only special characters should return underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup("@#$%") == "____"
|
||||
|
||||
def test_cleanup_unicode_characters(self):
|
||||
"""Unicode characters should be replaced with underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup("café") == "caf_"
|
||||
assert SmartDecisionMakerBlock.cleanup("日本語") == "___"
|
||||
|
||||
def test_cleanup_multiple_consecutive_spaces(self):
|
||||
"""Multiple consecutive spaces should become multiple underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup("a b") == "a___b"
|
||||
|
||||
def test_cleanup_leading_trailing_spaces(self):
|
||||
"""Leading/trailing spaces should become underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup(" name ") == "_name_"
|
||||
|
||||
def test_cleanup_realistic_field_names(self):
|
||||
"""Test realistic field names from actual use cases."""
|
||||
# From the reported bug
|
||||
assert SmartDecisionMakerBlock.cleanup("Max Keyword Difficulty") == "max_keyword_difficulty"
|
||||
# Other realistic names
|
||||
assert SmartDecisionMakerBlock.cleanup("Search Query") == "search_query"
|
||||
assert SmartDecisionMakerBlock.cleanup("API Response (JSON)") == "api_response__json_"
|
||||
assert SmartDecisionMakerBlock.cleanup("User's Input") == "user_s_input"
|
||||
|
||||
|
||||
class TestFieldMappingCreation:
|
||||
"""Tests for field mapping creation in function signatures."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_mapping_with_spaces_in_names(self):
|
||||
"""Test that field mapping correctly maps clean names back to original names with spaces."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test description"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": ["Max Keyword Difficulty"]}
|
||||
)
|
||||
|
||||
def get_field_schema(field_name):
|
||||
if field_name == "Max Keyword Difficulty":
|
||||
return {"type": "integer", "description": "Maximum keyword difficulty (0-100)"}
|
||||
raise KeyError(f"Field {field_name} not found")
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_test_~_max_keyword_difficulty",
|
||||
sink_name="Max Keyword Difficulty", # Original name with spaces
|
||||
sink_id="test-node-id",
|
||||
source_id="smart_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
# Verify the cleaned name is used in properties
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert "max_keyword_difficulty" in properties
|
||||
|
||||
# Verify the field mapping maps back to original
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
assert field_mapping["max_keyword_difficulty"] == "Max Keyword Difficulty"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_mapping_with_multiple_special_char_names(self):
|
||||
"""Test field mapping with multiple fields containing special characters."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "SEO Tool"
|
||||
mock_node.block.description = "SEO analysis tool"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
|
||||
def get_field_schema(field_name):
|
||||
schemas = {
|
||||
"Max Keyword Difficulty": {"type": "integer", "description": "Max difficulty"},
|
||||
"Search Volume (Monthly)": {"type": "integer", "description": "Monthly volume"},
|
||||
"CPC ($)": {"type": "number", "description": "Cost per click"},
|
||||
"Target URL": {"type": "string", "description": "URL to analyze"},
|
||||
}
|
||||
if field_name in schemas:
|
||||
return schemas[field_name]
|
||||
raise KeyError(f"Field {field_name} not found")
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="Max Keyword Difficulty", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
Mock(sink_name="Search Volume (Monthly)", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
Mock(sink_name="CPC ($)", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
Mock(sink_name="Target URL", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
|
||||
# Verify all cleaned names are in properties
|
||||
assert "max_keyword_difficulty" in properties
|
||||
assert "search_volume__monthly_" in properties
|
||||
assert "cpc____" in properties
|
||||
assert "target_url" in properties
|
||||
|
||||
# Verify field mappings
|
||||
assert field_mapping["max_keyword_difficulty"] == "Max Keyword Difficulty"
|
||||
assert field_mapping["search_volume__monthly_"] == "Search Volume (Monthly)"
|
||||
assert field_mapping["cpc____"] == "CPC ($)"
|
||||
assert field_mapping["target_url"] == "Target URL"
|
||||
|
||||
|
||||
class TestFieldNameCollision:
|
||||
"""Tests for detecting field name collisions after sanitization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collision_detection_same_sanitized_name(self):
|
||||
"""Test behavior when two different names sanitize to the same value."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# These two different names will sanitize to the same value
|
||||
name1 = "max keyword difficulty" # -> max_keyword_difficulty
|
||||
name2 = "Max Keyword Difficulty" # -> max_keyword_difficulty
|
||||
name3 = "MAX_KEYWORD_DIFFICULTY" # -> max_keyword_difficulty
|
||||
|
||||
assert SmartDecisionMakerBlock.cleanup(name1) == SmartDecisionMakerBlock.cleanup(name2)
|
||||
assert SmartDecisionMakerBlock.cleanup(name2) == SmartDecisionMakerBlock.cleanup(name3)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collision_in_function_signature(self):
|
||||
"""Test that collisions in sanitized names could cause issues."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test description"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
|
||||
def get_field_schema(field_name):
|
||||
return {"type": "string", "description": f"Field: {field_name}"}
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
|
||||
# Two different fields that sanitize to the same name
|
||||
mock_links = [
|
||||
Mock(sink_name="Test Field", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
Mock(sink_name="test field", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
|
||||
# Both sanitize to "test_field" - only one will be in properties
|
||||
assert "test_field" in properties
|
||||
# The field_mapping will have the last one written
|
||||
assert field_mapping["test_field"] in ["Test Field", "test field"]
|
||||
|
||||
|
||||
class TestOutputRouting:
|
||||
"""Tests for output routing with sanitized names."""
|
||||
|
||||
def test_emit_key_format_with_spaces(self):
|
||||
"""Test that emit keys use sanitized field names."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
original_field_name = "Max Keyword Difficulty"
|
||||
sink_node_id = "node-123"
|
||||
|
||||
sanitized_name = block.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_name}"
|
||||
|
||||
assert emit_key == "tools_^_node-123_~_max_keyword_difficulty"
|
||||
|
||||
def test_parse_execution_output_exact_match(self):
|
||||
"""Test parse_execution_output with exact matching names."""
|
||||
output_item = ("tools_^_node-123_~_max_keyword_difficulty", 50)
|
||||
|
||||
# When sink_pin_name matches the sanitized name, it should work
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="max_keyword_difficulty",
|
||||
)
|
||||
assert result == 50
|
||||
|
||||
def test_parse_execution_output_mismatch_original_vs_sanitized(self):
|
||||
"""
|
||||
CRITICAL TEST: This reproduces the exact bug reported.
|
||||
|
||||
When frontend creates a link with original name "Max Keyword Difficulty"
|
||||
but backend emits with sanitized name "max_keyword_difficulty",
|
||||
the tool call should still be routed correctly.
|
||||
|
||||
CURRENT BEHAVIOR (BUG): Returns None because names don't match
|
||||
EXPECTED BEHAVIOR: Should return the value (50) after sanitizing both names
|
||||
"""
|
||||
output_item = ("tools_^_node-123_~_max_keyword_difficulty", 50)
|
||||
|
||||
# This is what happens: sink_pin_name comes from frontend link (unsanitized)
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="Max Keyword Difficulty", # Original name with spaces
|
||||
)
|
||||
|
||||
# BUG: This currently returns None because:
|
||||
# - target_input_pin = "max_keyword_difficulty" (from emit key, sanitized)
|
||||
# - sink_pin_name = "Max Keyword Difficulty" (from link, original)
|
||||
# - They don't match, so routing fails
|
||||
#
|
||||
# TODO: When the bug is fixed, change this assertion to:
|
||||
# assert result == 50
|
||||
assert result is None # Current buggy behavior
|
||||
|
||||
def test_parse_execution_output_with_sanitized_sink_pin(self):
|
||||
"""Test that if sink_pin_name is pre-sanitized, routing works."""
|
||||
output_item = ("tools_^_node-123_~_max_keyword_difficulty", 50)
|
||||
|
||||
# If sink_pin_name is already sanitized, routing works
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="max_keyword_difficulty", # Pre-sanitized
|
||||
)
|
||||
assert result == 50
|
||||
|
||||
|
||||
class TestProcessToolCallsMapping:
|
||||
"""Tests for _process_tool_calls method field mapping."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_tool_calls_maps_clean_to_original(self):
|
||||
"""Test that _process_tool_calls correctly maps clean names back to original."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "seo_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"max_keyword_difficulty": 50, # LLM uses clean name
|
||||
"search_query": "test query",
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seo_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"max_keyword_difficulty": {"type": "integer"},
|
||||
"search_query": {"type": "string"},
|
||||
},
|
||||
"required": ["max_keyword_difficulty", "search_query"],
|
||||
},
|
||||
"_sink_node_id": "test-sink-node",
|
||||
"_field_mapping": {
|
||||
"max_keyword_difficulty": "Max Keyword Difficulty", # Original name
|
||||
"search_query": "Search Query",
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
tool_info = processed[0]
|
||||
|
||||
# Verify input_data uses ORIGINAL field names
|
||||
assert "Max Keyword Difficulty" in tool_info.input_data
|
||||
assert "Search Query" in tool_info.input_data
|
||||
assert tool_info.input_data["Max Keyword Difficulty"] == 50
|
||||
assert tool_info.input_data["Search Query"] == "test query"
|
||||
|
||||
|
||||
class TestToolOutputEmitting:
|
||||
"""Tests for the tool output emitting in traditional mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_keys_use_sanitized_names(self):
|
||||
"""Test that emit keys always use sanitized field names."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "seo_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"max_keyword_difficulty": 50,
|
||||
})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seo_tool",
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
"_field_mapping": {
|
||||
"max_keyword_difficulty": "Max Keyword Difficulty",
|
||||
},
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"max_keyword_difficulty": {"type": "integer"},
|
||||
},
|
||||
"required": ["max_keyword_difficulty"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# The emit key should use the sanitized field name
|
||||
# Even though the original was "Max Keyword Difficulty", emit uses sanitized
|
||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_max_keyword_difficulty"] == 50
|
||||
|
||||
|
||||
class TestSanitizationConsistency:
|
||||
"""Tests for ensuring sanitization is consistent throughout the pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip_with_spaces(self):
|
||||
"""
|
||||
Test the full round-trip of a field name with spaces through the system.
|
||||
|
||||
This simulates:
|
||||
1. Frontend creates link with sink_name="Max Keyword Difficulty"
|
||||
2. Backend creates function signature with cleaned property name
|
||||
3. LLM responds with cleaned name
|
||||
4. Backend processes response and maps back to original
|
||||
5. Backend emits with sanitized name
|
||||
6. Routing should match (currently broken)
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
original_field_name = "Max Keyword Difficulty"
|
||||
cleaned_field_name = SmartDecisionMakerBlock.cleanup(original_field_name)
|
||||
|
||||
# Step 1: Simulate frontend link creation
|
||||
mock_link = Mock()
|
||||
mock_link.sink_name = original_field_name # Frontend uses original
|
||||
mock_link.sink_id = "test-sink-node-id"
|
||||
mock_link.source_id = "smart-node-id"
|
||||
|
||||
# Step 2: Create function signature
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-sink-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "SEO Tool"
|
||||
mock_node.block.description = "SEO analysis"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": [original_field_name]}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "integer", "description": "Max difficulty"}
|
||||
)
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, [mock_link])
|
||||
|
||||
# Verify cleaned name is in properties
|
||||
assert cleaned_field_name in signature["function"]["parameters"]["properties"]
|
||||
# Verify field mapping exists
|
||||
assert signature["function"]["_field_mapping"][cleaned_field_name] == original_field_name
|
||||
|
||||
# Step 3: Simulate LLM response using cleaned name
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "seo_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
cleaned_field_name: 50 # LLM uses cleaned name
|
||||
})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
# Prepare tool_functions as they would be in run()
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seo_tool",
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
"_field_mapping": signature["function"]["_field_mapping"],
|
||||
"parameters": signature["function"]["parameters"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Step 4: Process tool calls
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
assert len(processed) == 1
|
||||
# Input data should have ORIGINAL name
|
||||
assert original_field_name in processed[0].input_data
|
||||
assert processed[0].input_data[original_field_name] == 50
|
||||
|
||||
# Step 5: Emit key generation (from run method logic)
|
||||
field_mapping = processed[0].field_mapping
|
||||
for clean_arg_name in signature["function"]["parameters"]["properties"]:
|
||||
original = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
sanitized_arg_name = block.cleanup(original)
|
||||
emit_key = f"tools_^_test-sink-node-id_~_{sanitized_arg_name}"
|
||||
|
||||
# Emit key uses sanitized name
|
||||
assert emit_key == f"tools_^_test-sink-node-id_~_{cleaned_field_name}"
|
||||
|
||||
# Step 6: Routing check (this is where the bug manifests)
|
||||
emit_key = f"tools_^_test-sink-node-id_~_{cleaned_field_name}"
|
||||
output_item = (emit_key, 50)
|
||||
|
||||
# Current routing uses original sink_name from link
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="test-sink-node-id",
|
||||
sink_pin_name=original_field_name, # Frontend's original name
|
||||
)
|
||||
|
||||
# BUG: This returns None because sanitized != original
|
||||
# When fixed, this should return 50
|
||||
assert result is None # Current broken behavior
|
||||
|
||||
def test_sanitization_is_idempotent(self):
|
||||
"""Test that sanitizing an already sanitized name gives the same result."""
|
||||
original = "Max Keyword Difficulty"
|
||||
first_clean = SmartDecisionMakerBlock.cleanup(original)
|
||||
second_clean = SmartDecisionMakerBlock.cleanup(first_clean)
|
||||
|
||||
assert first_clean == second_clean
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases in the sanitization pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_field_name(self):
|
||||
"""Test handling of empty field name."""
|
||||
assert SmartDecisionMakerBlock.cleanup("") == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_long_field_name(self):
|
||||
"""Test handling of very long field names."""
|
||||
long_name = "A" * 1000 + " " + "B" * 1000
|
||||
cleaned = SmartDecisionMakerBlock.cleanup(long_name)
|
||||
assert "_" in cleaned # Space was replaced
|
||||
assert len(cleaned) == len(long_name)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_name_with_newlines(self):
|
||||
"""Test handling of field names with newlines."""
|
||||
name_with_newline = "First Line\nSecond Line"
|
||||
cleaned = SmartDecisionMakerBlock.cleanup(name_with_newline)
|
||||
assert "\n" not in cleaned
|
||||
assert "_" in cleaned
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_name_with_tabs(self):
|
||||
"""Test handling of field names with tabs."""
|
||||
name_with_tab = "First\tSecond"
|
||||
cleaned = SmartDecisionMakerBlock.cleanup(name_with_tab)
|
||||
assert "\t" not in cleaned
|
||||
assert "_" in cleaned
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_numeric_field_name(self):
|
||||
"""Test handling of purely numeric field names."""
|
||||
assert SmartDecisionMakerBlock.cleanup("123") == "123"
|
||||
assert SmartDecisionMakerBlock.cleanup("123 456") == "123_456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hyphenated_field_names(self):
|
||||
"""Test that hyphens are preserved (valid in function names)."""
|
||||
assert SmartDecisionMakerBlock.cleanup("field-name") == "field-name"
|
||||
assert SmartDecisionMakerBlock.cleanup("Field-Name") == "field-name"
|
||||
|
||||
|
||||
class TestDynamicFieldsWithSpaces:
|
||||
"""Tests for dynamic fields with spaces in their names."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_dict_field_with_spaces(self):
|
||||
"""Test dynamic dictionary fields where the key contains spaces."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "CreateDictionary"
|
||||
mock_node.block.description = "Creates a dictionary"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": ["values"]}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
side_effect=KeyError("not found")
|
||||
)
|
||||
|
||||
# Dynamic field with a key containing spaces
|
||||
mock_links = [
|
||||
Mock(
|
||||
sink_name="values_#_User Name", # Dict key with space
|
||||
sink_id="test-node-id",
|
||||
source_id="smart_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
|
||||
# The cleaned name should be in properties
|
||||
expected_clean = SmartDecisionMakerBlock.cleanup("values_#_User Name")
|
||||
assert expected_clean in properties
|
||||
|
||||
# Field mapping should map back to original
|
||||
assert field_mapping[expected_clean] == "values_#_User Name"
|
||||
|
||||
|
||||
class TestAgentModeWithSpaces:
|
||||
"""Tests for agent mode with field names containing spaces."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_mode_tool_execution_with_spaces(self):
|
||||
"""Test that agent mode correctly handles field names with spaces."""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
original_field = "Max Keyword Difficulty"
|
||||
clean_field = SmartDecisionMakerBlock.cleanup(original_field)
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "seo_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
clean_field: 50 # LLM uses clean name
|
||||
})
|
||||
|
||||
mock_response_1 = MagicMock()
|
||||
mock_response_1.response = None
|
||||
mock_response_1.tool_calls = [mock_tool_call]
|
||||
mock_response_1.prompt_tokens = 50
|
||||
mock_response_1.completion_tokens = 25
|
||||
mock_response_1.reasoning = None
|
||||
mock_response_1.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_1", "type": "function"}],
|
||||
}
|
||||
|
||||
mock_response_2 = MagicMock()
|
||||
mock_response_2.response = "Task completed"
|
||||
mock_response_2.tool_calls = []
|
||||
mock_response_2.prompt_tokens = 30
|
||||
mock_response_2.completion_tokens = 15
|
||||
mock_response_2.reasoning = None
|
||||
mock_response_2.raw_response = {"role": "assistant", "content": "Task completed"}
|
||||
|
||||
llm_call_mock = AsyncMock()
|
||||
llm_call_mock.side_effect = [mock_response_1, mock_response_2]
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seo_tool",
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
"_field_mapping": {
|
||||
clean_field: original_field,
|
||||
},
|
||||
"parameters": {
|
||||
"properties": {
|
||||
clean_field: {"type": "integer"},
|
||||
},
|
||||
"required": [clean_field],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block-id"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
|
||||
mock_node_exec_result = MagicMock()
|
||||
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
|
||||
|
||||
# The input data should use ORIGINAL field name
|
||||
mock_input_data = {original_field: 50}
|
||||
mock_db_client.upsert_execution_input.return_value = (
|
||||
mock_node_exec_result,
|
||||
mock_input_data,
|
||||
)
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||
"result": {"status": "success"}
|
||||
}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
), patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
mock_node_stats = MagicMock()
|
||||
mock_node_stats.error = None
|
||||
mock_execution_processor.on_node_execution = AsyncMock(
|
||||
return_value=mock_node_stats
|
||||
)
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Analyze keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=3,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify upsert was called with original field name
|
||||
upsert_calls = mock_db_client.upsert_execution_input.call_args_list
|
||||
assert len(upsert_calls) > 0
|
||||
# Check that the original field name was used
|
||||
for call in upsert_calls:
|
||||
input_name = call.kwargs.get("input_name") or call.args[2]
|
||||
# The input name should be the original (mapped back)
|
||||
assert input_name == original_field
|
||||
|
||||
|
||||
class TestRequiredFieldsWithSpaces:
|
||||
"""Tests for required field handling with spaces in names."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_fields_use_clean_names(self):
|
||||
"""Test that required fields array uses clean names for API compatibility."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={
|
||||
"properties": {},
|
||||
"required": ["Max Keyword Difficulty", "Search Query"],
|
||||
}
|
||||
)
|
||||
|
||||
def get_field_schema(field_name):
|
||||
return {"type": "string", "description": f"Field: {field_name}"}
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="Max Keyword Difficulty", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
Mock(sink_name="Search Query", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
required = signature["function"]["parameters"]["required"]
|
||||
|
||||
# Required array should use CLEAN names for API compatibility
|
||||
assert "max_keyword_difficulty" in required
|
||||
assert "search_query" in required
|
||||
# Original names should NOT be in required
|
||||
assert "Max Keyword Difficulty" not in required
|
||||
assert "Search Query" not in required
|
||||
@@ -1,5 +1,5 @@
|
||||
from gravitasml.parser import Parser
|
||||
from gravitasml.token import Token, tokenize
|
||||
from gravitasml.token import tokenize
|
||||
|
||||
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.model import SchemaField
|
||||
@@ -25,38 +25,6 @@ class XMLParserBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_tokens(tokens: list[Token]) -> None:
|
||||
"""Ensure the XML has a single root element and no stray text."""
|
||||
if not tokens:
|
||||
raise ValueError("XML input is empty.")
|
||||
|
||||
depth = 0
|
||||
root_seen = False
|
||||
|
||||
for token in tokens:
|
||||
if token.type == "TAG_OPEN":
|
||||
if depth == 0 and root_seen:
|
||||
raise ValueError("XML must have a single root element.")
|
||||
depth += 1
|
||||
if depth == 1:
|
||||
root_seen = True
|
||||
elif token.type == "TAG_CLOSE":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise SyntaxError("Unexpected closing tag in XML input.")
|
||||
elif token.type in {"TEXT", "ESCAPE"}:
|
||||
if depth == 0 and token.value:
|
||||
raise ValueError(
|
||||
"XML contains text outside the root element; "
|
||||
"wrap content in a single root tag."
|
||||
)
|
||||
|
||||
if depth != 0:
|
||||
raise SyntaxError("Unclosed tag detected in XML input.")
|
||||
if not root_seen:
|
||||
raise ValueError("XML must include a root element.")
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add size limits to prevent XML bomb attacks
|
||||
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
|
||||
@@ -67,9 +35,7 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
try:
|
||||
tokens = list(tokenize(input_data.input_xml))
|
||||
self._validate_tokens(tokens)
|
||||
|
||||
tokens = tokenize(input_data.input_xml)
|
||||
parser = Parser(tokens)
|
||||
parsed_result = parser.parse()
|
||||
yield "parsed_xml", parsed_result
|
||||
|
||||
@@ -111,8 +111,6 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path[:3] == "/v/":
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path.startswith("/shorts/"):
|
||||
return parsed_url.path.split("/")[2]
|
||||
raise ValueError(f"Invalid YouTube URL: {url}")
|
||||
|
||||
def get_transcript(
|
||||
|
||||
@@ -244,7 +244,11 @@ def websocket(server_address: str, graph_exec_id: str):
|
||||
|
||||
import websockets.asyncio.client
|
||||
|
||||
from backend.api.ws_api import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
|
||||
from backend.server.ws_api import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
)
|
||||
|
||||
async def send_message(server_address: str):
|
||||
uri = f"ws://{server_address}"
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""CLI utilities for backend development & administration"""
|
||||
@@ -1,57 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to generate OpenAPI JSON specification for the FastAPI app.
|
||||
|
||||
This script imports the FastAPI app from backend.api.rest_api and outputs
|
||||
the OpenAPI specification as JSON to stdout or a specified file.
|
||||
|
||||
Usage:
|
||||
`poetry run python generate_openapi_json.py`
|
||||
`poetry run python generate_openapi_json.py --output openapi.json`
|
||||
`poetry run python generate_openapi_json.py --indent 4 --output openapi.json`
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--output",
|
||||
type=click.Path(dir_okay=False, path_type=Path),
|
||||
help="Output file path (default: stdout)",
|
||||
)
|
||||
@click.option(
|
||||
"--pretty",
|
||||
type=click.BOOL,
|
||||
default=False,
|
||||
help="Pretty-print JSON output (indented 2 spaces)",
|
||||
)
|
||||
def main(output: Path, pretty: bool):
|
||||
"""Generate and output the OpenAPI JSON specification."""
|
||||
openapi_schema = get_openapi_schema()
|
||||
|
||||
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
|
||||
|
||||
if output:
|
||||
output.write_text(json_output)
|
||||
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
|
||||
click.echo(f"\n{json_output[:500]} ...")
|
||||
else:
|
||||
print(json_output)
|
||||
|
||||
|
||||
def get_openapi_schema():
|
||||
"""Get the OpenAPI schema from the FastAPI app"""
|
||||
from backend.api.rest_api import app
|
||||
|
||||
return app.openapi()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["LOG_LEVEL"] = "ERROR" # disable stdout log output
|
||||
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
from .base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
|
||||
class APIKeyInfo(APIAuthorizationInfo):
|
||||
class APIKeyInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
head: str = Field(
|
||||
@@ -28,9 +26,12 @@ class APIKeyInfo(APIAuthorizationInfo):
|
||||
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
|
||||
)
|
||||
status: APIKeyStatus
|
||||
permissions: list[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
type: Literal["api_key"] = "api_key" # type: ignore
|
||||
user_id: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
@@ -40,7 +41,7 @@ class APIKeyInfo(APIAuthorizationInfo):
|
||||
head=api_key.head,
|
||||
tail=api_key.tail,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
scopes=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
@@ -210,7 +211,7 @@ async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
|
||||
|
||||
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
|
||||
return required_permission in api_key.scopes
|
||||
return required_permission in api_key.permissions
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:
|
||||
@@ -1,15 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class APIAuthorizationInfo(BaseModel):
|
||||
user_id: str
|
||||
scopes: list[APIKeyPermission]
|
||||
type: Literal["oauth", "api_key"]
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime] = None
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
@@ -1,872 +0,0 @@
|
||||
"""
|
||||
OAuth 2.0 Provider Data Layer
|
||||
|
||||
Handles management of OAuth applications, authorization codes,
|
||||
access tokens, and refresh tokens.
|
||||
|
||||
Hashing strategy:
|
||||
- Access tokens & Refresh tokens: SHA256 (deterministic, allows direct lookup by hash)
|
||||
- Client secrets: Scrypt with salt (lookup by client_id, then verify with salt)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission as APIPermission
|
||||
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
||||
from prisma.models import OAuthApplication as PrismaOAuthApplication
|
||||
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
||||
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
||||
from prisma.types import OAuthApplicationUpdateInput
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from .base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith() # Only used for client secret hashing (Scrypt)
|
||||
|
||||
|
||||
def _generate_token() -> str:
|
||||
"""Generate a cryptographically secure random token."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
"""Hash a token using SHA256 (deterministic, for direct lookup)."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
# Token TTLs
|
||||
AUTHORIZATION_CODE_TTL = timedelta(minutes=10)
|
||||
ACCESS_TOKEN_TTL = timedelta(hours=1)
|
||||
REFRESH_TOKEN_TTL = timedelta(days=30)
|
||||
|
||||
ACCESS_TOKEN_PREFIX = "agpt_xt_"
|
||||
REFRESH_TOKEN_PREFIX = "agpt_rt_"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Exception Classes
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthError(Exception):
|
||||
"""Base OAuth error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidClientError(OAuthError):
|
||||
"""Invalid client_id or client_secret"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidGrantError(OAuthError):
|
||||
"""Invalid or expired authorization code/refresh token"""
|
||||
|
||||
def __init__(self, reason: str):
|
||||
self.reason = reason
|
||||
super().__init__(f"Invalid grant: {reason}")
|
||||
|
||||
|
||||
class InvalidTokenError(OAuthError):
|
||||
"""Invalid, expired, or revoked token"""
|
||||
|
||||
def __init__(self, reason: str):
|
||||
self.reason = reason
|
||||
super().__init__(f"Invalid token: {reason}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Data Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthApplicationInfo(BaseModel):
|
||||
"""OAuth application information (without client secret hash)"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
client_id: str
|
||||
redirect_uris: list[str]
|
||||
grant_types: list[str]
|
||||
scopes: list[APIPermission]
|
||||
owner_id: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@staticmethod
|
||||
def from_db(app: PrismaOAuthApplication):
|
||||
return OAuthApplicationInfo(
|
||||
id=app.id,
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
logo_url=app.logoUrl,
|
||||
client_id=app.clientId,
|
||||
redirect_uris=app.redirectUris,
|
||||
grant_types=app.grantTypes,
|
||||
scopes=[APIPermission(s) for s in app.scopes],
|
||||
owner_id=app.ownerId,
|
||||
is_active=app.isActive,
|
||||
created_at=app.createdAt,
|
||||
updated_at=app.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class OAuthApplicationInfoWithSecret(OAuthApplicationInfo):
|
||||
"""OAuth application with client secret hash (for validation)"""
|
||||
|
||||
client_secret_hash: str
|
||||
client_secret_salt: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(app: PrismaOAuthApplication):
|
||||
return OAuthApplicationInfoWithSecret(
|
||||
**OAuthApplicationInfo.from_db(app).model_dump(),
|
||||
client_secret_hash=app.clientSecret,
|
||||
client_secret_salt=app.clientSecretSalt,
|
||||
)
|
||||
|
||||
def verify_secret(self, plaintext_secret: str) -> bool:
|
||||
"""Verify a plaintext client secret against the stored hash"""
|
||||
# Use keysmith.verify_key() with stored salt
|
||||
return keysmith.verify_key(
|
||||
plaintext_secret, self.client_secret_hash, self.client_secret_salt
|
||||
)
|
||||
|
||||
|
||||
class OAuthAuthorizationCodeInfo(BaseModel):
|
||||
"""Authorization code information"""
|
||||
|
||||
id: str
|
||||
code: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
application_id: str
|
||||
user_id: str
|
||||
scopes: list[APIPermission]
|
||||
redirect_uri: str
|
||||
code_challenge: Optional[str] = None
|
||||
code_challenge_method: Optional[str] = None
|
||||
used_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def is_used(self) -> bool:
|
||||
return self.used_at is not None
|
||||
|
||||
@staticmethod
|
||||
def from_db(code: PrismaOAuthAuthorizationCode):
|
||||
return OAuthAuthorizationCodeInfo(
|
||||
id=code.id,
|
||||
code=code.code,
|
||||
created_at=code.createdAt,
|
||||
expires_at=code.expiresAt,
|
||||
application_id=code.applicationId,
|
||||
user_id=code.userId,
|
||||
scopes=[APIPermission(s) for s in code.scopes],
|
||||
redirect_uri=code.redirectUri,
|
||||
code_challenge=code.codeChallenge,
|
||||
code_challenge_method=code.codeChallengeMethod,
|
||||
used_at=code.usedAt,
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccessTokenInfo(APIAuthorizationInfo):
|
||||
"""Access token information"""
|
||||
|
||||
id: str
|
||||
expires_at: datetime # type: ignore
|
||||
application_id: str
|
||||
|
||||
type: Literal["oauth"] = "oauth" # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthAccessToken):
|
||||
return OAuthAccessTokenInfo(
|
||||
id=token.id,
|
||||
user_id=token.userId,
|
||||
scopes=[APIPermission(s) for s in token.scopes],
|
||||
created_at=token.createdAt,
|
||||
expires_at=token.expiresAt,
|
||||
last_used_at=None,
|
||||
revoked_at=token.revokedAt,
|
||||
application_id=token.applicationId,
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccessToken(OAuthAccessTokenInfo):
|
||||
"""Access token with plaintext token included (sensitive)"""
|
||||
|
||||
token: SecretStr = Field(description="Plaintext token (sensitive)")
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthAccessToken, plaintext_token: str): # type: ignore
|
||||
return OAuthAccessToken(
|
||||
**OAuthAccessTokenInfo.from_db(token).model_dump(),
|
||||
token=SecretStr(plaintext_token),
|
||||
)
|
||||
|
||||
|
||||
class OAuthRefreshTokenInfo(BaseModel):
|
||||
"""Refresh token information"""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
scopes: list[APIPermission]
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
application_id: str
|
||||
revoked_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def is_revoked(self) -> bool:
|
||||
return self.revoked_at is not None
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthRefreshToken):
|
||||
return OAuthRefreshTokenInfo(
|
||||
id=token.id,
|
||||
user_id=token.userId,
|
||||
scopes=[APIPermission(s) for s in token.scopes],
|
||||
created_at=token.createdAt,
|
||||
expires_at=token.expiresAt,
|
||||
application_id=token.applicationId,
|
||||
revoked_at=token.revokedAt,
|
||||
)
|
||||
|
||||
|
||||
class OAuthRefreshToken(OAuthRefreshTokenInfo):
|
||||
"""Refresh token with plaintext token included (sensitive)"""
|
||||
|
||||
token: SecretStr = Field(description="Plaintext token (sensitive)")
|
||||
|
||||
@staticmethod
|
||||
def from_db(token: PrismaOAuthRefreshToken, plaintext_token: str): # type: ignore
|
||||
return OAuthRefreshToken(
|
||||
**OAuthRefreshTokenInfo.from_db(token).model_dump(),
|
||||
token=SecretStr(plaintext_token),
|
||||
)
|
||||
|
||||
|
||||
class TokenIntrospectionResult(BaseModel):
|
||||
"""Result of token introspection (RFC 7662)"""
|
||||
|
||||
active: bool
|
||||
scopes: Optional[list[str]] = None
|
||||
client_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
exp: Optional[int] = None # Unix timestamp
|
||||
token_type: Optional[Literal["access_token", "refresh_token"]] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Application Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def get_oauth_application(client_id: str) -> Optional[OAuthApplicationInfo]:
|
||||
"""Get OAuth application by client ID (without secret)"""
|
||||
app = await PrismaOAuthApplication.prisma().find_unique(
|
||||
where={"clientId": client_id}
|
||||
)
|
||||
if not app:
|
||||
return None
|
||||
return OAuthApplicationInfo.from_db(app)
|
||||
|
||||
|
||||
async def get_oauth_application_with_secret(
|
||||
client_id: str,
|
||||
) -> Optional[OAuthApplicationInfoWithSecret]:
|
||||
"""Get OAuth application by client ID (with secret hash for validation)"""
|
||||
app = await PrismaOAuthApplication.prisma().find_unique(
|
||||
where={"clientId": client_id}
|
||||
)
|
||||
if not app:
|
||||
return None
|
||||
return OAuthApplicationInfoWithSecret.from_db(app)
|
||||
|
||||
|
||||
async def validate_client_credentials(
|
||||
client_id: str, client_secret: str
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Validate client credentials and return application info.
|
||||
|
||||
Raises:
|
||||
InvalidClientError: If client_id or client_secret is invalid, or app is inactive
|
||||
"""
|
||||
app = await get_oauth_application_with_secret(client_id)
|
||||
if not app:
|
||||
raise InvalidClientError("Invalid client_id")
|
||||
|
||||
if not app.is_active:
|
||||
raise InvalidClientError("Application is not active")
|
||||
|
||||
# Verify client secret
|
||||
if not app.verify_secret(client_secret):
|
||||
raise InvalidClientError("Invalid client_secret")
|
||||
|
||||
# Return without secret hash
|
||||
return OAuthApplicationInfo(**app.model_dump(exclude={"client_secret_hash"}))
|
||||
|
||||
|
||||
def validate_redirect_uri(app: OAuthApplicationInfo, redirect_uri: str) -> bool:
|
||||
"""Validate that redirect URI is registered for the application"""
|
||||
return redirect_uri in app.redirect_uris
|
||||
|
||||
|
||||
def validate_scopes(
|
||||
app: OAuthApplicationInfo, requested_scopes: list[APIPermission]
|
||||
) -> bool:
|
||||
"""Validate that all requested scopes are allowed for the application"""
|
||||
return all(scope in app.scopes for scope in requested_scopes)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Code Flow
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _generate_authorization_code() -> str:
|
||||
"""Generate a cryptographically secure authorization code"""
|
||||
# 32 bytes = 256 bits of entropy
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
async def create_authorization_code(
|
||||
application_id: str,
|
||||
user_id: str,
|
||||
scopes: list[APIPermission],
|
||||
redirect_uri: str,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[Literal["S256", "plain"]] = None,
|
||||
) -> OAuthAuthorizationCodeInfo:
|
||||
"""
|
||||
Create a new authorization code.
|
||||
Expires in 10 minutes and can only be used once.
|
||||
"""
|
||||
code = _generate_authorization_code()
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + AUTHORIZATION_CODE_TTL
|
||||
|
||||
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"code": code,
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
"redirectUri": redirect_uri,
|
||||
"codeChallenge": code_challenge,
|
||||
"codeChallengeMethod": code_challenge_method,
|
||||
}
|
||||
)
|
||||
|
||||
return OAuthAuthorizationCodeInfo.from_db(saved_code)
|
||||
|
||||
|
||||
async def consume_authorization_code(
|
||||
code: str,
|
||||
application_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: Optional[str] = None,
|
||||
) -> tuple[str, list[APIPermission]]:
|
||||
"""
|
||||
Consume an authorization code and return (user_id, scopes).
|
||||
|
||||
This marks the code as used and validates:
|
||||
- Code exists and matches application
|
||||
- Code is not expired
|
||||
- Code has not been used
|
||||
- Redirect URI matches
|
||||
- PKCE code verifier matches (if code challenge was provided)
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If code is invalid, expired, used, or PKCE fails
|
||||
"""
|
||||
auth_code = await PrismaOAuthAuthorizationCode.prisma().find_unique(
|
||||
where={"code": code}
|
||||
)
|
||||
|
||||
if not auth_code:
|
||||
raise InvalidGrantError("authorization code not found")
|
||||
|
||||
# Validate application
|
||||
if auth_code.applicationId != application_id:
|
||||
raise InvalidGrantError(
|
||||
"authorization code does not belong to this application"
|
||||
)
|
||||
|
||||
# Check if already used
|
||||
if auth_code.usedAt is not None:
|
||||
raise InvalidGrantError(
|
||||
f"authorization code already used at {auth_code.usedAt}"
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
if auth_code.expiresAt < now:
|
||||
raise InvalidGrantError("authorization code expired")
|
||||
|
||||
# Validate redirect URI
|
||||
if auth_code.redirectUri != redirect_uri:
|
||||
raise InvalidGrantError("redirect_uri mismatch")
|
||||
|
||||
# Validate PKCE if code challenge was provided
|
||||
if auth_code.codeChallenge:
|
||||
if not code_verifier:
|
||||
raise InvalidGrantError("code_verifier required but not provided")
|
||||
|
||||
if not _verify_pkce(
|
||||
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
|
||||
):
|
||||
raise InvalidGrantError("PKCE verification failed")
|
||||
|
||||
# Mark code as used
|
||||
await PrismaOAuthAuthorizationCode.prisma().update(
|
||||
where={"code": code},
|
||||
data={"usedAt": now},
|
||||
)
|
||||
|
||||
return auth_code.userId, [APIPermission(s) for s in auth_code.scopes]
|
||||
|
||||
|
||||
def _verify_pkce(
|
||||
code_verifier: str, code_challenge: str, code_challenge_method: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Verify PKCE code verifier against code challenge.
|
||||
|
||||
Supports:
|
||||
- S256: SHA256(code_verifier) == code_challenge
|
||||
- plain: code_verifier == code_challenge
|
||||
"""
|
||||
if code_challenge_method == "S256":
|
||||
# Hash the verifier with SHA256 and base64url encode
|
||||
hashed = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||
computed_challenge = (
|
||||
secrets.token_urlsafe(len(hashed)).encode("ascii").decode("ascii")
|
||||
)
|
||||
# For proper base64url encoding
|
||||
import base64
|
||||
|
||||
computed_challenge = (
|
||||
base64.urlsafe_b64encode(hashed).decode("ascii").rstrip("=")
|
||||
)
|
||||
return secrets.compare_digest(computed_challenge, code_challenge)
|
||||
elif code_challenge_method == "plain" or code_challenge_method is None:
|
||||
# Plain comparison
|
||||
return secrets.compare_digest(code_verifier, code_challenge)
|
||||
else:
|
||||
logger.warning(f"Unsupported code challenge method: {code_challenge_method}")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Access Token Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def create_access_token(
|
||||
application_id: str, user_id: str, scopes: list[APIPermission]
|
||||
) -> OAuthAccessToken:
|
||||
"""
|
||||
Create a new access token.
|
||||
Returns OAuthAccessToken (with plaintext token).
|
||||
"""
|
||||
plaintext_token = ACCESS_TOKEN_PREFIX + _generate_token()
|
||||
token_hash = _hash_token(plaintext_token)
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + ACCESS_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthAccessToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
}
|
||||
)
|
||||
|
||||
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
|
||||
|
||||
async def validate_access_token(
|
||||
token: str,
|
||||
) -> tuple[OAuthAccessTokenInfo, OAuthApplicationInfo]:
|
||||
"""
|
||||
Validate an access token and return token info.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If token is invalid, expired, or revoked
|
||||
InvalidClientError: If the client application is not marked as active
|
||||
"""
|
||||
token_hash = _hash_token(token)
|
||||
|
||||
# Direct lookup by hash
|
||||
access_token = await PrismaOAuthAccessToken.prisma().find_unique(
|
||||
where={"token": token_hash}, include={"Application": True}
|
||||
)
|
||||
|
||||
if not access_token:
|
||||
raise InvalidTokenError("access token not found")
|
||||
|
||||
if not access_token.Application: # should be impossible
|
||||
raise InvalidClientError("Client application not found")
|
||||
|
||||
if not access_token.Application.isActive:
|
||||
raise InvalidClientError("Client application is disabled")
|
||||
|
||||
if access_token.revokedAt is not None:
|
||||
raise InvalidTokenError("access token has been revoked")
|
||||
|
||||
# Check expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
if access_token.expiresAt < now:
|
||||
raise InvalidTokenError("access token expired")
|
||||
|
||||
return (
|
||||
OAuthAccessTokenInfo.from_db(access_token),
|
||||
OAuthApplicationInfo.from_db(access_token.Application),
|
||||
)
|
||||
|
||||
|
||||
async def revoke_access_token(
|
||||
token: str, application_id: str
|
||||
) -> OAuthAccessTokenInfo | None:
|
||||
"""
|
||||
Revoke an access token.
|
||||
|
||||
Args:
|
||||
token: The plaintext access token to revoke
|
||||
application_id: The application ID making the revocation request.
|
||||
Only tokens belonging to this application will be revoked.
|
||||
|
||||
Returns:
|
||||
OAuthAccessTokenInfo if token was found and revoked, None otherwise.
|
||||
|
||||
Note:
|
||||
Always performs exactly 2 DB queries regardless of outcome to prevent
|
||||
timing side-channel attacks that could reveal token existence.
|
||||
"""
|
||||
try:
|
||||
token_hash = _hash_token(token)
|
||||
|
||||
# Use update_many to filter by both token and applicationId
|
||||
updated_count = await PrismaOAuthAccessToken.prisma().update_many(
|
||||
where={
|
||||
"token": token_hash,
|
||||
"applicationId": application_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Always perform second query to ensure constant time
|
||||
result = await PrismaOAuthAccessToken.prisma().find_unique(
|
||||
where={"token": token_hash}
|
||||
)
|
||||
|
||||
# Only return result if we actually revoked something
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
return OAuthAccessTokenInfo.from_db(result) if result else None
|
||||
except Exception as e:
|
||||
logger.exception(f"Error revoking access token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Refresh Token Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def create_refresh_token(
|
||||
application_id: str, user_id: str, scopes: list[APIPermission]
|
||||
) -> OAuthRefreshToken:
|
||||
"""
|
||||
Create a new refresh token.
|
||||
Returns OAuthRefreshToken (with plaintext token).
|
||||
"""
|
||||
plaintext_token = REFRESH_TOKEN_PREFIX + _generate_token()
|
||||
token_hash = _hash_token(plaintext_token)
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + REFRESH_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthRefreshToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
}
|
||||
)
|
||||
|
||||
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
|
||||
|
||||
async def refresh_tokens(
|
||||
refresh_token: str, application_id: str
|
||||
) -> tuple[OAuthAccessToken, OAuthRefreshToken]:
|
||||
"""
|
||||
Use a refresh token to create new access and refresh tokens.
|
||||
Returns (new_access_token, new_refresh_token) both with plaintext tokens included.
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If refresh token is invalid, expired, or revoked
|
||||
"""
|
||||
token_hash = _hash_token(refresh_token)
|
||||
|
||||
# Direct lookup by hash
|
||||
rt = await PrismaOAuthRefreshToken.prisma().find_unique(where={"token": token_hash})
|
||||
|
||||
if not rt:
|
||||
raise InvalidGrantError("refresh token not found")
|
||||
|
||||
# NOTE: no need to check Application.isActive, this is checked by the token endpoint
|
||||
|
||||
if rt.revokedAt is not None:
|
||||
raise InvalidGrantError("refresh token has been revoked")
|
||||
|
||||
# Validate application
|
||||
if rt.applicationId != application_id:
|
||||
raise InvalidGrantError("refresh token does not belong to this application")
|
||||
|
||||
# Check expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
if rt.expiresAt < now:
|
||||
raise InvalidGrantError("refresh token expired")
|
||||
|
||||
# Revoke old refresh token
|
||||
await PrismaOAuthRefreshToken.prisma().update(
|
||||
where={"token": token_hash},
|
||||
data={"revokedAt": now},
|
||||
)
|
||||
|
||||
# Create new access and refresh tokens with same scopes
|
||||
scopes = [APIPermission(s) for s in rt.scopes]
|
||||
new_access_token = await create_access_token(
|
||||
rt.applicationId,
|
||||
rt.userId,
|
||||
scopes,
|
||||
)
|
||||
new_refresh_token = await create_refresh_token(
|
||||
rt.applicationId,
|
||||
rt.userId,
|
||||
scopes,
|
||||
)
|
||||
|
||||
return new_access_token, new_refresh_token
|
||||
|
||||
|
||||
async def revoke_refresh_token(
|
||||
token: str, application_id: str
|
||||
) -> OAuthRefreshTokenInfo | None:
|
||||
"""
|
||||
Revoke a refresh token.
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token to revoke
|
||||
application_id: The application ID making the revocation request.
|
||||
Only tokens belonging to this application will be revoked.
|
||||
|
||||
Returns:
|
||||
OAuthRefreshTokenInfo if token was found and revoked, None otherwise.
|
||||
|
||||
Note:
|
||||
Always performs exactly 2 DB queries regardless of outcome to prevent
|
||||
timing side-channel attacks that could reveal token existence.
|
||||
"""
|
||||
try:
|
||||
token_hash = _hash_token(token)
|
||||
|
||||
# Use update_many to filter by both token and applicationId
|
||||
updated_count = await PrismaOAuthRefreshToken.prisma().update_many(
|
||||
where={
|
||||
"token": token_hash,
|
||||
"applicationId": application_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Always perform second query to ensure constant time
|
||||
result = await PrismaOAuthRefreshToken.prisma().find_unique(
|
||||
where={"token": token_hash}
|
||||
)
|
||||
|
||||
# Only return result if we actually revoked something
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
return OAuthRefreshTokenInfo.from_db(result) if result else None
|
||||
except Exception as e:
|
||||
logger.exception(f"Error revoking refresh token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def introspect_token(
|
||||
token: str,
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None,
|
||||
) -> TokenIntrospectionResult:
|
||||
"""
|
||||
Introspect a token and return its metadata (RFC 7662).
|
||||
|
||||
Returns TokenIntrospectionResult with active=True and metadata if valid,
|
||||
or active=False if the token is invalid/expired/revoked.
|
||||
"""
|
||||
# Try as access token first (or if hint says "access_token")
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
token_info, app = await validate_access_token(token)
|
||||
return TokenIntrospectionResult(
|
||||
active=True,
|
||||
scopes=list(s.value for s in token_info.scopes),
|
||||
client_id=app.client_id if app else None,
|
||||
user_id=token_info.user_id,
|
||||
exp=int(token_info.expires_at.timestamp()),
|
||||
token_type="access_token",
|
||||
)
|
||||
except InvalidTokenError:
|
||||
pass # Try as refresh token
|
||||
|
||||
# Try as refresh token
|
||||
token_hash = _hash_token(token)
|
||||
refresh_token = await PrismaOAuthRefreshToken.prisma().find_unique(
|
||||
where={"token": token_hash}
|
||||
)
|
||||
|
||||
if refresh_token and refresh_token.revokedAt is None:
|
||||
# Check if valid (not expired)
|
||||
now = datetime.now(timezone.utc)
|
||||
if refresh_token.expiresAt > now:
|
||||
app = await get_oauth_application_by_id(refresh_token.applicationId)
|
||||
return TokenIntrospectionResult(
|
||||
active=True,
|
||||
scopes=list(s for s in refresh_token.scopes),
|
||||
client_id=app.client_id if app else None,
|
||||
user_id=refresh_token.userId,
|
||||
exp=int(refresh_token.expiresAt.timestamp()),
|
||||
token_type="refresh_token",
|
||||
)
|
||||
|
||||
# Token not found or inactive
|
||||
return TokenIntrospectionResult(active=False)
|
||||
|
||||
|
||||
async def get_oauth_application_by_id(app_id: str) -> Optional[OAuthApplicationInfo]:
|
||||
"""Get OAuth application by ID"""
|
||||
app = await PrismaOAuthApplication.prisma().find_unique(where={"id": app_id})
|
||||
if not app:
|
||||
return None
|
||||
return OAuthApplicationInfo.from_db(app)
|
||||
|
||||
|
||||
async def list_user_oauth_applications(user_id: str) -> list[OAuthApplicationInfo]:
|
||||
"""Get all OAuth applications owned by a user"""
|
||||
apps = await PrismaOAuthApplication.prisma().find_many(
|
||||
where={"ownerId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [OAuthApplicationInfo.from_db(app) for app in apps]
|
||||
|
||||
|
||||
async def update_oauth_application(
|
||||
app_id: str,
|
||||
*,
|
||||
owner_id: str,
|
||||
is_active: Optional[bool] = None,
|
||||
logo_url: Optional[str] = None,
|
||||
) -> Optional[OAuthApplicationInfo]:
|
||||
"""
|
||||
Update OAuth application active status.
|
||||
Only the owner can update their app's status.
|
||||
|
||||
Returns the updated app info, or None if app not found or not owned by user.
|
||||
"""
|
||||
# First verify ownership
|
||||
app = await PrismaOAuthApplication.prisma().find_first(
|
||||
where={"id": app_id, "ownerId": owner_id}
|
||||
)
|
||||
if not app:
|
||||
return None
|
||||
|
||||
patch: OAuthApplicationUpdateInput = {}
|
||||
if is_active is not None:
|
||||
patch["isActive"] = is_active
|
||||
if logo_url:
|
||||
patch["logoUrl"] = logo_url
|
||||
if not patch:
|
||||
return OAuthApplicationInfo.from_db(app) # return unchanged
|
||||
|
||||
updated_app = await PrismaOAuthApplication.prisma().update(
|
||||
where={"id": app_id},
|
||||
data=patch,
|
||||
)
|
||||
return OAuthApplicationInfo.from_db(updated_app) if updated_app else None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Cleanup
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def cleanup_expired_oauth_tokens() -> dict[str, int]:
|
||||
"""
|
||||
Delete expired OAuth tokens from the database.
|
||||
|
||||
This removes:
|
||||
- Expired authorization codes (10 min TTL)
|
||||
- Expired access tokens (1 hour TTL)
|
||||
- Expired refresh tokens (30 day TTL)
|
||||
|
||||
Returns a dict with counts of deleted tokens by type.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Delete expired authorization codes
|
||||
codes_result = await PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"expiresAt": {"lt": now}}
|
||||
)
|
||||
|
||||
# Delete expired access tokens
|
||||
access_result = await PrismaOAuthAccessToken.prisma().delete_many(
|
||||
where={"expiresAt": {"lt": now}}
|
||||
)
|
||||
|
||||
# Delete expired refresh tokens
|
||||
refresh_result = await PrismaOAuthRefreshToken.prisma().delete_many(
|
||||
where={"expiresAt": {"lt": now}}
|
||||
)
|
||||
|
||||
deleted = {
|
||||
"authorization_codes": codes_result,
|
||||
"access_tokens": access_result,
|
||||
"refresh_tokens": refresh_result,
|
||||
}
|
||||
|
||||
total = sum(deleted.values())
|
||||
if total > 0:
|
||||
logger.info(f"Cleaned up {total} expired OAuth tokens: {deleted}")
|
||||
|
||||
return deleted
|
||||
@@ -50,8 +50,6 @@ from .model import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
from .graph import Link
|
||||
|
||||
app_config = Config()
|
||||
@@ -474,7 +472,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.requires_human_review: bool = False
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -617,77 +614,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
block_id=self.id,
|
||||
) from ex
|
||||
|
||||
async def is_block_exec_need_review(
|
||||
self,
|
||||
input_data: BlockInput,
|
||||
*,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: "ExecutionContext",
|
||||
**kwargs,
|
||||
) -> tuple[bool, BlockInput]:
|
||||
"""
|
||||
Check if this block execution needs human review and handle the review process.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_pause, input_data_to_use)
|
||||
- should_pause: True if execution should be paused for review
|
||||
- input_data_to_use: The input data to use (may be modified by reviewer)
|
||||
"""
|
||||
# Skip review if not required or safe mode is disabled
|
||||
if not self.requires_human_review or not execution_context.safe_mode:
|
||||
return False, input_data
|
||||
|
||||
from backend.blocks.helpers.review import HITLReviewHelper
|
||||
|
||||
# Handle the review request and get decision
|
||||
decision = await HITLReviewHelper.handle_review_decision(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
)
|
||||
|
||||
if decision is None:
|
||||
# We're awaiting review - pause execution
|
||||
return True, input_data
|
||||
|
||||
if not decision.should_proceed:
|
||||
# Review was rejected, raise an error to stop execution
|
||||
raise BlockExecutionError(
|
||||
message=f"Block execution rejected by reviewer: {decision.message}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Review was approved - use the potentially modified data
|
||||
# ReviewResult.data must be a dict for block inputs
|
||||
reviewed_data = decision.review_result.data
|
||||
if not isinstance(reviewed_data, dict):
|
||||
raise BlockExecutionError(
|
||||
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
return False, reviewed_data
|
||||
|
||||
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
# Check for review requirement and get potentially modified input data
|
||||
should_pause, input_data = await self.is_block_exec_need_review(
|
||||
input_data, **kwargs
|
||||
)
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
@@ -695,7 +622,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||
**kwargs,
|
||||
|
||||
@@ -59,13 +59,12 @@ from backend.integrations.credentials_store import (
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2,
|
||||
LlmModel.O1: 16,
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: 6,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
@@ -88,7 +87,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
|
||||
@@ -16,7 +16,6 @@ from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBala
|
||||
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.admin.model import UserHistoryResponse
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
@@ -30,6 +29,7 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
@@ -341,19 +341,6 @@ class UserCreditBase(ABC):
|
||||
|
||||
if result:
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if transaction.amount > 0 and transaction.type in [
|
||||
CreditTransactionType.GRANT,
|
||||
CreditTransactionType.TOP_UP,
|
||||
]:
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return result[0]["balance"]
|
||||
|
||||
async def _add_transaction(
|
||||
@@ -543,22 +530,6 @@ class UserCreditBase(ABC):
|
||||
if result:
|
||||
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if (
|
||||
amount > 0
|
||||
and is_active
|
||||
and transaction_type
|
||||
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
|
||||
):
|
||||
# Lazy import to avoid circular dependency with executor.manager
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return new_balance, tx_key
|
||||
|
||||
# If no result, either user doesn't exist or insufficient balance
|
||||
|
||||
@@ -111,7 +111,7 @@ def get_database_schema() -> str:
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||
schema_prefix = f"{schema}." if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from multiprocessing import Manager
|
||||
from queue import Empty
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -383,7 +382,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
self,
|
||||
execution_context: ExecutionContext,
|
||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
@@ -391,7 +389,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip or set(),
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
@@ -1147,8 +1144,6 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||
nodes_to_skip: set[str] = Field(default_factory=set)
|
||||
"""Node IDs that should be skipped due to optional credentials not being configured."""
|
||||
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
|
||||
|
||||
|
||||
@@ -1168,12 +1163,16 @@ class NodeExecutionEntry(BaseModel):
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
"""
|
||||
Queue for managing the execution of agents.
|
||||
This will be shared between different processes
|
||||
Thread-safe queue for managing node execution within a single graph execution.
|
||||
|
||||
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
|
||||
threads within the same process. If migrating back to ProcessPoolExecutor,
|
||||
replace with multiprocessing.Manager().Queue() for cross-process safety.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queue = Manager().Queue()
|
||||
# Thread-safe queue (not multiprocessing) — see class docstring
|
||||
self.queue: queue.Queue[T] = queue.Queue()
|
||||
|
||||
def add(self, execution: T) -> T:
|
||||
self.queue.put(execution)
|
||||
@@ -1188,7 +1187,7 @@ class ExecutionQueue(Generic[T]):
|
||||
def get_or_none(self) -> T | None:
|
||||
try:
|
||||
return self.queue.get_nowait()
|
||||
except Empty:
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Tests for ExecutionQueue thread-safety."""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionQueue
|
||||
|
||||
|
||||
def test_execution_queue_uses_stdlib_queue():
|
||||
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
|
||||
q = ExecutionQueue()
|
||||
assert isinstance(q.queue, queue.Queue)
|
||||
|
||||
|
||||
def test_basic_operations():
|
||||
"""Test add, get, empty, and get_or_none."""
|
||||
q = ExecutionQueue()
|
||||
|
||||
assert q.empty() is True
|
||||
assert q.get_or_none() is None
|
||||
|
||||
result = q.add("item1")
|
||||
assert result == "item1"
|
||||
assert q.empty() is False
|
||||
|
||||
item = q.get()
|
||||
assert item == "item1"
|
||||
assert q.empty() is True
|
||||
|
||||
|
||||
def test_thread_safety():
|
||||
"""Test concurrent access from multiple threads."""
|
||||
q = ExecutionQueue()
|
||||
results = []
|
||||
num_items = 100
|
||||
|
||||
def producer():
|
||||
for i in range(num_items):
|
||||
q.add(f"item_{i}")
|
||||
|
||||
def consumer():
|
||||
count = 0
|
||||
while count < num_items:
|
||||
item = q.get_or_none()
|
||||
if item is not None:
|
||||
results.append(item)
|
||||
count += 1
|
||||
|
||||
producer_thread = threading.Thread(target=producer)
|
||||
consumer_thread = threading.Thread(target=consumer)
|
||||
|
||||
producer_thread.start()
|
||||
consumer_thread.start()
|
||||
|
||||
producer_thread.join(timeout=5)
|
||||
consumer_thread.join(timeout=5)
|
||||
|
||||
assert len(results) == num_items
|
||||
@@ -94,15 +94,6 @@ class Node(BaseDbModel):
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
|
||||
@property
|
||||
def credentials_optional(self) -> bool:
|
||||
"""
|
||||
Whether credentials are optional for this node.
|
||||
When True and credentials are not configured, the node will be skipped
|
||||
during execution rather than causing a validation error.
|
||||
"""
|
||||
return self.metadata.get("credentials_optional", False)
|
||||
|
||||
@property
|
||||
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
|
||||
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
||||
@@ -244,10 +235,7 @@ class BaseGraph(BaseDbModel):
|
||||
return any(
|
||||
node.block_id
|
||||
for node in self.nodes
|
||||
if (
|
||||
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||
or node.block.requires_human_review
|
||||
)
|
||||
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -338,35 +326,7 @@ class Graph(BaseGraph):
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
schema = self._credentials_input_schema.jsonschema()
|
||||
|
||||
# Determine which credential fields are required based on credentials_optional metadata
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
required_fields = []
|
||||
|
||||
# Build a map of node_id -> node for quick lookup
|
||||
all_nodes = {node.id: node for node in self.nodes}
|
||||
for sub_graph in self.sub_graphs:
|
||||
for node in sub_graph.nodes:
|
||||
all_nodes[node.id] = node
|
||||
|
||||
for field_key, (
|
||||
_field_info,
|
||||
node_field_pairs,
|
||||
) in graph_credentials_inputs.items():
|
||||
# A field is required if ANY node using it has credentials_optional=False
|
||||
is_required = False
|
||||
for node_id, _field_name in node_field_pairs:
|
||||
node = all_nodes.get(node_id)
|
||||
if node and not node.credentials_optional:
|
||||
is_required = True
|
||||
break
|
||||
|
||||
if is_required:
|
||||
required_fields.append(field_key)
|
||||
|
||||
schema["required"] = required_fields
|
||||
return schema
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
|
||||
@@ -6,14 +6,14 @@ import fastapi.exceptions
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.api.features.store.model as store
|
||||
from backend.api.model import CreateGraph
|
||||
import backend.server.v2.store.model as store
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.block import BlockSchema, BlockSchemaInput
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_user
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -396,58 +396,3 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
)
|
||||
assert got_graph is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_node_credentials_optional_default():
|
||||
"""Test that credentials_optional defaults to False when not set in metadata."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_true():
|
||||
"""Test that credentials_optional returns True when explicitly set."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": True},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
|
||||
|
||||
def test_node_credentials_optional_false():
|
||||
"""Test that credentials_optional returns False when explicitly set to False."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": False},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_with_other_metadata():
|
||||
"""Test that credentials_optional works correctly with other metadata present."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={
|
||||
"position": {"x": 100, "y": 200},
|
||||
"customized_name": "My Custom Node",
|
||||
"credentials_optional": True,
|
||||
},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||
assert node.metadata["customized_name"] == "My Custom Node"
|
||||
|
||||
@@ -13,7 +13,7 @@ from prisma.models import PendingHumanReview
|
||||
from prisma.types import PendingHumanReviewUpdateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.executions.review.model import (
|
||||
from backend.server.v2.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
SafeJsonData,
|
||||
)
|
||||
@@ -100,7 +100,7 @@ async def get_or_create_human_review(
|
||||
return None
|
||||
else:
|
||||
return ReviewResult(
|
||||
data=review.payload,
|
||||
data=review.payload if review.status == ReviewStatus.APPROVED else None,
|
||||
status=review.status,
|
||||
message=review.reviewMessage or "",
|
||||
processed=review.processed,
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .db import BaseDbModel
|
||||
from .graph import NodeModel
|
||||
@@ -79,7 +79,7 @@ class WebhookWithRelations(Webhook):
|
||||
# integrations.py → library/model.py → integrations.py (for Webhook)
|
||||
# Runtime import is used in WebhookWithRelations.from_db() method instead
|
||||
# Import at runtime to avoid circular dependency
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
return WebhookWithRelations(
|
||||
**Webhook.from_db(webhook).model_dump(),
|
||||
@@ -285,8 +285,8 @@ async def unlink_webhook_from_graph(
|
||||
user_id: The ID of the user (for authorization)
|
||||
"""
|
||||
# Avoid circular imports
|
||||
from backend.api.features.library.db import set_preset_webhook
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.server.v2.library.db import set_preset_webhook
|
||||
|
||||
# Find all nodes in this graph that use this webhook
|
||||
nodes = await AgentNode.prisma().find_many(
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
from backend.api.model import NotificationPayload
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.server.model import NotificationPayload
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.api.features.store.model import StoreAgentDetails
|
||||
from backend.api.model import OnboardingNotificationPayload
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.notification_bus import (
|
||||
@@ -18,6 +16,8 @@ from backend.data.notification_bus import (
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.server.model import OnboardingNotificationPayload
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.timezone_utils import get_user_timezone_or_utc
|
||||
@@ -442,8 +442,6 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
agentGraphVersions=agent.agentGraphVersions,
|
||||
agentGraphId=agent.agentGraphId,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
for agent in recommended_agents
|
||||
|
||||
@@ -1,513 +0,0 @@
|
||||
"""
|
||||
Tests for dynamic fields edge cases and failure modes.
|
||||
|
||||
Covers failure modes:
|
||||
8. No Type Validation in Dynamic Field Merging
|
||||
17. No Validation of Dynamic Field Paths
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.dynamic_fields import (
|
||||
DICT_SPLIT,
|
||||
LIST_SPLIT,
|
||||
OBJC_SPLIT,
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
is_dynamic_field,
|
||||
is_tool_pin,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
sanitize_pin_name,
|
||||
)
|
||||
|
||||
|
||||
class TestDynamicFieldMergingTypeValidation:
|
||||
"""
|
||||
Tests for Failure Mode #8: No Type Validation in Dynamic Field Merging
|
||||
|
||||
When merging dynamic fields, there's no validation that intermediate
|
||||
structures have the correct type, leading to potential type coercion errors.
|
||||
"""
|
||||
|
||||
def test_merge_dict_field_creates_dict(self):
|
||||
"""Test that dictionary fields create dict structure."""
|
||||
data = {
|
||||
"values_#_name": "Alice",
|
||||
"values_#_age": 30,
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "values" in result
|
||||
assert isinstance(result["values"], dict)
|
||||
assert result["values"]["name"] == "Alice"
|
||||
assert result["values"]["age"] == 30
|
||||
|
||||
def test_merge_list_field_creates_list(self):
|
||||
"""Test that list fields create list structure."""
|
||||
data = {
|
||||
"items_$_0": "first",
|
||||
"items_$_1": "second",
|
||||
"items_$_2": "third",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "items" in result
|
||||
assert isinstance(result["items"], list)
|
||||
assert result["items"] == ["first", "second", "third"]
|
||||
|
||||
def test_merge_with_existing_primitive_type_conflict(self):
|
||||
"""
|
||||
Test behavior when merging into existing primitive value.
|
||||
|
||||
BUG: If the base field already exists as a primitive,
|
||||
merging a dynamic field may fail or corrupt data.
|
||||
"""
|
||||
# Pre-existing primitive value
|
||||
data = {
|
||||
"value": "I am a string", # Primitive
|
||||
"value_#_key": "dict value", # Dynamic dict field
|
||||
}
|
||||
|
||||
# This may raise an error or produce unexpected results
|
||||
# depending on merge order and implementation
|
||||
try:
|
||||
result = merge_execution_input(data)
|
||||
# If it succeeds, check what happened
|
||||
# The primitive may have been overwritten
|
||||
if isinstance(result.get("value"), dict):
|
||||
# Primitive was converted to dict - data loss!
|
||||
assert "key" in result["value"]
|
||||
else:
|
||||
# Or the dynamic field was ignored
|
||||
pass
|
||||
except (TypeError, AttributeError):
|
||||
# Expected error when trying to merge into primitive
|
||||
pass
|
||||
|
||||
def test_merge_list_with_gaps(self):
|
||||
"""Test merging list fields with non-contiguous indices."""
|
||||
data = {
|
||||
"items_$_0": "zero",
|
||||
"items_$_2": "two", # Gap at index 1
|
||||
"items_$_5": "five", # Larger gap
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "items" in result
|
||||
# Check how gaps are handled
|
||||
items = result["items"]
|
||||
assert items[0] == "zero"
|
||||
# Index 1 may be None or missing
|
||||
assert items[2] == "two"
|
||||
assert items[5] == "five"
|
||||
|
||||
def test_merge_nested_dynamic_fields(self):
|
||||
"""Test merging deeply nested dynamic fields."""
|
||||
data = {
|
||||
"data_#_users_$_0": "user1",
|
||||
"data_#_users_$_1": "user2",
|
||||
"data_#_config_#_enabled": True,
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
# Complex nested structures should be created
|
||||
assert "data" in result
|
||||
|
||||
def test_merge_object_field(self):
|
||||
"""Test merging object attribute fields."""
|
||||
data = {
|
||||
"user_@_name": "Alice",
|
||||
"user_@_email": "alice@example.com",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "user" in result
|
||||
# Object fields create dict-like structure
|
||||
assert result["user"]["name"] == "Alice"
|
||||
assert result["user"]["email"] == "alice@example.com"
|
||||
|
||||
def test_merge_mixed_field_types(self):
|
||||
"""Test merging mixed regular and dynamic fields."""
|
||||
data = {
|
||||
"regular": "value",
|
||||
"dict_field_#_key": "dict_value",
|
||||
"list_field_$_0": "list_item",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["regular"] == "value"
|
||||
assert result["dict_field"]["key"] == "dict_value"
|
||||
assert result["list_field"][0] == "list_item"
|
||||
|
||||
|
||||
class TestDynamicFieldPathValidation:
|
||||
"""
|
||||
Tests for Failure Mode #17: No Validation of Dynamic Field Paths
|
||||
|
||||
When traversing dynamic field paths, intermediate None values
|
||||
can cause TypeErrors instead of graceful failures.
|
||||
"""
|
||||
|
||||
def test_parse_output_with_none_intermediate(self):
|
||||
"""
|
||||
Test parse_execution_output with None intermediate value.
|
||||
|
||||
If data contains {"items": None} and we try to access items[0],
|
||||
it should return None gracefully, not raise TypeError.
|
||||
"""
|
||||
# Output with nested path
|
||||
output_item = ("data_$_0", "value")
|
||||
|
||||
# When the base is None, should return None
|
||||
# This tests the path traversal logic
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="data",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
# Should handle gracefully (return the value or None)
|
||||
# Not raise TypeError
|
||||
|
||||
def test_extract_base_field_name_with_multiple_delimiters(self):
|
||||
"""Test extracting base name with multiple delimiters."""
|
||||
# Multiple dict delimiters
|
||||
assert extract_base_field_name("a_#_b_#_c") == "a"
|
||||
|
||||
# Multiple list delimiters
|
||||
assert extract_base_field_name("a_$_0_$_1") == "a"
|
||||
|
||||
# Mixed delimiters
|
||||
assert extract_base_field_name("a_#_b_$_0") == "a"
|
||||
|
||||
def test_is_dynamic_field_edge_cases(self):
|
||||
"""Test is_dynamic_field with edge cases."""
|
||||
# Standard dynamic fields
|
||||
assert is_dynamic_field("values_#_key") is True
|
||||
assert is_dynamic_field("items_$_0") is True
|
||||
assert is_dynamic_field("obj_@_attr") is True
|
||||
|
||||
# Regular fields
|
||||
assert is_dynamic_field("regular") is False
|
||||
assert is_dynamic_field("with_underscore") is False
|
||||
|
||||
# Edge cases
|
||||
assert is_dynamic_field("") is False
|
||||
assert is_dynamic_field("_#_") is True # Just delimiter
|
||||
assert is_dynamic_field("a_#_") is True # Trailing delimiter
|
||||
|
||||
def test_sanitize_pin_name_with_tool_pins(self):
|
||||
"""Test sanitize_pin_name with various tool pin formats."""
|
||||
# Tool pins should return "tools"
|
||||
assert sanitize_pin_name("tools") == "tools"
|
||||
assert sanitize_pin_name("tools_^_node_~_field") == "tools"
|
||||
|
||||
# Dynamic fields should return base name
|
||||
assert sanitize_pin_name("values_#_key") == "values"
|
||||
assert sanitize_pin_name("items_$_0") == "items"
|
||||
|
||||
# Regular fields unchanged
|
||||
assert sanitize_pin_name("regular") == "regular"
|
||||
|
||||
|
||||
class TestDynamicFieldDescriptions:
|
||||
"""Tests for dynamic field description generation."""
|
||||
|
||||
def test_dict_field_description(self):
|
||||
"""Test description for dictionary fields."""
|
||||
desc = get_dynamic_field_description("values_#_user_name")
|
||||
|
||||
assert "Dictionary field" in desc
|
||||
assert "values['user_name']" in desc
|
||||
|
||||
def test_list_field_description(self):
|
||||
"""Test description for list fields."""
|
||||
desc = get_dynamic_field_description("items_$_0")
|
||||
|
||||
assert "List item 0" in desc
|
||||
assert "items[0]" in desc
|
||||
|
||||
def test_object_field_description(self):
|
||||
"""Test description for object fields."""
|
||||
desc = get_dynamic_field_description("user_@_email")
|
||||
|
||||
assert "Object attribute" in desc
|
||||
assert "user.email" in desc
|
||||
|
||||
def test_regular_field_description(self):
|
||||
"""Test description for regular (non-dynamic) fields."""
|
||||
desc = get_dynamic_field_description("regular_field")
|
||||
|
||||
assert desc == "Value for regular_field"
|
||||
|
||||
def test_description_with_numeric_key(self):
|
||||
"""Test description with numeric dictionary key."""
|
||||
desc = get_dynamic_field_description("values_#_123")
|
||||
|
||||
assert "Dictionary field" in desc
|
||||
assert "values['123']" in desc
|
||||
|
||||
|
||||
class TestParseExecutionOutputToolRouting:
|
||||
"""Tests for tool pin routing in parse_execution_output."""
|
||||
|
||||
def test_tool_pin_routing_exact_match(self):
|
||||
"""Test tool pin routing with exact match."""
|
||||
output_item = ("tools_^_node-123_~_field_name", "value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="field_name",
|
||||
)
|
||||
|
||||
assert result == "value"
|
||||
|
||||
def test_tool_pin_routing_node_mismatch(self):
|
||||
"""Test tool pin routing with node ID mismatch."""
|
||||
output_item = ("tools_^_node-123_~_field_name", "value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="different-node",
|
||||
sink_pin_name="field_name",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_tool_pin_routing_field_mismatch(self):
|
||||
"""Test tool pin routing with field name mismatch."""
|
||||
output_item = ("tools_^_node-123_~_field_name", "value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="different_field",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_tool_pin_missing_required_params(self):
|
||||
"""Test that tool pins require node_id and pin_name."""
|
||||
output_item = ("tools_^_node-123_~_field", "value")
|
||||
|
||||
with pytest.raises(ValueError, match="must be provided"):
|
||||
parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id=None,
|
||||
sink_pin_name="field",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="must be provided"):
|
||||
parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
|
||||
class TestParseExecutionOutputDynamicFields:
|
||||
"""Tests for dynamic field routing in parse_execution_output."""
|
||||
|
||||
def test_dict_field_extraction(self):
|
||||
"""Test extraction of dictionary field value."""
|
||||
# The output_item is (field_name, data_structure)
|
||||
data = {"key1": "value1", "key2": "value2"}
|
||||
output_item = ("values", data)
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="values_#_key1",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
assert result == "value1"
|
||||
|
||||
def test_list_field_extraction(self):
|
||||
"""Test extraction of list item value."""
|
||||
data = ["zero", "one", "two"]
|
||||
output_item = ("items", data)
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="items_$_1",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
assert result == "one"
|
||||
|
||||
def test_nested_field_extraction(self):
|
||||
"""Test extraction of nested field value."""
|
||||
data = {
|
||||
"users": [
|
||||
{"name": "Alice", "email": "alice@example.com"},
|
||||
{"name": "Bob", "email": "bob@example.com"},
|
||||
]
|
||||
}
|
||||
output_item = ("data", data)
|
||||
|
||||
# Access nested path
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="data_#_users",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
assert result == data["users"]
|
||||
|
||||
def test_missing_key_returns_none(self):
|
||||
"""Test that missing keys return None."""
|
||||
data = {"existing": "value"}
|
||||
output_item = ("values", data)
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="values_#_nonexistent",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_index_out_of_bounds_returns_none(self):
|
||||
"""Test that out-of-bounds indices return None."""
|
||||
data = ["zero", "one"]
|
||||
output_item = ("items", data)
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="items_$_99",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIsToolPin:
|
||||
"""Tests for is_tool_pin function."""
|
||||
|
||||
def test_tools_prefix(self):
|
||||
"""Test that 'tools_^_' prefix is recognized."""
|
||||
assert is_tool_pin("tools_^_node_~_field") is True
|
||||
assert is_tool_pin("tools_^_anything") is True
|
||||
|
||||
def test_tools_exact(self):
|
||||
"""Test that exact 'tools' is recognized."""
|
||||
assert is_tool_pin("tools") is True
|
||||
|
||||
def test_non_tool_pins(self):
|
||||
"""Test that non-tool pins are not recognized."""
|
||||
assert is_tool_pin("input") is False
|
||||
assert is_tool_pin("output") is False
|
||||
assert is_tool_pin("toolsomething") is False
|
||||
assert is_tool_pin("my_tools") is False
|
||||
assert is_tool_pin("") is False
|
||||
|
||||
|
||||
class TestMergeExecutionInputEdgeCases:
|
||||
"""Edge case tests for merge_execution_input."""
|
||||
|
||||
def test_empty_input(self):
|
||||
"""Test merging empty input."""
|
||||
result = merge_execution_input({})
|
||||
assert result == {}
|
||||
|
||||
def test_only_regular_fields(self):
|
||||
"""Test merging only regular fields (no dynamic)."""
|
||||
data = {"a": 1, "b": 2, "c": 3}
|
||||
result = merge_execution_input(data)
|
||||
assert result == data
|
||||
|
||||
def test_overwrite_behavior(self):
|
||||
"""Test behavior when same key is set multiple times."""
|
||||
# This shouldn't happen in practice, but test the behavior
|
||||
data = {
|
||||
"values_#_key": "first",
|
||||
}
|
||||
result = merge_execution_input(data)
|
||||
assert result["values"]["key"] == "first"
|
||||
|
||||
def test_numeric_string_keys(self):
|
||||
"""Test handling of numeric string keys in dict fields."""
|
||||
data = {
|
||||
"values_#_123": "numeric_key",
|
||||
"values_#_456": "another_numeric",
|
||||
}
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["values"]["123"] == "numeric_key"
|
||||
assert result["values"]["456"] == "another_numeric"
|
||||
|
||||
def test_special_characters_in_keys(self):
|
||||
"""Test handling of special characters in keys."""
|
||||
data = {
|
||||
"values_#_key-with-dashes": "value1",
|
||||
"values_#_key.with.dots": "value2",
|
||||
}
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["values"]["key-with-dashes"] == "value1"
|
||||
assert result["values"]["key.with.dots"] == "value2"
|
||||
|
||||
def test_deeply_nested_list(self):
|
||||
"""Test deeply nested list indices."""
|
||||
data = {
|
||||
"matrix_$_0_$_0": "0,0",
|
||||
"matrix_$_0_$_1": "0,1",
|
||||
"matrix_$_1_$_0": "1,0",
|
||||
"matrix_$_1_$_1": "1,1",
|
||||
}
|
||||
|
||||
# Note: Current implementation may not support this depth
|
||||
# Test documents expected behavior
|
||||
try:
|
||||
result = merge_execution_input(data)
|
||||
# If supported, verify structure
|
||||
except (KeyError, TypeError, IndexError):
|
||||
# Deep nesting may not be supported
|
||||
pass
|
||||
|
||||
def test_none_values(self):
|
||||
"""Test handling of None values in input."""
|
||||
data = {
|
||||
"regular": None,
|
||||
"dict_#_key": None,
|
||||
"list_$_0": None,
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["regular"] is None
|
||||
assert result["dict"]["key"] is None
|
||||
assert result["list"][0] is None
|
||||
|
||||
def test_complex_values(self):
|
||||
"""Test handling of complex values (dicts, lists)."""
|
||||
data = {
|
||||
"values_#_nested_dict": {"inner": "value"},
|
||||
"values_#_nested_list": [1, 2, 3],
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["values"]["nested_dict"] == {"inner": "value"}
|
||||
assert result["values"]["nested_list"] == [1, 2, 3]
|
||||
@@ -1,463 +0,0 @@
|
||||
"""
|
||||
Tests for dynamic field routing with sanitized names.
|
||||
|
||||
This test file specifically tests the parse_execution_output function
|
||||
which is responsible for routing tool outputs to the correct nodes.
|
||||
The critical bug this addresses is the mismatch between:
|
||||
- emit keys using sanitized names (e.g., "max_keyword_difficulty")
|
||||
- sink_pin_name using original names (e.g., "Max Keyword Difficulty")
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.dynamic_fields import (
|
||||
DICT_SPLIT,
|
||||
LIST_SPLIT,
|
||||
OBJC_SPLIT,
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
is_dynamic_field,
|
||||
is_tool_pin,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
sanitize_pin_name,
|
||||
)
|
||||
|
||||
|
||||
def cleanup(s: str) -> str:
|
||||
"""
|
||||
Simulate SmartDecisionMakerBlock.cleanup() for testing.
|
||||
Clean up names for use as tool function names.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
|
||||
class TestParseExecutionOutputToolRouting:
|
||||
"""Tests for tool pin routing in parse_execution_output."""
|
||||
|
||||
def test_exact_match_routes_correctly(self):
|
||||
"""When emit key field exactly matches sink_pin_name, routing works."""
|
||||
output_item = ("tools_^_node-123_~_query", "test value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="query",
|
||||
)
|
||||
assert result == "test value"
|
||||
|
||||
def test_sanitized_emit_vs_original_sink_fails(self):
|
||||
"""
|
||||
CRITICAL BUG TEST: When emit key uses sanitized name but sink uses original,
|
||||
routing fails.
|
||||
"""
|
||||
# Backend emits with sanitized name
|
||||
sanitized_field = cleanup("Max Keyword Difficulty")
|
||||
output_item = (f"tools_^_node-123_~_{sanitized_field}", 50)
|
||||
|
||||
# Frontend link has original name
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="Max Keyword Difficulty", # Original name
|
||||
)
|
||||
|
||||
# BUG: This returns None because sanitized != original
|
||||
# Once fixed, change this to: assert result == 50
|
||||
assert result is None, "Expected None due to sanitization mismatch bug"
|
||||
|
||||
def test_node_id_mismatch_returns_none(self):
|
||||
"""When node IDs don't match, routing should return None."""
|
||||
output_item = ("tools_^_node-123_~_query", "test value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="different-node", # Different node
|
||||
sink_pin_name="query",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_both_node_and_pin_must_match(self):
|
||||
"""Both node_id and pin_name must match for routing to succeed."""
|
||||
output_item = ("tools_^_node-123_~_query", "test value")
|
||||
|
||||
# Wrong node, right pin
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="wrong-node",
|
||||
sink_pin_name="query",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Right node, wrong pin
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="wrong_pin",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Right node, right pin
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="query",
|
||||
)
|
||||
assert result == "test value"
|
||||
|
||||
|
||||
class TestToolPinRoutingWithSpecialCharacters:
|
||||
"""Tests for tool pin routing with various special characters in names."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"original_name,sanitized_name",
|
||||
[
|
||||
("Max Keyword Difficulty", "max_keyword_difficulty"),
|
||||
("Search Volume (Monthly)", "search_volume__monthly_"),
|
||||
("CPC ($)", "cpc____"),
|
||||
("User's Input", "user_s_input"),
|
||||
("Query #1", "query__1"),
|
||||
("API.Response", "api_response"),
|
||||
("Field@Name", "field_name"),
|
||||
("Test\tTab", "test_tab"),
|
||||
("Test\nNewline", "test_newline"),
|
||||
],
|
||||
)
|
||||
def test_routing_mismatch_with_special_chars(self, original_name, sanitized_name):
|
||||
"""
|
||||
Test that various special characters cause routing mismatches.
|
||||
|
||||
This test documents the current buggy behavior where sanitized emit keys
|
||||
don't match original sink_pin_names.
|
||||
"""
|
||||
# Verify sanitization
|
||||
assert cleanup(original_name) == sanitized_name
|
||||
|
||||
# Backend emits with sanitized name
|
||||
output_item = (f"tools_^_node-123_~_{sanitized_name}", "value")
|
||||
|
||||
# Frontend link has original name
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name=original_name,
|
||||
)
|
||||
|
||||
# BUG: Returns None due to mismatch
|
||||
assert result is None, f"Routing should fail for '{original_name}' vs '{sanitized_name}'"
|
||||
|
||||
|
||||
class TestToolPinMissingParameters:
|
||||
"""Tests for missing required parameters in parse_execution_output."""
|
||||
|
||||
def test_missing_sink_node_id_raises_error(self):
|
||||
"""Missing sink_node_id should raise ValueError for tool pins."""
|
||||
output_item = ("tools_^_node-123_~_query", "test value")
|
||||
|
||||
with pytest.raises(ValueError, match="sink_node_id and sink_pin_name must be provided"):
|
||||
parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id=None,
|
||||
sink_pin_name="query",
|
||||
)
|
||||
|
||||
def test_missing_sink_pin_name_raises_error(self):
|
||||
"""Missing sink_pin_name should raise ValueError for tool pins."""
|
||||
output_item = ("tools_^_node-123_~_query", "test value")
|
||||
|
||||
with pytest.raises(ValueError, match="sink_node_id and sink_pin_name must be provided"):
|
||||
parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
|
||||
class TestIsToolPin:
|
||||
"""Tests for is_tool_pin function."""
|
||||
|
||||
def test_tools_prefix_is_tool_pin(self):
|
||||
"""Names starting with 'tools_^_' are tool pins."""
|
||||
assert is_tool_pin("tools_^_node_~_field") is True
|
||||
assert is_tool_pin("tools_^_anything") is True
|
||||
|
||||
def test_tools_exact_is_tool_pin(self):
|
||||
"""Exact 'tools' is a tool pin."""
|
||||
assert is_tool_pin("tools") is True
|
||||
|
||||
def test_non_tool_pins(self):
|
||||
"""Non-tool pin names should return False."""
|
||||
assert is_tool_pin("input") is False
|
||||
assert is_tool_pin("output") is False
|
||||
assert is_tool_pin("my_tools") is False
|
||||
assert is_tool_pin("toolsomething") is False
|
||||
|
||||
|
||||
class TestSanitizePinName:
|
||||
"""Tests for sanitize_pin_name function."""
|
||||
|
||||
def test_extracts_base_from_dynamic_field(self):
|
||||
"""Should extract base field name from dynamic fields."""
|
||||
assert sanitize_pin_name("values_#_key") == "values"
|
||||
assert sanitize_pin_name("items_$_0") == "items"
|
||||
assert sanitize_pin_name("obj_@_attr") == "obj"
|
||||
|
||||
def test_returns_tools_for_tool_pins(self):
|
||||
"""Tool pins should be sanitized to 'tools'."""
|
||||
assert sanitize_pin_name("tools_^_node_~_field") == "tools"
|
||||
assert sanitize_pin_name("tools") == "tools"
|
||||
|
||||
def test_regular_field_unchanged(self):
|
||||
"""Regular field names should be unchanged."""
|
||||
assert sanitize_pin_name("query") == "query"
|
||||
assert sanitize_pin_name("max_difficulty") == "max_difficulty"
|
||||
|
||||
|
||||
class TestDynamicFieldDescriptions:
|
||||
"""Tests for dynamic field description generation."""
|
||||
|
||||
def test_dict_field_description_with_spaces_in_key(self):
|
||||
"""Dictionary field keys with spaces should generate correct descriptions."""
|
||||
# After cleanup, "User Name" becomes "user_name" in the field name
|
||||
# But the original key might have had spaces
|
||||
desc = get_dynamic_field_description("values_#_user_name")
|
||||
assert "Dictionary field" in desc
|
||||
assert "values['user_name']" in desc
|
||||
|
||||
def test_list_field_description(self):
|
||||
"""List field descriptions should include index."""
|
||||
desc = get_dynamic_field_description("items_$_0")
|
||||
assert "List item 0" in desc
|
||||
assert "items[0]" in desc
|
||||
|
||||
def test_object_field_description(self):
|
||||
"""Object field descriptions should include attribute."""
|
||||
desc = get_dynamic_field_description("user_@_email")
|
||||
assert "Object attribute" in desc
|
||||
assert "user.email" in desc
|
||||
|
||||
|
||||
class TestMergeExecutionInput:
|
||||
"""Tests for merge_execution_input function."""
|
||||
|
||||
def test_merges_dict_fields(self):
|
||||
"""Dictionary fields should be merged into nested structure."""
|
||||
data = {
|
||||
"values_#_name": "Alice",
|
||||
"values_#_age": 30,
|
||||
"other_field": "unchanged",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "values" in result
|
||||
assert result["values"]["name"] == "Alice"
|
||||
assert result["values"]["age"] == 30
|
||||
assert result["other_field"] == "unchanged"
|
||||
|
||||
def test_merges_list_fields(self):
|
||||
"""List fields should be merged into arrays."""
|
||||
data = {
|
||||
"items_$_0": "first",
|
||||
"items_$_1": "second",
|
||||
"items_$_2": "third",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "items" in result
|
||||
assert result["items"] == ["first", "second", "third"]
|
||||
|
||||
def test_merges_mixed_fields(self):
|
||||
"""Mixed regular and dynamic fields should all be preserved."""
|
||||
data = {
|
||||
"regular": "value",
|
||||
"dict_#_key": "dict_value",
|
||||
"list_$_0": "list_item",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["regular"] == "value"
|
||||
assert result["dict"]["key"] == "dict_value"
|
||||
assert result["list"] == ["list_item"]
|
||||
|
||||
|
||||
class TestExtractBaseFieldName:
|
||||
"""Tests for extract_base_field_name function."""
|
||||
|
||||
def test_extracts_from_dict_delimiter(self):
|
||||
"""Should extract base name before _#_ delimiter."""
|
||||
assert extract_base_field_name("values_#_name") == "values"
|
||||
assert extract_base_field_name("user_#_email_#_domain") == "user"
|
||||
|
||||
def test_extracts_from_list_delimiter(self):
|
||||
"""Should extract base name before _$_ delimiter."""
|
||||
assert extract_base_field_name("items_$_0") == "items"
|
||||
assert extract_base_field_name("data_$_1_$_nested") == "data"
|
||||
|
||||
def test_extracts_from_object_delimiter(self):
|
||||
"""Should extract base name before _@_ delimiter."""
|
||||
assert extract_base_field_name("obj_@_attr") == "obj"
|
||||
|
||||
def test_no_delimiter_returns_original(self):
|
||||
"""Names without delimiters should be returned unchanged."""
|
||||
assert extract_base_field_name("regular_field") == "regular_field"
|
||||
assert extract_base_field_name("query") == "query"
|
||||
|
||||
|
||||
class TestIsDynamicField:
|
||||
"""Tests for is_dynamic_field function."""
|
||||
|
||||
def test_dict_delimiter_is_dynamic(self):
|
||||
"""Fields with _#_ are dynamic."""
|
||||
assert is_dynamic_field("values_#_key") is True
|
||||
|
||||
def test_list_delimiter_is_dynamic(self):
|
||||
"""Fields with _$_ are dynamic."""
|
||||
assert is_dynamic_field("items_$_0") is True
|
||||
|
||||
def test_object_delimiter_is_dynamic(self):
|
||||
"""Fields with _@_ are dynamic."""
|
||||
assert is_dynamic_field("obj_@_attr") is True
|
||||
|
||||
def test_regular_fields_not_dynamic(self):
|
||||
"""Regular field names without delimiters are not dynamic."""
|
||||
assert is_dynamic_field("regular_field") is False
|
||||
assert is_dynamic_field("query") is False
|
||||
assert is_dynamic_field("Max Keyword Difficulty") is False
|
||||
|
||||
|
||||
class TestRoutingEndToEnd:
|
||||
"""End-to-end tests for the full routing flow."""
|
||||
|
||||
def test_successful_routing_without_spaces(self):
|
||||
"""Full routing flow works when no spaces in names."""
|
||||
field_name = "query"
|
||||
node_id = "test-node-123"
|
||||
|
||||
# Emit key (as created by SmartDecisionMaker)
|
||||
emit_key = f"tools_^_{node_id}_~_{cleanup(field_name)}"
|
||||
output_item = (emit_key, "search term")
|
||||
|
||||
# Route (as called by executor)
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id=node_id,
|
||||
sink_pin_name=field_name,
|
||||
)
|
||||
|
||||
assert result == "search term"
|
||||
|
||||
def test_failed_routing_with_spaces(self):
|
||||
"""
|
||||
Full routing flow FAILS when names have spaces.
|
||||
|
||||
This test documents the exact bug scenario:
|
||||
1. Frontend creates link with sink_name="Max Keyword Difficulty"
|
||||
2. SmartDecisionMaker emits with sanitized name in key
|
||||
3. Executor calls parse_execution_output with original sink_pin_name
|
||||
4. Routing fails because names don't match
|
||||
"""
|
||||
original_field_name = "Max Keyword Difficulty"
|
||||
sanitized_field_name = cleanup(original_field_name)
|
||||
node_id = "test-node-123"
|
||||
|
||||
# Step 1 & 2: SmartDecisionMaker emits with sanitized name
|
||||
emit_key = f"tools_^_{node_id}_~_{sanitized_field_name}"
|
||||
output_item = (emit_key, 50)
|
||||
|
||||
# Step 3: Executor routes with original name from link
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id=node_id,
|
||||
sink_pin_name=original_field_name, # Original from link!
|
||||
)
|
||||
|
||||
# Step 4: BUG - Returns None instead of 50
|
||||
assert result is None
|
||||
|
||||
# This is what should happen after fix:
|
||||
# assert result == 50
|
||||
|
||||
def test_multiple_fields_with_spaces(self):
|
||||
"""Test routing multiple fields where some have spaces."""
|
||||
node_id = "test-node"
|
||||
|
||||
fields = {
|
||||
"query": "test", # No spaces - should work
|
||||
"Max Difficulty": 100, # Spaces - will fail
|
||||
"min_volume": 1000, # No spaces - should work
|
||||
}
|
||||
|
||||
results = {}
|
||||
for original_name, value in fields.items():
|
||||
sanitized = cleanup(original_name)
|
||||
emit_key = f"tools_^_{node_id}_~_{sanitized}"
|
||||
output_item = (emit_key, value)
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id=node_id,
|
||||
sink_pin_name=original_name,
|
||||
)
|
||||
results[original_name] = result
|
||||
|
||||
# Fields without spaces work
|
||||
assert results["query"] == "test"
|
||||
assert results["min_volume"] == 1000
|
||||
|
||||
# Fields with spaces fail
|
||||
assert results["Max Difficulty"] is None # BUG!
|
||||
|
||||
|
||||
class TestProposedFix:
|
||||
"""
|
||||
Tests for the proposed fix.
|
||||
|
||||
The fix should sanitize sink_pin_name before comparison in parse_execution_output.
|
||||
This class contains tests that will pass once the fix is implemented.
|
||||
"""
|
||||
|
||||
def test_routing_should_sanitize_both_sides(self):
|
||||
"""
|
||||
PROPOSED FIX: parse_execution_output should sanitize sink_pin_name
|
||||
before comparing with the field from emit key.
|
||||
|
||||
Current behavior: Direct string comparison
|
||||
Fixed behavior: Compare cleanup(target_input_pin) == cleanup(sink_pin_name)
|
||||
"""
|
||||
original_field = "Max Keyword Difficulty"
|
||||
sanitized_field = cleanup(original_field)
|
||||
node_id = "node-123"
|
||||
|
||||
emit_key = f"tools_^_{node_id}_~_{sanitized_field}"
|
||||
output_item = (emit_key, 50)
|
||||
|
||||
# Extract the comparison being made
|
||||
selector = emit_key[8:] # Remove "tools_^_"
|
||||
target_node_id, target_input_pin = selector.split("_~_", 1)
|
||||
|
||||
# Current comparison (FAILS):
|
||||
current_comparison = (target_input_pin == original_field)
|
||||
assert current_comparison is False, "Current comparison fails"
|
||||
|
||||
# Proposed fixed comparison (PASSES):
|
||||
# Either sanitize sink_pin_name, or sanitize both
|
||||
fixed_comparison = (target_input_pin == cleanup(original_field))
|
||||
assert fixed_comparison is True, "Fixed comparison should pass"
|
||||
@@ -2,11 +2,6 @@ import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
list_library_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -66,6 +61,8 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
|
||||
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
|
||||
@@ -48,8 +48,27 @@ from backend.data.notifications import (
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
@@ -76,24 +95,7 @@ from backend.util.retry import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .activity_status_generator import generate_activity_status_for_execution
|
||||
from .automod.manager import automod_manager
|
||||
from .cluster_lock import ClusterLock
|
||||
from .utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
@@ -114,40 +116,6 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||||
|
||||
|
||||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
"""
|
||||
Clear all insufficient funds notification flags for a user.
|
||||
|
||||
This should be called when a user tops up their credits, allowing
|
||||
Discord notifications to be sent again if they run out of funds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to clear notifications for.
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted.
|
||||
"""
|
||||
try:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
f"{user_id}: {e}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
@@ -178,7 +146,6 @@ async def execute_node(
|
||||
execution_processor: "ExecutionProcessor",
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -246,7 +213,6 @@ async def execute_node(
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
"execution_processor": execution_processor,
|
||||
"nodes_to_skip": nodes_to_skip or set(),
|
||||
}
|
||||
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
@@ -544,7 +510,6 @@ class ExecutionProcessor:
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -567,7 +532,6 @@ class ExecutionProcessor:
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
)
|
||||
if isinstance(status, BaseException):
|
||||
raise status
|
||||
@@ -613,7 +577,6 @@ class ExecutionProcessor:
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
@@ -650,7 +613,6 @@ class ExecutionProcessor:
|
||||
execution_processor=self,
|
||||
execution_stats=stats,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
):
|
||||
await persist_output(output_name, output_data)
|
||||
|
||||
@@ -962,21 +924,6 @@ class ExecutionProcessor:
|
||||
|
||||
queued_node_exec = execution_queue.get()
|
||||
|
||||
# Check if this node should be skipped due to optional credentials
|
||||
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
|
||||
log_metadata.info(
|
||||
f"Skipping node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id} - optional credentials not configured"
|
||||
)
|
||||
# Mark the node as completed without executing
|
||||
# No outputs will be produced, so downstream nodes won't trigger
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
exec_id=queued_node_exec.node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
)
|
||||
continue
|
||||
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id}",
|
||||
@@ -1037,7 +984,6 @@ class ExecutionProcessor:
|
||||
execution_stats,
|
||||
execution_stats_lock,
|
||||
),
|
||||
nodes_to_skip=graph_exec.nodes_to_skip,
|
||||
),
|
||||
self.node_execution_loop,
|
||||
)
|
||||
@@ -1317,40 +1263,12 @@ class ExecutionProcessor:
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
# Check if we've already sent a notification for this user+agent combo.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
try:
|
||||
redis_client = redis.get_redis()
|
||||
# SET NX returns True only if the key was newly set (didn't exist)
|
||||
is_new_notification = redis_client.set(
|
||||
redis_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||||
)
|
||||
if not is_new_notification:
|
||||
# Already notified for this user+agent, skip all notifications
|
||||
logger.debug(
|
||||
f"Skipping duplicate insufficient funds notification for "
|
||||
f"user={user_id}, graph={graph_id}"
|
||||
)
|
||||
return
|
||||
except Exception as redis_error:
|
||||
# If Redis fails, log and continue to send the notification
|
||||
# (better to occasionally duplicate than to never notify)
|
||||
logger.warning(
|
||||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||||
f"{redis_error}"
|
||||
)
|
||||
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
# Queue user email notification
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
@@ -1364,7 +1282,6 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
# Send Discord system alert
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
|
||||
@@ -1,560 +0,0 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import ZeroBalanceData
|
||||
from backend.executor.manager import (
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
|
||||
ExecutionProcessor,
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def async_iter(items):
|
||||
"""Helper to create an async iterator from a list."""
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that the first insufficient funds notification sends a Discord alert."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72, # $0.72
|
||||
amount=-714, # Attempting to spend $7.14
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate first-time notification (set returns True)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = True # Key was newly set
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify notification was queued
|
||||
mock_queue_notif.assert_called_once()
|
||||
notification_call = mock_queue_notif.call_args[0][0]
|
||||
assert notification_call.type == NotificationType.ZERO_BALANCE
|
||||
assert notification_call.user_id == user_id
|
||||
assert isinstance(notification_call.data, ZeroBalanceData)
|
||||
assert notification_call.data.current_balance == 72
|
||||
|
||||
# Verify Redis was checked with correct key pattern
|
||||
expected_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
mock_redis_client.set.assert_called_once()
|
||||
call_args = mock_redis_client.set.call_args
|
||||
assert call_args[0][0] == expected_key
|
||||
assert call_args[1]["nx"] is True
|
||||
|
||||
# Verify Discord alert was sent
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
discord_message = mock_client.discord_system_alert.call_args[0][0]
|
||||
assert "Insufficient Funds Alert" in discord_message
|
||||
assert "test@example.com" in discord_message
|
||||
assert "Test Agent" in discord_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate duplicate notification (set returns False/None)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = None # Key already existed
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was NOT queued (deduplication worked)
|
||||
mock_queue_notif.assert_not_called()
|
||||
|
||||
# Verify Discord alert was NOT sent (deduplication worked)
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that different agents for the same user get separate Discord alerts."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id_1 = "test-graph-111"
|
||||
graph_id_2 = "test-graph-222"
|
||||
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch("backend.executor.manager.queue_notification"), patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
# Both calls return True (first time for each agent)
|
||||
mock_redis_client.set.return_value = True
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# First agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_1,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Second agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_2,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify Discord alerts were sent for both agents
|
||||
assert mock_client.discord_system_alert.call_count == 2
|
||||
|
||||
# Verify Redis was called with different keys
|
||||
assert mock_redis_client.set.call_count == 2
|
||||
calls = mock_redis_client.set.call_args_list
|
||||
assert (
|
||||
calls[0][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_1}"
|
||||
)
|
||||
assert (
|
||||
calls[1][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_2}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
|
||||
"""Test that clearing notifications removes all keys for a user."""
|
||||
|
||||
user_id = "test-user-123"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return some keys as an async iterator
|
||||
mock_keys = [
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-2",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-3",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = async_iter(mock_keys)
|
||||
# delete is awaited, so use AsyncMock
|
||||
mock_redis_client.delete = AsyncMock(return_value=3)
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify correct pattern was used
|
||||
expected_pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
mock_redis_client.scan_iter.assert_called_once_with(match=expected_pattern)
|
||||
|
||||
# Verify delete was called with all keys
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
# Verify return value
|
||||
assert result == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestServer):
|
||||
"""Test clearing notifications when there are no keys to clear."""
|
||||
|
||||
user_id = "test-user-no-notifications"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return no keys as an async iterator
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify delete was not called
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
# Verify return value
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_handles_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that clearing notifications handles Redis errors gracefully."""
|
||||
|
||||
user_id = "test-user-redis-error"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock get_redis_async to raise an error
|
||||
mock_redis_module.get_redis_async = AsyncMock(
|
||||
side_effect=Exception("Redis connection failed")
|
||||
)
|
||||
|
||||
# Clear notifications should not raise, just return 0
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify it returned 0 (graceful failure)
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that both email and Discord notifications are still sent when Redis fails."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to raise an error
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.side_effect = Exception("Redis connection error")
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was still queued despite Redis error
|
||||
mock_queue_notif.assert_called_once()
|
||||
|
||||
# Verify Discord alert was still sent despite Redis error
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_grant(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding GRANT credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-grant-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
# Create a concrete instance
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with GRANT type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500, # Positive amount
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
is_active=True, # Active transaction
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding TOP_UP credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-topup-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
mock_redis_client.delete = AsyncMock(return_value=0)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with TOP_UP type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000, # Positive amount
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was attempted
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_inactive_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for inactive transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-inactive"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with is_active=False (should NOT clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=False, # Inactive - pending Stripe payment
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_usage_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for USAGE transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-usage"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with USAGE type (spending, should NOT clear)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-100, # Negative - spending credits
|
||||
transaction_type=CreditTransactionType.USAGE,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_clears_notifications(server: SpinTestServer):
|
||||
"""Test that _enable_transaction clears notification flags when enabling a TOP_UP."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-enable"
|
||||
|
||||
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
|
||||
"backend.data.credit.query_raw_with_schema"
|
||||
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock finding the pending transaction
|
||||
mock_transaction = MagicMock()
|
||||
mock_transaction.amount = 1000
|
||||
mock_transaction.type = CreditTransactionType.TOP_UP
|
||||
mock_credit_tx.prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_transaction
|
||||
)
|
||||
|
||||
# Mock the query to return updated balance
|
||||
mock_query.return_value = [{"balance": 1500}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _enable_transaction (simulates Stripe checkout completion)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
await credit_model._enable_transaction(
|
||||
transaction_key="cs_test_123",
|
||||
user_id=user_id,
|
||||
metadata=SafeJson({"payment": "completed"}),
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
@@ -3,16 +3,16 @@ import logging
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
|
||||
import backend.api.features.library.model
|
||||
import backend.api.features.store.model
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.api.rest_api import AgentServer
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.data_manipulation import FindInDictionaryBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.data.model import User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
@@ -356,7 +356,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -444,7 +444,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -485,7 +485,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
|
||||
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
|
||||
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
slug=test_graph.id,
|
||||
@@ -514,7 +514,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
|
||||
admin_user = await create_test_user(alt_user=True)
|
||||
await server.agent_server.test_review_store_listing(
|
||||
backend.api.features.store.model.ReviewSubmissionRequest(
|
||||
backend.server.v2.store.model.ReviewSubmissionRequest(
|
||||
store_listing_version_id=slv_id,
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
@@ -523,7 +523,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Add the approved store listing to the admin user's library so they can execute it
|
||||
from backend.api.features.library.db import add_store_agent_to_library
|
||||
from backend.server.v2.library.db import add_store_agent_to_library
|
||||
|
||||
await add_store_agent_to_library(
|
||||
store_listing_version_id=slv_id, user_id=admin_user.id
|
||||
|
||||
@@ -23,7 +23,6 @@ from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
@@ -243,12 +242,6 @@ def cleanup_expired_files():
|
||||
run_async(cleanup_expired_files_async())
|
||||
|
||||
|
||||
def cleanup_oauth_tokens():
|
||||
"""Clean up expired OAuth tokens from the database."""
|
||||
# Wait for completion
|
||||
run_async(cleanup_expired_oauth_tokens())
|
||||
|
||||
|
||||
def execution_accuracy_alerts():
|
||||
"""Check execution accuracy and send alerts if drops are detected."""
|
||||
return report_execution_accuracy_alerts()
|
||||
@@ -453,17 +446,6 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# OAuth Token Cleanup - configurable interval
|
||||
self.scheduler.add_job(
|
||||
cleanup_oauth_tokens,
|
||||
id="cleanup_oauth_tokens",
|
||||
trigger="interval",
|
||||
replace_existing=True,
|
||||
seconds=config.oauth_token_cleanup_interval_hours
|
||||
* 3600, # Convert hours to seconds
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Execution Accuracy Monitoring - configurable interval
|
||||
self.scheduler.add_job(
|
||||
execution_accuracy_alerts,
|
||||
@@ -622,11 +604,6 @@ class Scheduler(AppService):
|
||||
"""Manually trigger cleanup of expired cloud storage files."""
|
||||
return cleanup_expired_files()
|
||||
|
||||
@expose
|
||||
def execute_cleanup_oauth_tokens(self):
|
||||
"""Manually trigger cleanup of expired OAuth tokens."""
|
||||
return cleanup_oauth_tokens()
|
||||
|
||||
@expose
|
||||
def execute_report_execution_accuracy_alerts(self):
|
||||
"""Manually trigger execution accuracy alert checking."""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.data import db
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -239,19 +239,14 @@ async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[dict[str, dict[str, str]], set[str]]:
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors
|
||||
and a set of nodes that should be skipped due to optional missing credentials.
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
|
||||
"""
|
||||
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
@@ -261,46 +256,27 @@ async def _validate_node_input_credentials(
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
# Track if any credential field is missing for this node
|
||||
has_missing_credentials = False
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
try:
|
||||
# Check nodes_input_masks first, then input_default
|
||||
field_value = None
|
||||
if (
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
):
|
||||
field_value = node_input_mask[field_name]
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
elif field_name in node.input_default:
|
||||
# For optional credentials, don't use input_default - treat as missing
|
||||
# This prevents stale credential IDs from failing validation
|
||||
if node.credentials_optional:
|
||||
field_value = None
|
||||
else:
|
||||
field_value = node.input_default[field_name]
|
||||
|
||||
# Check if credentials are missing (None, empty, or not present)
|
||||
if field_value is None or (
|
||||
isinstance(field_value, dict) and not field_value.get("id")
|
||||
):
|
||||
has_missing_credentials = True
|
||||
# If node has credentials_optional flag, mark for skipping instead of error
|
||||
if node.credentials_optional:
|
||||
continue # Don't add error, will be marked for skip after loop
|
||||
else:
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(field_value)
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
# Missing credentials
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
except ValidationError as e:
|
||||
# Validation error means credentials were provided but invalid
|
||||
# This should always be an error, even if optional
|
||||
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
||||
continue
|
||||
|
||||
@@ -311,7 +287,6 @@ async def _validate_node_input_credentials(
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any errors fetching credentials
|
||||
# If credentials were explicitly configured but unavailable, it's an error
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
@@ -338,19 +313,7 @@ async def _validate_node_input_credentials(
|
||||
] = "Invalid credentials: type/provider mismatch"
|
||||
continue
|
||||
|
||||
# If node has optional credentials and any are missing, mark for skipping
|
||||
# But only if there are no other errors for this node
|
||||
if (
|
||||
has_missing_credentials
|
||||
and node.credentials_optional
|
||||
and node.id not in credential_errors
|
||||
):
|
||||
nodes_to_skip.add(node.id)
|
||||
logger.info(
|
||||
f"Node #{node.id} will be skipped: optional credentials not configured"
|
||||
)
|
||||
|
||||
return credential_errors, nodes_to_skip
|
||||
return credential_errors
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
@@ -392,25 +355,21 @@ async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node,
|
||||
along with a set of nodes that should be skipped due to optional missing credentials.
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node
|
||||
"""
|
||||
# Get input validation errors
|
||||
node_input_errors = GraphModel.validate_graph_get_errors(
|
||||
graph, for_run=True, nodes_input_masks=nodes_input_masks
|
||||
)
|
||||
|
||||
# Get credential input/availability/validation errors and nodes to skip
|
||||
node_credential_input_errors, nodes_to_skip = (
|
||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
||||
# Get credential input/availability/validation errors
|
||||
node_credential_input_errors = await _validate_node_input_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
|
||||
# Merge credential errors with structural errors
|
||||
@@ -419,7 +378,7 @@ async def validate_graph_with_credentials(
|
||||
node_input_errors[node_id] = {}
|
||||
node_input_errors[node_id].update(field_errors)
|
||||
|
||||
return node_input_errors, nodes_to_skip
|
||||
return node_input_errors
|
||||
|
||||
|
||||
async def _construct_starting_node_execution_input(
|
||||
@@ -427,7 +386,7 @@ async def _construct_starting_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
@@ -441,14 +400,11 @@ async def _construct_starting_node_execution_input(
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
|
||||
and the corresponding input data for that node.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured)
|
||||
]
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
# Use new validation function that includes credentials
|
||||
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
validation_errors = await validate_graph_with_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
n_error_nodes = len(validation_errors)
|
||||
@@ -489,7 +445,7 @@ async def _construct_starting_node_execution_input(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input, nodes_to_skip
|
||||
return nodes_input
|
||||
|
||||
|
||||
async def validate_and_construct_node_execution_input(
|
||||
@@ -500,7 +456,7 @@ async def validate_and_construct_node_execution_input(
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
is_sub_graph: bool = False,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
@@ -517,7 +473,6 @@ async def validate_and_construct_node_execution_input(
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured).
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
@@ -559,16 +514,14 @@ async def validate_and_construct_node_execution_input(
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
|
||||
starting_nodes_input, nodes_to_skip = (
|
||||
await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
starting_nodes_input = await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
|
||||
return graph, starting_nodes_input, nodes_input_masks
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
@@ -826,9 +779,6 @@ async def add_graph_execution(
|
||||
|
||||
# Use existing execution's compiled input masks
|
||||
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
|
||||
# For resumed executions, nodes_to_skip was already determined at creation time
|
||||
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
|
||||
else:
|
||||
@@ -837,7 +787,7 @@ async def add_graph_execution(
|
||||
)
|
||||
|
||||
# Create new execution
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
@@ -886,7 +836,6 @@ async def add_graph_execution(
|
||||
try:
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
|
||||
@@ -367,13 +367,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
||||
nodes_to_skip: set[str] = set()
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip,
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
@@ -459,212 +456,3 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
# Both executions should succeed (though they create different objects)
|
||||
assert result1 == mock_graph_exec
|
||||
assert result2 == mock_graph_exec_2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns nodes_to_skip set
|
||||
for nodes with credentials_optional=True and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=True
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-optional-creds"
|
||||
mock_node.credentials_optional = True
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in nodes_to_skip, not in errors
|
||||
assert mock_node.id in nodes_to_skip
|
||||
assert mock_node.id not in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_required_missing_creds_error(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns errors
|
||||
for nodes with credentials_optional=False and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=False (required)
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-required-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in errors, not in nodes_to_skip
|
||||
assert mock_node.id in errors
|
||||
assert "credentials" in errors[mock_node.id]
|
||||
assert "required" in errors[mock_node.id]["credentials"].lower()
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that validate_graph_with_credentials returns nodes_to_skip set
|
||||
from _validate_node_input_credentials.
|
||||
"""
|
||||
from backend.executor.utils import validate_graph_with_credentials
|
||||
|
||||
# Mock _validate_node_input_credentials to return specific values
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils._validate_node_input_credentials"
|
||||
)
|
||||
expected_errors = {"node1": {"field": "error"}}
|
||||
expected_nodes_to_skip = {"node2", "node3"}
|
||||
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
|
||||
|
||||
# Mock GraphModel with validate_graph_get_errors method
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.validate_graph_get_errors.return_value = {}
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip is passed through
|
||||
assert nodes_to_skip == expected_nodes_to_skip
|
||||
assert "node1" in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
"""
|
||||
Test that add_graph_execution properly passes nodes_to_skip
|
||||
to the graph execution entry.
|
||||
"""
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
# Mock data
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
inputs = {"test_input": "test_value"}
|
||||
graph_version = 1
|
||||
|
||||
# Mock the graph object
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = graph_version
|
||||
|
||||
# Starting nodes and masks
|
||||
starting_nodes_input = [("node1", {"input1": "value1"})]
|
||||
compiled_nodes_input_masks = {}
|
||||
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
|
||||
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_to_entry(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mocker.MagicMock()
|
||||
|
||||
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
|
||||
|
||||
# Setup mocks
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_udb = mocker.patch("backend.executor.utils.user_db")
|
||||
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
|
||||
# Setup returns - include nodes_to_skip in the tuple
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip, # This should be passed through
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
|
||||
mock_user = mocker.MagicMock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_settings = mocker.MagicMock()
|
||||
mock_settings.human_in_the_loop_safe_mode = True
|
||||
|
||||
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
|
||||
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
|
||||
mock_get_queue.return_value = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
|
||||
|
||||
# Call the function
|
||||
await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
@@ -8,7 +8,6 @@ from .discord import DiscordOAuthHandler
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
from .reddit import RedditOAuthHandler
|
||||
from .twitter import TwitterOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -21,7 +20,6 @@ _ORIGINAL_HANDLERS = [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
RedditOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
TodoistOAuthHandler,
|
||||
]
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class RedditOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Reddit OAuth 2.0 handler.
|
||||
|
||||
Based on the documentation at:
|
||||
- https://github.com/reddit-archive/reddit/wiki/OAuth2
|
||||
|
||||
Notes:
|
||||
- Reddit requires `duration=permanent` to get refresh tokens
|
||||
- Access tokens expire after 1 hour (3600 seconds)
|
||||
- Reddit requires HTTP Basic Auth for token requests
|
||||
- Reddit requires a unique User-Agent header
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.REDDIT
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = [
|
||||
"identity", # Get username, verify auth
|
||||
"read", # Access posts and comments
|
||||
"submit", # Submit new posts and comments
|
||||
"edit", # Edit own posts and comments
|
||||
"history", # Access user's post history
|
||||
"privatemessages", # Access inbox and send private messages
|
||||
"flair", # Access and set flair on posts/subreddits
|
||||
]
|
||||
|
||||
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
|
||||
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
|
||||
USERNAME_URL = "https://oauth.reddit.com/api/v1/me"
|
||||
REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
"""Generate Reddit OAuth 2.0 authorization URL"""
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
"duration": "permanent", # Required for refresh tokens
|
||||
}
|
||||
|
||||
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
"""Exchange authorization code for access tokens"""
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
|
||||
# Reddit requires HTTP Basic Auth for token requests
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.TOKEN_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise ValueError(
|
||||
f"Reddit token exchange failed: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
||||
|
||||
username = await self._get_username(tokens["access_token"])
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
||||
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
async def _get_username(self, access_token: str) -> str:
|
||||
"""Get the username from the access token"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
response = await Requests().get(self.USERNAME_URL, headers=headers)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"Failed to get Reddit username: {response.status}")
|
||||
|
||||
data = response.json()
|
||||
return data.get("name", "unknown")
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
"""Refresh access tokens using refresh token"""
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError("No refresh token available")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.TOKEN_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise ValueError(
|
||||
f"Reddit token refresh failed: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
||||
|
||||
username = await self._get_username(tokens["access_token"])
|
||||
|
||||
# Reddit may or may not return a new refresh token
|
||||
new_refresh_token = tokens.get("refresh_token")
|
||||
if new_refresh_token:
|
||||
refresh_token: SecretStr | None = SecretStr(new_refresh_token)
|
||||
elif credentials.refresh_token:
|
||||
# Keep the existing refresh token
|
||||
refresh_token = credentials.refresh_token
|
||||
else:
|
||||
refresh_token = None
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=refresh_token,
|
||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""Revoke the access token"""
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.REVOKE_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
# Reddit returns 204 No Content on successful revocation
|
||||
return response.ok
|
||||
@@ -149,10 +149,10 @@ async def setup_webhook_for_block(
|
||||
async def migrate_legacy_triggered_graphs():
|
||||
from prisma.models import AgentGraph
|
||||
|
||||
from backend.api.features.library.db import create_preset
|
||||
from backend.api.features.library.model import LibraryAgentPresetCreatable
|
||||
from backend.data.graph import AGENT_GRAPH_INCLUDE, GraphModel, set_node_webhook
|
||||
from backend.data.model import is_credentials_field_name
|
||||
from backend.server.v2.library.db import create_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPresetCreatable
|
||||
|
||||
triggered_graphs = [
|
||||
GraphModel.from_db(_graph)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.app import run_processes
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionEventType,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.conn_manager import ConnectionManager
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
29
autogpt_platform/backend/backend/server/external/api.py
vendored
Normal file
29
autogpt_platform/backend/backend/server/external/api.py
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes.integrations import integrations_router
|
||||
from .routes.tools import tools_router
|
||||
from .routes.v1 import v1_router
|
||||
|
||||
external_app = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_app.add_middleware(SecurityHeadersMiddleware)
|
||||
external_app.include_router(v1_router, prefix="/v1")
|
||||
external_app.include_router(tools_router, prefix="/v1")
|
||||
external_app.include_router(integrations_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_app,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
36
autogpt_platform/backend/backend/server/external/middleware.py
vendored
Normal file
36
autogpt_platform/backend/backend/server/external/middleware.py
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Base middleware for API key authentication"""
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
api_key_obj = await validate_api_key(api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return api_key_obj
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""Dependency function for checking specific permissions"""
|
||||
|
||||
async def check_permission(
|
||||
api_key: APIKeyInfo = Security(require_api_key),
|
||||
) -> APIKeyInfo:
|
||||
if not has_permission(api_key, permission):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"API key lacks the required permission '{permission}'",
|
||||
)
|
||||
return api_key
|
||||
|
||||
return check_permission
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user