mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-14 17:47:57 -05:00
Compare commits
1 Commits
claude-cod
...
swiftyos/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eda0f16cb9 |
@@ -1,37 +0,0 @@
|
||||
{
|
||||
"worktreeCopyPatterns": [
|
||||
".env*",
|
||||
".vscode/**",
|
||||
".auth/**",
|
||||
".claude/**",
|
||||
"autogpt_platform/.env*",
|
||||
"autogpt_platform/backend/.env*",
|
||||
"autogpt_platform/frontend/.env*",
|
||||
"autogpt_platform/frontend/.auth/**",
|
||||
"autogpt_platform/db/docker/.env*"
|
||||
],
|
||||
"worktreeCopyIgnores": [
|
||||
"**/node_modules/**",
|
||||
"**/dist/**",
|
||||
"**/.git/**",
|
||||
"**/Thumbs.db",
|
||||
"**/.DS_Store",
|
||||
"**/.next/**",
|
||||
"**/__pycache__/**",
|
||||
"**/.ruff_cache/**",
|
||||
"**/.pytest_cache/**",
|
||||
"**/*.pyc",
|
||||
"**/playwright-report/**",
|
||||
"**/logs/**",
|
||||
"**/site/**"
|
||||
],
|
||||
"worktreePathTemplate": "$BASE_PATH.worktree",
|
||||
"postCreateCmd": [
|
||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||
"cd autogpt_platform/frontend && pnpm install",
|
||||
"cd docs && pip install -r requirements.txt"
|
||||
],
|
||||
"terminalCommand": "code .",
|
||||
"deleteBranchWithWorktree": false
|
||||
}
|
||||
@@ -16,7 +16,6 @@
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
|
||||
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
12
.github/workflows/copilot-setup-steps.yml
vendored
12
.github/workflows/copilot-setup-steps.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -108,16 +108,6 @@ jobs:
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
# Remove large unused tools to free disk space for Docker builds
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -134,7 +134,7 @@ jobs:
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
|
||||
@@ -12,7 +12,6 @@ reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
@@ -34,7 +33,6 @@ init-env:
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to add 401 responses
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
# Wrap current method to allow stacking OpenAPI schema modifiers like this
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
|
||||
@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ import fastapi.testclient
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
from .v1.routes import v1_router
|
||||
|
||||
external_api = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||
external_api.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_api,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
@@ -1,340 +0,0 @@
|
||||
"""Tests for analytics API endpoints."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from .analytics import router as analytics_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_metric endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging."""
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "metric-123-uuid"
|
||||
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string,test_id",
|
||||
[
|
||||
(100, "api_calls_count", "external_api", "integer_value"),
|
||||
(0, "error_count", "no_errors", "zero_value"),
|
||||
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||
(999999999, "large_number", "max_value", "large_number"),
|
||||
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_various_values(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
metric_value: float,
|
||||
metric_name: str,
|
||||
data_string: str,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values."""
|
||||
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": metric_name,
|
||||
"metric_value": metric_value,
|
||||
"data_string": data_string,
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"metric_id": response.json(), "test_case": test_id},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
f"analytics_metric_{test_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"metric_name": "test"}, "Field required"),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
|
||||
"Input should be a valid number",
|
||||
),
|
||||
(
|
||||
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_metric_value_and_data_string",
|
||||
"invalid_metric_value_type",
|
||||
"empty_metric_name",
|
||||
"empty_data_string",
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid metric requests."""
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_metric_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "test_metric",
|
||||
"metric_value": 1.0,
|
||||
"data_string": "test",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Database connection failed" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_analytics endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging."""
|
||||
mock_result = Mock(id="analytics-789-uuid")
|
||||
mock_log_analytics = mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "user_action",
|
||||
"data": {
|
||||
"action": "button_click",
|
||||
"button_id": "submit_form",
|
||||
"timestamp": "2023-01-01T00:00:00Z",
|
||||
"metadata": {"form_type": "registration", "fields_filled": 5},
|
||||
},
|
||||
"data_index": "button_click_submit_form",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "analytics-789-uuid"
|
||||
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_analytics_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_complex_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw analytics logging with complex nested data structures."""
|
||||
mock_result = Mock(id="analytics-complex-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "agent_execution",
|
||||
"data": {
|
||||
"agent_id": "agent_123",
|
||||
"execution_id": "exec_456",
|
||||
"status": "completed",
|
||||
"duration_ms": 3500,
|
||||
"nodes_executed": 15,
|
||||
"blocks_used": [
|
||||
{"block_id": "llm_block", "count": 3},
|
||||
{"block_id": "http_block", "count": 5},
|
||||
{"block_id": "code_block", "count": 2},
|
||||
],
|
||||
"errors": [],
|
||||
"metadata": {
|
||||
"trigger": "manual",
|
||||
"user_tier": "premium",
|
||||
"environment": "production",
|
||||
},
|
||||
},
|
||||
"data_index": "agent_123_exec_456",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"analytics_id": response.json(), "logged_data": request_data["data"]},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
"analytics_log_analytics_complex_data",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"type": "test"}, "Field required"),
|
||||
(
|
||||
{"type": "test", "data": "not_a_dict", "data_index": "test"},
|
||||
"Input should be a valid dictionary",
|
||||
),
|
||||
({"type": "test", "data": {"key": "value"}}, "Field required"),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_data_and_data_index",
|
||||
"invalid_data_type",
|
||||
"missing_data_index",
|
||||
],
|
||||
)
|
||||
def test_log_raw_analytics_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid analytics requests."""
|
||||
response = client.post("/log_raw_analytics", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_analytics_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Analytics DB unreachable"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "test_event",
|
||||
"data": {"key": "value"},
|
||||
"data_index": "test_index",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Analytics DB unreachable" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
@@ -1,41 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def sort_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to sort the endpoints,
|
||||
schemas, and responses.
|
||||
"""
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
|
||||
# Sort endpoints
|
||||
openapi_schema["paths"] = dict(sorted(openapi_schema["paths"].items()))
|
||||
|
||||
# Sort endpoints -> methods
|
||||
for p in openapi_schema["paths"].keys():
|
||||
openapi_schema["paths"][p] = dict(
|
||||
sorted(openapi_schema["paths"][p].items())
|
||||
)
|
||||
|
||||
# Sort endpoints -> methods -> responses
|
||||
for m in openapi_schema["paths"][p].keys():
|
||||
openapi_schema["paths"][p][m]["responses"] = dict(
|
||||
sorted(openapi_schema["paths"][p][m]["responses"].items())
|
||||
)
|
||||
|
||||
# Sort schemas and responses as well
|
||||
for k in openapi_schema["components"].keys():
|
||||
openapi_schema["components"][k] = dict(
|
||||
sorted(openapi_schema["components"][k].items())
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
@@ -36,10 +36,10 @@ def main(**kwargs):
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
@@ -50,7 +49,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -82,7 +81,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
|
||||
@@ -6,9 +6,6 @@ import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
from prisma.types import Serializable
|
||||
|
||||
from backend.sdk import (
|
||||
BaseWebhooksManager,
|
||||
@@ -87,9 +84,7 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
# update webhook config
|
||||
await update_webhook(
|
||||
webhook.id,
|
||||
config=cast(
|
||||
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
|
||||
),
|
||||
config={"base_id": base_id, "cursor": response.cursor},
|
||||
)
|
||||
|
||||
event_type = "notification"
|
||||
|
||||
@@ -1,631 +0,0 @@
|
||||
import json
|
||||
import shlex
|
||||
import uuid
|
||||
from typing import Literal, Optional
|
||||
|
||||
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Test credentials for E2B
|
||||
TEST_E2B_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="e2b",
|
||||
api_key=SecretStr("mock-e2b-api-key"),
|
||||
title="Mock E2B API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_E2B_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_E2B_CREDENTIALS.provider,
|
||||
"id": TEST_E2B_CREDENTIALS.id,
|
||||
"type": TEST_E2B_CREDENTIALS.type,
|
||||
"title": TEST_E2B_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
# Test credentials for Anthropic
|
||||
TEST_ANTHROPIC_CREDENTIALS = APIKeyCredentials(
|
||||
id="2e568a2b-b2ea-475a-8564-9a676bf31c56",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr("mock-anthropic-api-key"),
|
||||
title="Mock Anthropic API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_ANTHROPIC_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_ANTHROPIC_CREDENTIALS.provider,
|
||||
"id": TEST_ANTHROPIC_CREDENTIALS.id,
|
||||
"type": TEST_ANTHROPIC_CREDENTIALS.type,
|
||||
"title": TEST_ANTHROPIC_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class ClaudeCodeBlock(Block):
|
||||
"""
|
||||
Execute tasks using Claude Code (Anthropic's AI coding assistant) in an E2B sandbox.
|
||||
|
||||
Claude Code can create files, install tools, run commands, and perform complex
|
||||
coding tasks autonomously within a secure sandbox environment.
|
||||
"""
|
||||
|
||||
# Use base template - we'll install Claude Code ourselves for latest version
|
||||
DEFAULT_TEMPLATE = "base"
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
e2b_credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"API key for the E2B platform to create the sandbox. "
|
||||
"Get one on the [e2b website](https://e2b.dev/docs)"
|
||||
),
|
||||
)
|
||||
|
||||
anthropic_credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.ANTHROPIC], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"API key for Anthropic to power Claude Code. "
|
||||
"Get one at [Anthropic's website](https://console.anthropic.com)"
|
||||
),
|
||||
)
|
||||
|
||||
prompt: str = SchemaField(
|
||||
description=(
|
||||
"The task or instruction for Claude Code to execute. "
|
||||
"Claude Code can create files, install packages, run commands, "
|
||||
"and perform complex coding tasks."
|
||||
),
|
||||
placeholder="Create a hello world index.html file",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
timeout: int = SchemaField(
|
||||
description=(
|
||||
"Sandbox timeout in seconds. Claude Code tasks can take "
|
||||
"a while, so set this appropriately for your task complexity. "
|
||||
"Note: This only applies when creating a new sandbox. "
|
||||
"When reconnecting to an existing sandbox via sandbox_id, "
|
||||
"the original timeout is retained."
|
||||
),
|
||||
default=300, # 5 minutes default
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
setup_commands: list[str] = SchemaField(
|
||||
description=(
|
||||
"Optional shell commands to run before executing Claude Code. "
|
||||
"Useful for installing dependencies or setting up the environment."
|
||||
),
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
working_directory: str = SchemaField(
|
||||
description="Working directory for Claude Code to operate in.",
|
||||
default="/home/user",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Session/continuation support
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID to resume a previous conversation. "
|
||||
"Leave empty for a new conversation. "
|
||||
"Use the session_id from a previous run to continue that conversation."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
sandbox_id: str = SchemaField(
|
||||
description=(
|
||||
"Sandbox ID to reconnect to an existing sandbox. "
|
||||
"Required when resuming a session (along with session_id). "
|
||||
"Use the sandbox_id from a previous run where dispose_sandbox was False."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Previous conversation history to continue from. "
|
||||
"Use this to restore context on a fresh sandbox if the previous one timed out. "
|
||||
"Pass the conversation_history output from a previous run."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description=(
|
||||
"Whether to dispose of the sandbox immediately after execution. "
|
||||
"Set to False if you want to continue the conversation later "
|
||||
"(you'll need both sandbox_id and session_id from the output)."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class FileOutput(BaseModel):
|
||||
"""A file extracted from the sandbox."""
|
||||
|
||||
path: str
|
||||
relative_path: str # Path relative to working directory (for GitHub, etc.)
|
||||
name: str
|
||||
content: str
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
response: str = SchemaField(
|
||||
description="The output/response from Claude Code execution"
|
||||
)
|
||||
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
|
||||
description=(
|
||||
"List of text files created/modified by Claude Code during this execution. "
|
||||
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
|
||||
)
|
||||
)
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Full conversation history including this turn. "
|
||||
"Pass this to conversation_history input to continue on a fresh sandbox "
|
||||
"if the previous sandbox timed out."
|
||||
)
|
||||
)
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID for this conversation. "
|
||||
"Pass this back along with sandbox_id to continue the conversation."
|
||||
)
|
||||
)
|
||||
sandbox_id: Optional[str] = SchemaField(
|
||||
description=(
|
||||
"ID of the sandbox instance. "
|
||||
"Pass this back along with session_id to continue the conversation. "
|
||||
"This is None if dispose_sandbox was True (sandbox was disposed)."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4e34f4a5-9b89-4326-ba77-2dd6750b7194",
|
||||
description=(
|
||||
"Execute tasks using Claude Code in an E2B sandbox. "
|
||||
"Claude Code can create files, install tools, run commands, "
|
||||
"and perform complex coding tasks autonomously."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.AI},
|
||||
input_schema=ClaudeCodeBlock.Input,
|
||||
output_schema=ClaudeCodeBlock.Output,
|
||||
test_credentials={
|
||||
"e2b_credentials": TEST_E2B_CREDENTIALS,
|
||||
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS,
|
||||
},
|
||||
test_input={
|
||||
"e2b_credentials": TEST_E2B_CREDENTIALS_INPUT,
|
||||
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS_INPUT,
|
||||
"prompt": "Create a hello world HTML file",
|
||||
"timeout": 300,
|
||||
"setup_commands": [],
|
||||
"working_directory": "/home/user",
|
||||
"session_id": "",
|
||||
"sandbox_id": "",
|
||||
"conversation_history": "",
|
||||
"dispose_sandbox": True,
|
||||
},
|
||||
test_output=[
|
||||
("response", "Created index.html with hello world content"),
|
||||
(
|
||||
"files",
|
||||
[
|
||||
{
|
||||
"path": "/home/user/index.html",
|
||||
"relative_path": "index.html",
|
||||
"name": "index.html",
|
||||
"content": "<html>Hello World</html>",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"conversation_history",
|
||||
"User: Create a hello world HTML file\n"
|
||||
"Claude: Created index.html with hello world content",
|
||||
),
|
||||
("session_id", str),
|
||||
("sandbox_id", None), # None because dispose_sandbox=True in test_input
|
||||
],
|
||||
test_mock={
|
||||
"execute_claude_code": lambda *args, **kwargs: (
|
||||
"Created index.html with hello world content", # response
|
||||
[
|
||||
ClaudeCodeBlock.FileOutput(
|
||||
path="/home/user/index.html",
|
||||
relative_path="index.html",
|
||||
name="index.html",
|
||||
content="<html>Hello World</html>",
|
||||
)
|
||||
], # files
|
||||
"User: Create a hello world HTML file\n"
|
||||
"Claude: Created index.html with hello world content", # conversation_history
|
||||
"test-session-id", # session_id
|
||||
"sandbox_id", # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_claude_code(
|
||||
self,
|
||||
e2b_api_key: str,
|
||||
anthropic_api_key: str,
|
||||
prompt: str,
|
||||
timeout: int,
|
||||
setup_commands: list[str],
|
||||
working_directory: str,
|
||||
session_id: str,
|
||||
existing_sandbox_id: str,
|
||||
conversation_history: str,
|
||||
dispose_sandbox: bool,
|
||||
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
|
||||
"""
|
||||
Execute Claude Code in an E2B sandbox.
|
||||
|
||||
Returns:
|
||||
Tuple of (response, files, conversation_history, session_id, sandbox_id)
|
||||
"""
|
||||
|
||||
# Validate that sandbox_id is provided when resuming a session
|
||||
if session_id and not existing_sandbox_id:
|
||||
raise ValueError(
|
||||
"sandbox_id is required when resuming a session with session_id. "
|
||||
"The session state is stored in the original sandbox. "
|
||||
"If the sandbox has timed out, use conversation_history instead "
|
||||
"to restore context on a fresh sandbox."
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
|
||||
try:
|
||||
# Either reconnect to existing sandbox or create a new one
|
||||
if existing_sandbox_id:
|
||||
# Reconnect to existing sandbox for conversation continuation
|
||||
sandbox = await BaseAsyncSandbox.connect(
|
||||
sandbox_id=existing_sandbox_id,
|
||||
api_key=e2b_api_key,
|
||||
)
|
||||
else:
|
||||
# Create new sandbox
|
||||
sandbox = await BaseAsyncSandbox.create(
|
||||
template=self.DEFAULT_TEMPLATE,
|
||||
api_key=e2b_api_key,
|
||||
timeout=timeout,
|
||||
envs={"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
)
|
||||
|
||||
# Install Claude Code from npm (ensures we get the latest version)
|
||||
install_result = await sandbox.commands.run(
|
||||
"npm install -g @anthropic-ai/claude-code@latest",
|
||||
timeout=120, # 2 min timeout for install
|
||||
)
|
||||
if install_result.exit_code != 0:
|
||||
raise Exception(
|
||||
f"Failed to install Claude Code: {install_result.stderr}"
|
||||
)
|
||||
|
||||
# Run any user-provided setup commands
|
||||
for cmd in setup_commands:
|
||||
setup_result = await sandbox.commands.run(cmd)
|
||||
if setup_result.exit_code != 0:
|
||||
raise Exception(
|
||||
f"Setup command failed: {cmd}\n"
|
||||
f"Exit code: {setup_result.exit_code}\n"
|
||||
f"Stdout: {setup_result.stdout}\n"
|
||||
f"Stderr: {setup_result.stderr}"
|
||||
)
|
||||
|
||||
# Generate or use provided session ID
|
||||
current_session_id = session_id if session_id else str(uuid.uuid4())
|
||||
|
||||
# Build base Claude flags
|
||||
base_flags = "-p --dangerously-skip-permissions --output-format json"
|
||||
|
||||
# Add conversation history context if provided (for fresh sandbox continuation)
|
||||
history_flag = ""
|
||||
if conversation_history and not session_id:
|
||||
# Inject previous conversation as context via system prompt
|
||||
# Use consistent escaping via _escape_prompt helper
|
||||
escaped_history = self._escape_prompt(
|
||||
f"Previous conversation context: {conversation_history}"
|
||||
)
|
||||
history_flag = f" --append-system-prompt {escaped_history}"
|
||||
|
||||
# Build Claude command based on whether we're resuming or starting new
|
||||
# Use shlex.quote for working_directory and session IDs to prevent injection
|
||||
safe_working_dir = shlex.quote(working_directory)
|
||||
if session_id:
|
||||
# Resuming existing session (sandbox still alive)
|
||||
safe_session_id = shlex.quote(session_id)
|
||||
claude_command = (
|
||||
f"cd {safe_working_dir} && "
|
||||
f"echo {self._escape_prompt(prompt)} | "
|
||||
f"claude --resume {safe_session_id} {base_flags}"
|
||||
)
|
||||
else:
|
||||
# New session with specific ID
|
||||
safe_current_session_id = shlex.quote(current_session_id)
|
||||
claude_command = (
|
||||
f"cd {safe_working_dir} && "
|
||||
f"echo {self._escape_prompt(prompt)} | "
|
||||
f"claude --session-id {safe_current_session_id} {base_flags}{history_flag}"
|
||||
)
|
||||
|
||||
# Capture timestamp before running Claude Code to filter files later
|
||||
# Capture timestamp 1 second in the past to avoid race condition with file creation
|
||||
timestamp_result = await sandbox.commands.run(
|
||||
"date -u -d '1 second ago' +%Y-%m-%dT%H:%M:%S"
|
||||
)
|
||||
if timestamp_result.exit_code != 0:
|
||||
raise RuntimeError(
|
||||
f"Failed to capture timestamp: {timestamp_result.stderr}"
|
||||
)
|
||||
start_timestamp = (
|
||||
timestamp_result.stdout.strip() if timestamp_result.stdout else None
|
||||
)
|
||||
|
||||
result = await sandbox.commands.run(
|
||||
claude_command,
|
||||
timeout=0, # No command timeout - let sandbox timeout handle it
|
||||
)
|
||||
|
||||
# Check for command failure
|
||||
if result.exit_code != 0:
|
||||
error_msg = result.stderr or result.stdout or "Unknown error"
|
||||
raise Exception(
|
||||
f"Claude Code command failed with exit code {result.exit_code}:\n"
|
||||
f"{error_msg}"
|
||||
)
|
||||
|
||||
raw_output = result.stdout or ""
|
||||
sandbox_id = sandbox.sandbox_id
|
||||
|
||||
# Parse JSON output to extract response and build conversation history
|
||||
response = ""
|
||||
new_conversation_history = conversation_history or ""
|
||||
|
||||
try:
|
||||
# The JSON output contains the result
|
||||
output_data = json.loads(raw_output)
|
||||
response = output_data.get("result", raw_output)
|
||||
|
||||
# Build conversation history entry
|
||||
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||
if new_conversation_history:
|
||||
new_conversation_history = (
|
||||
f"{new_conversation_history}\n\n{turn_entry}"
|
||||
)
|
||||
else:
|
||||
new_conversation_history = turn_entry
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# If not valid JSON, use raw output
|
||||
response = raw_output
|
||||
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||
if new_conversation_history:
|
||||
new_conversation_history = (
|
||||
f"{new_conversation_history}\n\n{turn_entry}"
|
||||
)
|
||||
else:
|
||||
new_conversation_history = turn_entry
|
||||
|
||||
# Extract files created/modified during this run
|
||||
files = await self._extract_files(
|
||||
sandbox, working_directory, start_timestamp
|
||||
)
|
||||
|
||||
return (
|
||||
response,
|
||||
files,
|
||||
new_conversation_history,
|
||||
current_session_id,
|
||||
sandbox_id,
|
||||
)
|
||||
|
||||
finally:
|
||||
if dispose_sandbox and sandbox:
|
||||
await sandbox.kill()
|
||||
|
||||
async def _extract_files(
|
||||
self,
|
||||
sandbox: BaseAsyncSandbox,
|
||||
working_directory: str,
|
||||
since_timestamp: str | None = None,
|
||||
) -> list["ClaudeCodeBlock.FileOutput"]:
|
||||
"""
|
||||
Extract text files created/modified during this Claude Code execution.
|
||||
|
||||
Args:
|
||||
sandbox: The E2B sandbox instance
|
||||
working_directory: Directory to search for files
|
||||
since_timestamp: ISO timestamp - only return files modified after this time
|
||||
|
||||
Returns:
|
||||
List of FileOutput objects with path, relative_path, name, and content
|
||||
"""
|
||||
files: list[ClaudeCodeBlock.FileOutput] = []
|
||||
|
||||
# Text file extensions we can safely read as text
|
||||
text_extensions = {
|
||||
".txt",
|
||||
".md",
|
||||
".html",
|
||||
".htm",
|
||||
".css",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".py",
|
||||
".rb",
|
||||
".php",
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rs",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".sql",
|
||||
".graphql",
|
||||
".env",
|
||||
".gitignore",
|
||||
".dockerfile",
|
||||
"Dockerfile",
|
||||
".vue",
|
||||
".svelte",
|
||||
".astro",
|
||||
".mdx",
|
||||
".rst",
|
||||
".tex",
|
||||
".csv",
|
||||
".log",
|
||||
}
|
||||
|
||||
try:
|
||||
# List files recursively using find command
|
||||
# Exclude node_modules and .git directories, but allow hidden files
|
||||
# like .env and .gitignore (they're filtered by text_extensions later)
|
||||
# Filter by timestamp to only get files created/modified during this run
|
||||
safe_working_dir = shlex.quote(working_directory)
|
||||
timestamp_filter = ""
|
||||
if since_timestamp:
|
||||
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
|
||||
find_result = await sandbox.commands.run(
|
||||
f"find {safe_working_dir} -type f "
|
||||
f"{timestamp_filter}"
|
||||
f"-not -path '*/node_modules/*' "
|
||||
f"-not -path '*/.git/*' "
|
||||
f"2>/dev/null"
|
||||
)
|
||||
|
||||
if find_result.stdout:
|
||||
for file_path in find_result.stdout.strip().split("\n"):
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Check if it's a text file we can read
|
||||
is_text = any(
|
||||
file_path.endswith(ext) for ext in text_extensions
|
||||
) or file_path.endswith("Dockerfile")
|
||||
|
||||
if is_text:
|
||||
try:
|
||||
content = await sandbox.files.read(file_path)
|
||||
# Handle bytes or string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
|
||||
# Extract filename from path
|
||||
file_name = file_path.split("/")[-1]
|
||||
|
||||
# Calculate relative path by stripping working directory
|
||||
relative_path = file_path
|
||||
if file_path.startswith(working_directory):
|
||||
relative_path = file_path[len(working_directory) :]
|
||||
# Remove leading slash if present
|
||||
if relative_path.startswith("/"):
|
||||
relative_path = relative_path[1:]
|
||||
|
||||
files.append(
|
||||
ClaudeCodeBlock.FileOutput(
|
||||
path=file_path,
|
||||
relative_path=relative_path,
|
||||
name=file_name,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Skip files that can't be read
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
# If file extraction fails, return empty results
|
||||
pass
|
||||
|
||||
return files
|
||||
|
||||
def _escape_prompt(self, prompt: str) -> str:
|
||||
"""Escape the prompt for safe shell execution."""
|
||||
# Use single quotes and escape any single quotes in the prompt
|
||||
escaped = prompt.replace("'", "'\"'\"'")
|
||||
return f"'{escaped}'"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
e2b_credentials: APIKeyCredentials,
|
||||
anthropic_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
(
|
||||
response,
|
||||
files,
|
||||
conversation_history,
|
||||
session_id,
|
||||
sandbox_id,
|
||||
) = await self.execute_claude_code(
|
||||
e2b_api_key=e2b_credentials.api_key.get_secret_value(),
|
||||
anthropic_api_key=anthropic_credentials.api_key.get_secret_value(),
|
||||
prompt=input_data.prompt,
|
||||
timeout=input_data.timeout,
|
||||
setup_commands=input_data.setup_commands,
|
||||
working_directory=input_data.working_directory,
|
||||
session_id=input_data.session_id,
|
||||
existing_sandbox_id=input_data.sandbox_id,
|
||||
conversation_history=input_data.conversation_history,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
# Always yield files (empty list if none) to match Output schema
|
||||
yield "files", [f.model_dump() for f in files]
|
||||
# Always yield conversation_history so user can restore context on fresh sandbox
|
||||
yield "conversation_history", conversation_history
|
||||
# Always yield session_id so user can continue conversation
|
||||
yield "session_id", session_id
|
||||
# Always yield sandbox_id (None if disposed) to match Output schema
|
||||
yield "sandbox_id", sandbox_id if not input_data.dispose_sandbox else None
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -182,10 +182,13 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
# Handle missing key, null value, or valid list value
|
||||
if isinstance(first_result, dict):
|
||||
items = first_result.get("items") or []
|
||||
else:
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,184 +0,0 @@
|
||||
"""
|
||||
Shared helpers for Human-In-The-Loop (HITL) review functionality.
|
||||
Used by both the dedicated HumanInTheLoopBlock and blocks that require human review.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReviewDecision(BaseModel):
|
||||
"""Result of a review decision."""
|
||||
|
||||
should_proceed: bool
|
||||
message: str
|
||||
review_result: ReviewResult
|
||||
|
||||
|
||||
class HITLReviewHelper:
|
||||
"""Helper class for Human-In-The-Loop review operations."""
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
||||
"""Create or retrieve a human review from the database."""
|
||||
return await get_database_manager_async_client().get_or_create_human_review(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_node_execution_status(**kwargs) -> None:
|
||||
"""Update the execution status of a node."""
|
||||
await async_update_node_execution_status(
|
||||
db_client=get_database_manager_async_client(), **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_review_processed_status(
|
||||
node_exec_id: str, processed: bool
|
||||
) -> None:
|
||||
"""Update the processed status of a review."""
|
||||
return await get_database_manager_async_client().update_review_processed_status(
|
||||
node_exec_id, processed
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _handle_review_request(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Handle a review request for a block that requires human review.
|
||||
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
Returns:
|
||||
ReviewResult if review is complete, None if waiting for human input
|
||||
|
||||
Raises:
|
||||
Exception: If review creation or status update fails
|
||||
"""
|
||||
# Skip review if safe mode is disabled - return auto-approved result
|
||||
if not execution_context.safe_mode:
|
||||
logger.info(
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
||||
)
|
||||
return ReviewResult(
|
||||
data=input_data,
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="Auto-approved (safe mode disabled)",
|
||||
processed=True,
|
||||
node_exec_id=node_exec_id,
|
||||
)
|
||||
|
||||
result = await HITLReviewHelper.get_or_create_human_review(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
input_data=input_data,
|
||||
message=f"Review required for {block_name} execution",
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
logger.info(
|
||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return None # Signal that execution should pause
|
||||
|
||||
# Mark review as processed if not already done
|
||||
if not result.processed:
|
||||
await HITLReviewHelper.update_review_processed_status(
|
||||
node_exec_id=node_exec_id, processed=True
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def handle_review_decision(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewDecision]:
|
||||
"""
|
||||
Handle a review request and return the decision in a single call.
|
||||
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
Returns:
|
||||
ReviewDecision if review is complete (approved/rejected),
|
||||
None if execution should pause (awaiting review)
|
||||
"""
|
||||
review_result = await HITLReviewHelper._handle_review_request(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
if review_result is None:
|
||||
# Still awaiting review - return None to pause execution
|
||||
return None
|
||||
|
||||
# Review is complete, determine outcome
|
||||
should_proceed = review_result.status == ReviewStatus.APPROVED
|
||||
message = review_result.message or (
|
||||
"Execution approved by reviewer"
|
||||
if should_proceed
|
||||
else "Execution rejected by reviewer"
|
||||
)
|
||||
|
||||
return ReviewDecision(
|
||||
should_proceed=should_proceed, message=message, review_result=review_result
|
||||
)
|
||||
@@ -3,7 +3,6 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.blocks.helpers.review import HITLReviewHelper
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -12,9 +11,11 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.data.model import SchemaField
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,26 +72,32 @@ class HumanInTheLoopBlock(Block):
|
||||
("approved_data", {"name": "John Doe", "age": 30}),
|
||||
],
|
||||
test_mock={
|
||||
"handle_review_decision": lambda **kwargs: type(
|
||||
"ReviewDecision",
|
||||
(),
|
||||
{
|
||||
"should_proceed": True,
|
||||
"message": "Test approval message",
|
||||
"review_result": ReviewResult(
|
||||
data={"name": "John Doe", "age": 30},
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="",
|
||||
processed=False,
|
||||
node_exec_id="test-node-exec-id",
|
||||
),
|
||||
},
|
||||
)(),
|
||||
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
|
||||
data={"name": "John Doe", "age": 30},
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="",
|
||||
processed=False,
|
||||
node_exec_id="test-node-exec-id",
|
||||
),
|
||||
"update_node_execution_status": lambda *_args, **_kwargs: None,
|
||||
"update_review_processed_status": lambda *_args, **_kwargs: None,
|
||||
},
|
||||
)
|
||||
|
||||
async def handle_review_decision(self, **kwargs):
|
||||
return await HITLReviewHelper.handle_review_decision(**kwargs)
|
||||
async def get_or_create_human_review(self, **kwargs):
|
||||
return await get_database_manager_async_client().get_or_create_human_review(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def update_node_execution_status(self, **kwargs):
|
||||
return await async_update_node_execution_status(
|
||||
db_client=get_database_manager_async_client(), **kwargs
|
||||
)
|
||||
|
||||
async def update_review_processed_status(self, node_exec_id: str, processed: bool):
|
||||
return await get_database_manager_async_client().update_review_processed_status(
|
||||
node_exec_id, processed
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -102,7 +109,7 @@ class HumanInTheLoopBlock(Block):
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
**_kwargs,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not execution_context.safe_mode:
|
||||
logger.info(
|
||||
@@ -112,28 +119,48 @@ class HumanInTheLoopBlock(Block):
|
||||
yield "review_message", "Auto-approved (safe mode disabled)"
|
||||
return
|
||||
|
||||
decision = await self.handle_review_decision(
|
||||
input_data=input_data.data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=input_data.editable,
|
||||
)
|
||||
try:
|
||||
result = await self.get_or_create_human_review(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
input_data=input_data.data,
|
||||
message=input_data.name,
|
||||
editable=input_data.editable,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
if decision is None:
|
||||
return
|
||||
if result is None:
|
||||
logger.info(
|
||||
f"HITL block pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
try:
|
||||
await self.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update node status for HITL block {node_exec_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
status = decision.review_result.status
|
||||
if status == ReviewStatus.APPROVED:
|
||||
yield "approved_data", decision.review_result.data
|
||||
elif status == ReviewStatus.REJECTED:
|
||||
yield "rejected_data", decision.review_result.data
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected review status: {status}")
|
||||
if not result.processed:
|
||||
await self.update_review_processed_status(
|
||||
node_exec_id=node_exec_id, processed=True
|
||||
)
|
||||
|
||||
if decision.message:
|
||||
yield "review_message", decision.message
|
||||
if result.status == ReviewStatus.APPROVED:
|
||||
yield "approved_data", result.data
|
||||
if result.message:
|
||||
yield "review_message", result.message
|
||||
|
||||
elif result.status == ReviewStatus.REJECTED:
|
||||
yield "rejected_data", result.data
|
||||
if result.message:
|
||||
yield "review_message", result.message
|
||||
|
||||
@@ -92,9 +92,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
@@ -195,9 +194,8 @@ MODEL_METADATA = {
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
|
||||
@@ -305,8 +303,6 @@ MODEL_METADATA = {
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
|
||||
}
|
||||
|
||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
@@ -794,7 +790,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -859,7 +855,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -1225,7 +1221,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -1321,7 +1317,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for summarizing the text.",
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
@@ -1538,7 +1534,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for the conversation.",
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
@@ -1576,7 +1572,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1639,7 +1635,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -1696,7 +1692,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,7 +18,6 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import DEFAULT_USER_AGENT
|
||||
|
||||
|
||||
class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
@@ -40,27 +39,17 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
output_schema=GetWikipediaSummaryBlock.Output,
|
||||
test_input={"topic": "Artificial Intelligence"},
|
||||
test_output=("summary", "summary content"),
|
||||
test_mock={
|
||||
"get_request": lambda url, headers, json: {"extract": "summary content"}
|
||||
},
|
||||
test_mock={"get_request": lambda url, json: {"extract": "summary content"}},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
topic = input_data.topic
|
||||
# URL-encode the topic to handle spaces and special characters
|
||||
encoded_topic = quote(topic, safe="")
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{encoded_topic}"
|
||||
|
||||
# Set headers per Wikimedia robot policy (https://w.wiki/4wJS)
|
||||
# - User-Agent: Required, must identify the bot
|
||||
# - Accept-Encoding: gzip recommended to reduce bandwidth
|
||||
headers = {
|
||||
"User-Agent": DEFAULT_USER_AGENT,
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
}
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
|
||||
# Note: User-Agent is now automatically set by the request library
|
||||
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
|
||||
try:
|
||||
response = await self.get_request(url, headers=headers, json=True)
|
||||
response = await self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise ValueError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
|
||||
@@ -226,7 +226,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=llm.DEFAULT_LLM_MODEL,
|
||||
default=llm.LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -391,12 +391,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
"""
|
||||
block = sink_node.block
|
||||
|
||||
# Use custom name from node metadata if set, otherwise fall back to block.name
|
||||
custom_name = sink_node.metadata.get("customized_name")
|
||||
tool_name = custom_name if custom_name else block.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"description": block.description,
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
@@ -493,24 +489,14 @@ class SmartDecisionMakerBlock(Block):
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
)
|
||||
|
||||
# Use custom name from node metadata if set, otherwise fall back to graph name
|
||||
custom_name = sink_node.metadata.get("customized_name")
|
||||
tool_name = custom_name if custom_name else sink_graph_meta.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
properties = {}
|
||||
field_mapping = {}
|
||||
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||
link.sink_name, {}
|
||||
@@ -520,7 +506,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
if "description" in sink_block_properties
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[clean_field_name] = {
|
||||
properties[link.sink_name] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||
@@ -533,7 +519,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
# Store node info for later use in output processing
|
||||
tool_function["_sink_node_id"] = sink_node.id
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
@@ -989,28 +975,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
nodes_to_skip: set[str] | None = None,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||
original_tool_count = len(tool_functions)
|
||||
|
||||
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
|
||||
if nodes_to_skip:
|
||||
tool_functions = [
|
||||
tf
|
||||
for tf in tool_functions
|
||||
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
|
||||
]
|
||||
|
||||
# Only raise error if we had tools but they were all filtered out
|
||||
if original_tool_count > 0 and not tool_functions:
|
||||
raise ValueError(
|
||||
"No available tools to execute - all downstream nodes are unavailable "
|
||||
"(possibly due to missing optional credentials)"
|
||||
)
|
||||
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
conversation_history = input_data.conversation_history or []
|
||||
@@ -1161,9 +1129,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
arg_value = tool_args.get(clean_arg_name)
|
||||
|
||||
# Use original_field_name directly (not sanitized) to match link sink_name
|
||||
# The field_mapping already translates from LLM's cleaned names to original names
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}"
|
||||
sanitized_arg_name = self.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
|
||||
@@ -196,15 +196,6 @@ class TestXMLParserBlockSecurity:
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
|
||||
pass
|
||||
|
||||
async def test_rejects_text_outside_root(self):
|
||||
"""Ensure parser surfaces readable errors for invalid root text."""
|
||||
block = XMLParserBlock()
|
||||
invalid_xml = "<root><child>value</child></root> trailing"
|
||||
|
||||
with pytest.raises(ValueError, match="text outside the root element"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
llm_model=llm.DEFAULT_LLM_MODEL,
|
||||
llm_model=llm.LlmModel.GPT4O,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
@@ -65,7 +65,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore # type: ignore
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AITextGeneratorBlock.Input(
|
||||
prompt="Generate text",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
@@ -228,7 +228,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text=long_text,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=100, # Small chunks
|
||||
chunk_overlap=10,
|
||||
@@ -299,7 +299,7 @@ class TestLLMStatsTracking:
|
||||
# Test with very short text (should only need 1 chunk + 1 final summary)
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="This is a short text.",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000, # Large enough to avoid chunking
|
||||
)
|
||||
@@ -346,7 +346,7 @@ class TestLLMStatsTracking:
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -387,7 +387,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AIListGeneratorBlock.Input(
|
||||
focus="test items",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_retries=3,
|
||||
)
|
||||
@@ -469,7 +469,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"result": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -513,7 +513,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
)
|
||||
@@ -558,7 +558,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
max_tokens=1000,
|
||||
@@ -593,7 +593,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -623,7 +623,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000,
|
||||
)
|
||||
@@ -654,7 +654,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ProviderName, User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
@@ -233,7 +233,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Create test input
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -335,7 +335,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
agent_mode_max_iterations=0,
|
||||
@@ -402,7 +402,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -462,7 +462,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -526,7 +526,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -648,7 +648,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
agent_mode_max_iterations=0,
|
||||
@@ -722,7 +722,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -778,7 +778,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -931,7 +931,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
|
||||
)
|
||||
@@ -1020,7 +1020,7 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0, # Traditional mode
|
||||
)
|
||||
@@ -1057,153 +1057,3 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
) # Should yield individual tool parameters
|
||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||
assert "conversations" in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the customized name (cleaned up)
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version
|
||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the block's default name
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "storevalueblock" # Default block name cleaned
|
||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-agent-node-id"
|
||||
mock_node.metadata = {"customized_name": "My Custom Agent"}
|
||||
mock_node.input_default = {
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_version": 1,
|
||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
||||
}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "test_input"
|
||||
|
||||
# Mock the database client
|
||||
mock_graph_meta = MagicMock()
|
||||
mock_graph_meta.name = "Original Agent Name"
|
||||
mock_graph_meta.description = "Agent description"
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the customized name (cleaned up)
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "my_custom_agent" # Cleaned version
|
||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
"""Test that agent node falls back to graph name when no customized_name."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-agent-node-id"
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.input_default = {
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_version": 1,
|
||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
||||
}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "test_input"
|
||||
|
||||
# Mock the database client
|
||||
mock_graph_meta = MagicMock()
|
||||
mock_graph_meta.name = "Original Agent Name"
|
||||
mock_graph_meta.description = "Agent description"
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the graph's default name
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "original_agent_name" # Graph name cleaned
|
||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
||||
|
||||
@@ -15,7 +15,6 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields
|
||||
mock_links = [
|
||||
@@ -78,7 +77,6 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
|
||||
@@ -44,7 +44,6 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
|
||||
mock_links = [
|
||||
@@ -107,7 +106,6 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
@@ -161,7 +159,6 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
mock_node.block = MatchTextPatternBlock()
|
||||
mock_node.block_id = MatchTextPatternBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic object fields
|
||||
mock_links = [
|
||||
@@ -211,13 +208,11 @@ async def test_create_tool_node_signatures():
|
||||
mock_dict_node.block = CreateDictionaryBlock()
|
||||
mock_dict_node.block_id = CreateDictionaryBlock().id
|
||||
mock_dict_node.input_default = {}
|
||||
mock_dict_node.metadata = {}
|
||||
|
||||
mock_list_node = Mock()
|
||||
mock_list_node.block = AddToListBlock()
|
||||
mock_list_node.block_id = AddToListBlock().id
|
||||
mock_list_node.input_default = {}
|
||||
mock_list_node.metadata = {}
|
||||
|
||||
# Mock links with dynamic fields
|
||||
dict_link1 = Mock(
|
||||
@@ -378,7 +373,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
input_data = block.input_schema(
|
||||
prompt="Create a user dictionary",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
|
||||
)
|
||||
|
||||
@@ -428,7 +423,6 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "A test block"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Mock the get_field_schema to return a proper schema for regular fields
|
||||
def get_field_schema(field_name):
|
||||
@@ -600,7 +594,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
retry=3, # Allow retries
|
||||
agent_mode_max_iterations=1,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .blog import WordPressCreatePostBlock, WordPressGetAllPostsBlock
|
||||
from .blog import WordPressCreatePostBlock
|
||||
|
||||
__all__ = ["WordPressCreatePostBlock", "WordPressGetAllPostsBlock"]
|
||||
__all__ = ["WordPressCreatePostBlock"]
|
||||
|
||||
@@ -161,7 +161,7 @@ async def oauth_exchange_code_for_tokens(
|
||||
grant_type="authorization_code",
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
response = await Requests(raise_for_status=False).post(
|
||||
response = await Requests().post(
|
||||
f"{WORDPRESS_BASE_URL}oauth2/token",
|
||||
headers=headers,
|
||||
data=data,
|
||||
@@ -205,7 +205,7 @@ async def oauth_refresh_tokens(
|
||||
grant_type="refresh_token",
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
response = await Requests(raise_for_status=False).post(
|
||||
response = await Requests().post(
|
||||
f"{WORDPRESS_BASE_URL}oauth2/token",
|
||||
headers=headers,
|
||||
data=data,
|
||||
@@ -252,7 +252,7 @@ async def validate_token(
|
||||
"token": token,
|
||||
}
|
||||
|
||||
response = await Requests(raise_for_status=False).get(
|
||||
response = await Requests().get(
|
||||
f"{WORDPRESS_BASE_URL}oauth2/token-info",
|
||||
params=params,
|
||||
)
|
||||
@@ -296,7 +296,7 @@ async def make_api_request(
|
||||
|
||||
url = f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}"
|
||||
|
||||
request_method = getattr(Requests(raise_for_status=False), method.lower())
|
||||
request_method = getattr(Requests(), method.lower())
|
||||
response = await request_method(
|
||||
url,
|
||||
headers=headers,
|
||||
@@ -476,7 +476,6 @@ async def create_post(
|
||||
data["tags"] = ",".join(str(t) for t in data["tags"])
|
||||
|
||||
# Make the API request
|
||||
site = normalize_site(site)
|
||||
endpoint = f"/rest/v1.1/sites/{site}/posts/new"
|
||||
|
||||
headers = {
|
||||
@@ -484,7 +483,7 @@ async def create_post(
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests(raise_for_status=False).post(
|
||||
response = await Requests().post(
|
||||
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
|
||||
headers=headers,
|
||||
data=data,
|
||||
@@ -500,132 +499,3 @@ async def create_post(
|
||||
)
|
||||
error_message = error_data.get("message", response.text)
|
||||
raise ValueError(f"Failed to create post: {response.status} - {error_message}")
|
||||
|
||||
|
||||
class Post(BaseModel):
|
||||
"""Response model for individual posts in a posts list response.
|
||||
|
||||
This is a simplified version compared to PostResponse, as the list endpoint
|
||||
returns less detailed information than the create/get single post endpoints.
|
||||
"""
|
||||
|
||||
ID: int
|
||||
site_ID: int
|
||||
author: PostAuthor
|
||||
date: datetime
|
||||
modified: datetime
|
||||
title: str
|
||||
URL: str
|
||||
short_URL: str
|
||||
content: str | None = None
|
||||
excerpt: str | None = None
|
||||
slug: str
|
||||
guid: str
|
||||
status: str
|
||||
sticky: bool
|
||||
password: str | None = ""
|
||||
parent: Union[Dict[str, Any], bool, None] = None
|
||||
type: str
|
||||
discussion: Dict[str, Union[str, bool, int]] | None = None
|
||||
likes_enabled: bool | None = None
|
||||
sharing_enabled: bool | None = None
|
||||
like_count: int | None = None
|
||||
i_like: bool | None = None
|
||||
is_reblogged: bool | None = None
|
||||
is_following: bool | None = None
|
||||
global_ID: str | None = None
|
||||
featured_image: str | None = None
|
||||
post_thumbnail: Dict[str, Any] | None = None
|
||||
format: str | None = None
|
||||
geo: Union[Dict[str, Any], bool, None] = None
|
||||
menu_order: int | None = None
|
||||
page_template: str | None = None
|
||||
publicize_URLs: List[str] | None = None
|
||||
terms: Dict[str, Dict[str, Any]] | None = None
|
||||
tags: Dict[str, Dict[str, Any]] | None = None
|
||||
categories: Dict[str, Dict[str, Any]] | None = None
|
||||
attachments: Dict[str, Dict[str, Any]] | None = None
|
||||
attachment_count: int | None = None
|
||||
metadata: List[Dict[str, Any]] | None = None
|
||||
meta: Dict[str, Any] | None = None
|
||||
capabilities: Dict[str, bool] | None = None
|
||||
revisions: List[int] | None = None
|
||||
other_URLs: Dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PostsResponse(BaseModel):
|
||||
"""Response model for WordPress posts list."""
|
||||
|
||||
found: int
|
||||
posts: List[Post]
|
||||
meta: Dict[str, Any]
|
||||
|
||||
|
||||
def normalize_site(site: str) -> str:
|
||||
"""
|
||||
Normalize a site identifier by stripping protocol and trailing slashes.
|
||||
|
||||
Args:
|
||||
site: Site URL, domain, or ID (e.g., "https://myblog.wordpress.com/", "myblog.wordpress.com", "123456789")
|
||||
|
||||
Returns:
|
||||
Normalized site identifier (domain or ID only)
|
||||
"""
|
||||
site = site.strip()
|
||||
if site.startswith("https://"):
|
||||
site = site[8:]
|
||||
elif site.startswith("http://"):
|
||||
site = site[7:]
|
||||
return site.rstrip("/")
|
||||
|
||||
|
||||
async def get_posts(
|
||||
credentials: Credentials,
|
||||
site: str,
|
||||
status: PostStatus | None = None,
|
||||
number: int = 100,
|
||||
offset: int = 0,
|
||||
) -> PostsResponse:
|
||||
"""
|
||||
Get posts from a WordPress site.
|
||||
|
||||
Args:
|
||||
credentials: OAuth credentials
|
||||
site: Site ID or domain (e.g., "myblog.wordpress.com" or "123456789")
|
||||
status: Filter by post status using PostStatus enum, or None for all
|
||||
number: Number of posts to retrieve (max 100)
|
||||
offset: Number of posts to skip (for pagination)
|
||||
|
||||
Returns:
|
||||
PostsResponse with the list of posts
|
||||
"""
|
||||
site = normalize_site(site)
|
||||
endpoint = f"/rest/v1.1/sites/{site}/posts"
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.auth_header(),
|
||||
}
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"number": max(1, min(number, 100)), # 1–100 posts per request
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
if status:
|
||||
params["status"] = status.value
|
||||
response = await Requests(raise_for_status=False).get(
|
||||
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.ok:
|
||||
return PostsResponse.model_validate(response.json())
|
||||
|
||||
error_data = (
|
||||
response.json()
|
||||
if response.headers.get("content-type", "").startswith("application/json")
|
||||
else {}
|
||||
)
|
||||
error_message = error_data.get("message", response.text)
|
||||
raise ValueError(f"Failed to get posts: {response.status} - {error_message}")
|
||||
|
||||
@@ -9,15 +9,7 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import (
|
||||
CreatePostRequest,
|
||||
Post,
|
||||
PostResponse,
|
||||
PostsResponse,
|
||||
PostStatus,
|
||||
create_post,
|
||||
get_posts,
|
||||
)
|
||||
from ._api import CreatePostRequest, PostResponse, PostStatus, create_post
|
||||
from ._config import wordpress
|
||||
|
||||
|
||||
@@ -57,15 +49,8 @@ class WordPressCreatePostBlock(Block):
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="URLs of images to sideload and attach to the post", default=[]
|
||||
)
|
||||
publish_as_draft: bool = SchemaField(
|
||||
description="If True, publishes the post as a draft. If False, publishes it publicly.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
site: str = SchemaField(
|
||||
description="The site ID or domain (pass-through for chaining with other blocks)"
|
||||
)
|
||||
post_id: int = SchemaField(description="The ID of the created post")
|
||||
post_url: str = SchemaField(description="The full URL of the created post")
|
||||
short_url: str = SchemaField(description="The shortened wp.me URL")
|
||||
@@ -93,9 +78,7 @@ class WordPressCreatePostBlock(Block):
|
||||
tags=input_data.tags,
|
||||
featured_image=input_data.featured_image,
|
||||
media_urls=input_data.media_urls,
|
||||
status=(
|
||||
PostStatus.DRAFT if input_data.publish_as_draft else PostStatus.PUBLISH
|
||||
),
|
||||
status=PostStatus.PUBLISH,
|
||||
)
|
||||
|
||||
post_response: PostResponse = await create_post(
|
||||
@@ -104,69 +87,7 @@ class WordPressCreatePostBlock(Block):
|
||||
post_data=post_request,
|
||||
)
|
||||
|
||||
yield "site", input_data.site
|
||||
yield "post_id", post_response.ID
|
||||
yield "post_url", post_response.URL
|
||||
yield "short_url", post_response.short_URL
|
||||
yield "post_data", post_response.model_dump()
|
||||
|
||||
|
||||
class WordPressGetAllPostsBlock(Block):
|
||||
"""
|
||||
Fetches all posts from a WordPress.com site or Jetpack-enabled site.
|
||||
Supports filtering by status and pagination.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = wordpress.credentials_field()
|
||||
site: str = SchemaField(
|
||||
description="Site ID or domain (e.g., 'myblog.wordpress.com' or '123456789')"
|
||||
)
|
||||
status: PostStatus | None = SchemaField(
|
||||
description="Filter by post status, or None for all",
|
||||
default=None,
|
||||
)
|
||||
number: int = SchemaField(
|
||||
description="Number of posts to retrieve (max 100 per request)", default=20
|
||||
)
|
||||
offset: int = SchemaField(
|
||||
description="Number of posts to skip (for pagination)", default=0
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
site: str = SchemaField(
|
||||
description="The site ID or domain (pass-through for chaining with other blocks)"
|
||||
)
|
||||
found: int = SchemaField(description="Total number of posts found")
|
||||
posts: list[Post] = SchemaField(
|
||||
description="List of post objects with their details"
|
||||
)
|
||||
post: Post = SchemaField(
|
||||
description="Individual post object (yielded for each post)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="97728fa7-7f6f-4789-ba0c-f2c114119536",
|
||||
description="Fetch all posts from WordPress.com or Jetpack sites",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
posts_response: PostsResponse = await get_posts(
|
||||
credentials=credentials,
|
||||
site=input_data.site,
|
||||
status=input_data.status,
|
||||
number=input_data.number,
|
||||
offset=input_data.offset,
|
||||
)
|
||||
|
||||
yield "site", input_data.site
|
||||
yield "found", posts_response.found
|
||||
yield "posts", posts_response.posts
|
||||
for post in posts_response.posts:
|
||||
yield "post", post
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from gravitasml.parser import Parser
|
||||
from gravitasml.token import Token, tokenize
|
||||
from gravitasml.token import tokenize
|
||||
|
||||
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.model import SchemaField
|
||||
@@ -25,38 +25,6 @@ class XMLParserBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_tokens(tokens: list[Token]) -> None:
|
||||
"""Ensure the XML has a single root element and no stray text."""
|
||||
if not tokens:
|
||||
raise ValueError("XML input is empty.")
|
||||
|
||||
depth = 0
|
||||
root_seen = False
|
||||
|
||||
for token in tokens:
|
||||
if token.type == "TAG_OPEN":
|
||||
if depth == 0 and root_seen:
|
||||
raise ValueError("XML must have a single root element.")
|
||||
depth += 1
|
||||
if depth == 1:
|
||||
root_seen = True
|
||||
elif token.type == "TAG_CLOSE":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise SyntaxError("Unexpected closing tag in XML input.")
|
||||
elif token.type in {"TEXT", "ESCAPE"}:
|
||||
if depth == 0 and token.value:
|
||||
raise ValueError(
|
||||
"XML contains text outside the root element; "
|
||||
"wrap content in a single root tag."
|
||||
)
|
||||
|
||||
if depth != 0:
|
||||
raise SyntaxError("Unclosed tag detected in XML input.")
|
||||
if not root_seen:
|
||||
raise ValueError("XML must include a root element.")
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add size limits to prevent XML bomb attacks
|
||||
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
|
||||
@@ -67,9 +35,7 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
try:
|
||||
tokens = list(tokenize(input_data.input_xml))
|
||||
self._validate_tokens(tokens)
|
||||
|
||||
tokens = tokenize(input_data.input_xml)
|
||||
parser = Parser(tokens)
|
||||
parsed_result = parser.parse()
|
||||
yield "parsed_xml", parsed_result
|
||||
|
||||
@@ -111,8 +111,6 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path[:3] == "/v/":
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path.startswith("/shorts/"):
|
||||
return parsed_url.path.split("/")[2]
|
||||
raise ValueError(f"Invalid YouTube URL: {url}")
|
||||
|
||||
def get_transcript(
|
||||
|
||||
@@ -244,7 +244,11 @@ def websocket(server_address: str, graph_exec_id: str):
|
||||
|
||||
import websockets.asyncio.client
|
||||
|
||||
from backend.api.ws_api import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
|
||||
from backend.server.ws_api import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
)
|
||||
|
||||
async def send_message(server_address: str):
|
||||
uri = f"ws://{server_address}"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""
|
||||
Script to generate OpenAPI JSON specification for the FastAPI app.
|
||||
|
||||
This script imports the FastAPI app from backend.api.rest_api and outputs
|
||||
This script imports the FastAPI app from backend.server.rest_api and outputs
|
||||
the OpenAPI specification as JSON to stdout or a specified file.
|
||||
|
||||
Usage:
|
||||
@@ -46,7 +46,7 @@ def main(output: Path, pretty: bool):
|
||||
|
||||
def get_openapi_schema():
|
||||
"""Get the OpenAPI schema from the FastAPI app"""
|
||||
from backend.api.rest_api import app
|
||||
from backend.server.rest_api import app
|
||||
|
||||
return app.openapi()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
58
autogpt_platform/backend/backend/data/auth/__init__.py
Normal file
58
autogpt_platform/backend/backend/data/auth/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Native authentication module for AutoGPT Platform.
|
||||
|
||||
This module provides authentication functionality that replaces Supabase Auth,
|
||||
including:
|
||||
- Password hashing with Argon2id
|
||||
- JWT token generation and validation
|
||||
- Magic links for email verification and password reset
|
||||
- Email service for auth-related emails
|
||||
- User migration from Supabase
|
||||
|
||||
Usage:
|
||||
from backend.data.auth.password import hash_password, verify_password
|
||||
from backend.data.auth.tokens import create_access_token, create_token_pair
|
||||
from backend.data.auth.magic_links import create_password_reset_link
|
||||
from backend.data.auth.email_service import get_auth_email_service
|
||||
"""
|
||||
|
||||
from backend.data.auth.email_service import AuthEmailService, get_auth_email_service
|
||||
from backend.data.auth.magic_links import (
|
||||
MagicLinkPurpose,
|
||||
create_email_verification_link,
|
||||
create_password_reset_link,
|
||||
verify_email_token,
|
||||
verify_password_reset_token,
|
||||
)
|
||||
from backend.data.auth.password import hash_password, needs_rehash, verify_password
|
||||
from backend.data.auth.tokens import (
|
||||
TokenPair,
|
||||
create_access_token,
|
||||
create_token_pair,
|
||||
decode_access_token,
|
||||
revoke_all_user_refresh_tokens,
|
||||
validate_refresh_token,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Password
|
||||
"hash_password",
|
||||
"verify_password",
|
||||
"needs_rehash",
|
||||
# Tokens
|
||||
"TokenPair",
|
||||
"create_access_token",
|
||||
"create_token_pair",
|
||||
"decode_access_token",
|
||||
"validate_refresh_token",
|
||||
"revoke_all_user_refresh_tokens",
|
||||
# Magic Links
|
||||
"MagicLinkPurpose",
|
||||
"create_email_verification_link",
|
||||
"create_password_reset_link",
|
||||
"verify_email_token",
|
||||
"verify_password_reset_token",
|
||||
# Email Service
|
||||
"AuthEmailService",
|
||||
"get_auth_email_service",
|
||||
]
|
||||
271
autogpt_platform/backend/backend/data/auth/email_service.py
Normal file
271
autogpt_platform/backend/backend/data/auth/email_service.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Email service for authentication flows.
|
||||
|
||||
Uses Postmark to send transactional emails for:
|
||||
- Email verification
|
||||
- Password reset
|
||||
- Account security notifications
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
|
||||
from jinja2 import Template
|
||||
from postmarker.core import PostmarkClient
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Template directory
|
||||
TEMPLATE_DIR = pathlib.Path(__file__).parent / "templates"
|
||||
|
||||
|
||||
class AuthEmailService:
|
||||
"""Email service for authentication-related emails."""
|
||||
|
||||
def __init__(self):
|
||||
if settings.secrets.postmark_server_api_token:
|
||||
self.postmark = PostmarkClient(
|
||||
server_token=settings.secrets.postmark_server_api_token
|
||||
)
|
||||
self.enabled = True
|
||||
else:
|
||||
logger.warning(
|
||||
"Postmark server API token not found, auth emails disabled"
|
||||
)
|
||||
self.postmark = None
|
||||
self.enabled = False
|
||||
|
||||
self.sender_email = settings.config.postmark_sender_email
|
||||
self.frontend_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
def _send_email(
|
||||
self,
|
||||
to_email: str,
|
||||
subject: str,
|
||||
html_body: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Send an email via Postmark.
|
||||
|
||||
Returns True if sent successfully, False otherwise.
|
||||
"""
|
||||
if not self.enabled or not self.postmark:
|
||||
logger.warning(f"Email not sent (disabled): {subject} to {to_email}")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.postmark.emails.send(
|
||||
From=self.sender_email,
|
||||
To=to_email,
|
||||
Subject=subject,
|
||||
HtmlBody=html_body,
|
||||
)
|
||||
logger.info(f"Auth email sent: {subject} to {to_email}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send auth email: {e}")
|
||||
return False
|
||||
|
||||
def send_verification_email(self, email: str, token: str) -> bool:
|
||||
"""
|
||||
Send email verification link.
|
||||
|
||||
Args:
|
||||
email: Recipient email address
|
||||
token: Verification token
|
||||
|
||||
Returns:
|
||||
True if sent successfully
|
||||
"""
|
||||
verify_url = f"{self.frontend_url}/auth/verify-email?token={token}"
|
||||
|
||||
subject = "Verify your email address"
|
||||
html_body = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<style>
|
||||
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; padding: 20px; }}
|
||||
.button {{ display: inline-block; padding: 12px 24px; background-color: #5046e5; color: white; text-decoration: none; border-radius: 6px; font-weight: 500; }}
|
||||
.footer {{ margin-top: 30px; font-size: 12px; color: #666; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>Verify your email address</h2>
|
||||
<p>Thanks for signing up! Please verify your email address by clicking the button below:</p>
|
||||
<p style="margin: 30px 0;">
|
||||
<a href="{verify_url}" class="button">Verify Email</a>
|
||||
</p>
|
||||
<p>Or copy and paste this link into your browser:</p>
|
||||
<p style="word-break: break-all; color: #666;">{verify_url}</p>
|
||||
<p>This link will expire in 24 hours.</p>
|
||||
<div class="footer">
|
||||
<p>If you didn't create an account, you can safely ignore this email.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return self._send_email(email, subject, html_body)
|
||||
|
||||
def send_password_reset_email(self, email: str, token: str) -> bool:
|
||||
"""
|
||||
Send password reset link.
|
||||
|
||||
Args:
|
||||
email: Recipient email address
|
||||
token: Password reset token
|
||||
|
||||
Returns:
|
||||
True if sent successfully
|
||||
"""
|
||||
reset_url = f"{self.frontend_url}/reset-password?token={token}"
|
||||
|
||||
subject = "Reset your password"
|
||||
html_body = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<style>
|
||||
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; padding: 20px; }}
|
||||
.button {{ display: inline-block; padding: 12px 24px; background-color: #5046e5; color: white; text-decoration: none; border-radius: 6px; font-weight: 500; }}
|
||||
.warning {{ background-color: #fef3c7; border: 1px solid #f59e0b; padding: 12px; border-radius: 6px; margin: 20px 0; }}
|
||||
.footer {{ margin-top: 30px; font-size: 12px; color: #666; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>Reset your password</h2>
|
||||
<p>We received a request to reset your password. Click the button below to choose a new password:</p>
|
||||
<p style="margin: 30px 0;">
|
||||
<a href="{reset_url}" class="button">Reset Password</a>
|
||||
</p>
|
||||
<p>Or copy and paste this link into your browser:</p>
|
||||
<p style="word-break: break-all; color: #666;">{reset_url}</p>
|
||||
<div class="warning">
|
||||
<strong>This link will expire in 15 minutes.</strong>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>If you didn't request a password reset, you can safely ignore this email. Your password will remain unchanged.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return self._send_email(email, subject, html_body)
|
||||
|
||||
def send_password_changed_notification(self, email: str) -> bool:
|
||||
"""
|
||||
Send notification that password was changed.
|
||||
|
||||
Args:
|
||||
email: Recipient email address
|
||||
|
||||
Returns:
|
||||
True if sent successfully
|
||||
"""
|
||||
subject = "Your password was changed"
|
||||
html_body = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<style>
|
||||
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; padding: 20px; }}
|
||||
.warning {{ background-color: #fee2e2; border: 1px solid #ef4444; padding: 12px; border-radius: 6px; margin: 20px 0; }}
|
||||
.footer {{ margin-top: 30px; font-size: 12px; color: #666; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>Password Changed</h2>
|
||||
<p>Your password was successfully changed.</p>
|
||||
<div class="warning">
|
||||
<strong>If you didn't make this change</strong>, please contact support immediately and reset your password.
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>This is an automated security notification.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return self._send_email(email, subject, html_body)
|
||||
|
||||
def send_migrated_user_password_reset(self, email: str, token: str) -> bool:
|
||||
"""
|
||||
Send password reset email for users migrated from Supabase.
|
||||
|
||||
Args:
|
||||
email: Recipient email address
|
||||
token: Password reset token
|
||||
|
||||
Returns:
|
||||
True if sent successfully
|
||||
"""
|
||||
reset_url = f"{self.frontend_url}/reset-password?token={token}"
|
||||
|
||||
subject = "Action Required: Set your password"
|
||||
html_body = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<style>
|
||||
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; padding: 20px; }}
|
||||
.button {{ display: inline-block; padding: 12px 24px; background-color: #5046e5; color: white; text-decoration: none; border-radius: 6px; font-weight: 500; }}
|
||||
.info {{ background-color: #dbeafe; border: 1px solid #3b82f6; padding: 12px; border-radius: 6px; margin: 20px 0; }}
|
||||
.footer {{ margin-top: 30px; font-size: 12px; color: #666; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>Set Your Password</h2>
|
||||
<div class="info">
|
||||
<strong>We've upgraded our authentication system!</strong>
|
||||
<p style="margin: 8px 0 0 0;">For enhanced security, please set a new password to continue using your account.</p>
|
||||
</div>
|
||||
<p>Click the button below to set your password:</p>
|
||||
<p style="margin: 30px 0;">
|
||||
<a href="{reset_url}" class="button">Set Password</a>
|
||||
</p>
|
||||
<p>Or copy and paste this link into your browser:</p>
|
||||
<p style="word-break: break-all; color: #666;">{reset_url}</p>
|
||||
<p>This link will expire in 24 hours.</p>
|
||||
<div class="footer">
|
||||
<p>If you signed up with Google, no action is needed - simply continue signing in with Google.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return self._send_email(email, subject, html_body)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_email_service: Optional[AuthEmailService] = None
|
||||
|
||||
|
||||
def get_auth_email_service() -> AuthEmailService:
|
||||
"""Get the singleton auth email service instance."""
|
||||
global _email_service
|
||||
if _email_service is None:
|
||||
_email_service = AuthEmailService()
|
||||
return _email_service
|
||||
253
autogpt_platform/backend/backend/data/auth/magic_links.py
Normal file
253
autogpt_platform/backend/backend/data/auth/magic_links.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Magic link service for email verification and password reset.
|
||||
|
||||
Magic links are single-use, time-limited tokens sent via email that allow
|
||||
users to verify their email address or reset their password without entering
|
||||
the old password.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from prisma.models import UserAuthMagicLink
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Magic link TTLs
|
||||
EMAIL_VERIFICATION_TTL = timedelta(hours=24)
|
||||
PASSWORD_RESET_TTL = timedelta(minutes=15)
|
||||
|
||||
# Token prefix for identification
|
||||
MAGIC_LINK_PREFIX = "agpt_ml_"
|
||||
|
||||
|
||||
class MagicLinkPurpose(str, Enum):
|
||||
"""Purpose of the magic link."""
|
||||
|
||||
EMAIL_VERIFICATION = "email_verification"
|
||||
PASSWORD_RESET = "password_reset"
|
||||
|
||||
|
||||
class MagicLinkInfo(BaseModel):
|
||||
"""Information about a valid magic link."""
|
||||
|
||||
email: str
|
||||
purpose: MagicLinkPurpose
|
||||
user_id: Optional[str] = None # Set for password reset, not for signup verification
|
||||
|
||||
|
||||
def generate_magic_link_token() -> str:
|
||||
"""
|
||||
Generate a cryptographically secure magic link token.
|
||||
|
||||
Returns:
|
||||
A prefixed random token string.
|
||||
"""
|
||||
random_bytes = secrets.token_urlsafe(32)
|
||||
return f"{MAGIC_LINK_PREFIX}{random_bytes}"
|
||||
|
||||
|
||||
def hash_magic_link_token(token: str) -> str:
|
||||
"""
|
||||
Hash a magic link token for storage.
|
||||
|
||||
Uses SHA256 for deterministic lookup.
|
||||
|
||||
Args:
|
||||
token: The plaintext magic link token.
|
||||
|
||||
Returns:
|
||||
The SHA256 hex digest.
|
||||
"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
async def create_magic_link(
|
||||
email: str,
|
||||
purpose: MagicLinkPurpose,
|
||||
user_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a magic link token and store it in the database.
|
||||
|
||||
Args:
|
||||
email: The email address associated with the link.
|
||||
purpose: The purpose of the magic link.
|
||||
user_id: Optional user ID (for password reset).
|
||||
|
||||
Returns:
|
||||
The plaintext magic link token.
|
||||
"""
|
||||
token = generate_magic_link_token()
|
||||
token_hash = hash_magic_link_token(token)
|
||||
|
||||
# Determine TTL based on purpose
|
||||
if purpose == MagicLinkPurpose.PASSWORD_RESET:
|
||||
ttl = PASSWORD_RESET_TTL
|
||||
else:
|
||||
ttl = EMAIL_VERIFICATION_TTL
|
||||
|
||||
expires_at = datetime.now(timezone.utc) + ttl
|
||||
|
||||
# Invalidate any existing magic links for this email and purpose
|
||||
await UserAuthMagicLink.prisma().update_many(
|
||||
where={
|
||||
"email": email,
|
||||
"purpose": purpose.value,
|
||||
"usedAt": None,
|
||||
},
|
||||
data={"usedAt": datetime.now(timezone.utc)}, # Mark as used to invalidate
|
||||
)
|
||||
|
||||
# Create new magic link
|
||||
await UserAuthMagicLink.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"tokenHash": token_hash,
|
||||
"email": email,
|
||||
"purpose": purpose.value,
|
||||
"userId": user_id,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
async def validate_magic_link(
|
||||
token: str,
|
||||
expected_purpose: Optional[MagicLinkPurpose] = None,
|
||||
) -> Optional[MagicLinkInfo]:
|
||||
"""
|
||||
Validate a magic link token without consuming it.
|
||||
|
||||
Args:
|
||||
token: The plaintext magic link token.
|
||||
expected_purpose: Optional expected purpose to validate against.
|
||||
|
||||
Returns:
|
||||
MagicLinkInfo if valid, None otherwise.
|
||||
"""
|
||||
token_hash = hash_magic_link_token(token)
|
||||
|
||||
where_clause: dict = {
|
||||
"tokenHash": token_hash,
|
||||
"usedAt": None,
|
||||
"expiresAt": {"gt": datetime.now(timezone.utc)},
|
||||
}
|
||||
|
||||
if expected_purpose:
|
||||
where_clause["purpose"] = expected_purpose.value
|
||||
|
||||
db_link = await UserAuthMagicLink.prisma().find_first(where=where_clause)
|
||||
|
||||
if not db_link:
|
||||
return None
|
||||
|
||||
return MagicLinkInfo(
|
||||
email=db_link.email,
|
||||
purpose=MagicLinkPurpose(db_link.purpose),
|
||||
user_id=db_link.userId,
|
||||
)
|
||||
|
||||
|
||||
async def consume_magic_link(
|
||||
token: str,
|
||||
expected_purpose: Optional[MagicLinkPurpose] = None,
|
||||
) -> Optional[MagicLinkInfo]:
|
||||
"""
|
||||
Validate and consume a magic link token (single-use).
|
||||
|
||||
Args:
|
||||
token: The plaintext magic link token.
|
||||
expected_purpose: Optional expected purpose to validate against.
|
||||
|
||||
Returns:
|
||||
MagicLinkInfo if valid and successfully consumed, None otherwise.
|
||||
"""
|
||||
# First validate
|
||||
link_info = await validate_magic_link(token, expected_purpose)
|
||||
if not link_info:
|
||||
return None
|
||||
|
||||
# Then consume (mark as used)
|
||||
token_hash = hash_magic_link_token(token)
|
||||
result = await UserAuthMagicLink.prisma().update_many(
|
||||
where={
|
||||
"tokenHash": token_hash,
|
||||
"usedAt": None,
|
||||
},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
if result == 0:
|
||||
# Race condition - link was consumed by another request
|
||||
logger.warning("Magic link was already consumed (race condition)")
|
||||
return None
|
||||
|
||||
return link_info
|
||||
|
||||
|
||||
async def create_email_verification_link(email: str) -> str:
|
||||
"""
|
||||
Create an email verification magic link.
|
||||
|
||||
Args:
|
||||
email: The email address to verify.
|
||||
|
||||
Returns:
|
||||
The plaintext magic link token.
|
||||
"""
|
||||
return await create_magic_link(email, MagicLinkPurpose.EMAIL_VERIFICATION)
|
||||
|
||||
|
||||
async def create_password_reset_link(email: str, user_id: str) -> str:
|
||||
"""
|
||||
Create a password reset magic link.
|
||||
|
||||
Args:
|
||||
email: The user's email address.
|
||||
user_id: The user's ID.
|
||||
|
||||
Returns:
|
||||
The plaintext magic link token.
|
||||
"""
|
||||
return await create_magic_link(
|
||||
email, MagicLinkPurpose.PASSWORD_RESET, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
async def verify_email_token(token: str) -> Optional[str]:
|
||||
"""
|
||||
Verify an email verification token and consume it.
|
||||
|
||||
Args:
|
||||
token: The magic link token.
|
||||
|
||||
Returns:
|
||||
The email address if valid, None otherwise.
|
||||
"""
|
||||
link_info = await consume_magic_link(token, MagicLinkPurpose.EMAIL_VERIFICATION)
|
||||
return link_info.email if link_info else None
|
||||
|
||||
|
||||
async def verify_password_reset_token(token: str) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
Verify a password reset token and consume it.
|
||||
|
||||
Args:
|
||||
token: The magic link token.
|
||||
|
||||
Returns:
|
||||
Tuple of (user_id, email) if valid, None otherwise.
|
||||
"""
|
||||
link_info = await consume_magic_link(token, MagicLinkPurpose.PASSWORD_RESET)
|
||||
if not link_info or not link_info.user_id:
|
||||
return None
|
||||
return link_info.user_id, link_info.email
|
||||
441
autogpt_platform/backend/backend/data/auth/migration.py
Normal file
441
autogpt_platform/backend/backend/data/auth/migration.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
Migration script for moving users from Supabase Auth to native FastAPI auth.
|
||||
|
||||
This script handles:
|
||||
1. Marking existing users as migrated from Supabase
|
||||
2. Sending password reset emails to migrated users
|
||||
3. Tracking migration progress
|
||||
4. Generating reports
|
||||
|
||||
Usage:
|
||||
# Dry run - see what would happen
|
||||
python -m backend.data.auth.migration --dry-run
|
||||
|
||||
# Mark users as migrated (no emails)
|
||||
python -m backend.data.auth.migration --mark-migrated
|
||||
|
||||
# Send password reset emails to migrated users
|
||||
python -m backend.data.auth.migration --send-emails --batch-size 100
|
||||
|
||||
# Full migration (mark + send emails)
|
||||
python -m backend.data.auth.migration --full-migration --batch-size 100
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import csv
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prisma.models import User
|
||||
|
||||
from backend.data.auth.email_service import get_auth_email_service
|
||||
from backend.data.auth.magic_links import create_password_reset_link
|
||||
from backend.data.db import connect, disconnect
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MigrationStats:
|
||||
"""Track migration statistics."""
|
||||
|
||||
def __init__(self):
|
||||
self.total_users = 0
|
||||
self.already_migrated = 0
|
||||
self.marked_migrated = 0
|
||||
self.emails_sent = 0
|
||||
self.emails_failed = 0
|
||||
self.oauth_users_skipped = 0
|
||||
self.errors = []
|
||||
|
||||
def __str__(self):
|
||||
return f"""
|
||||
Migration Statistics:
|
||||
---------------------
|
||||
Total users processed: {self.total_users}
|
||||
Already migrated: {self.already_migrated}
|
||||
Newly marked as migrated: {self.marked_migrated}
|
||||
Password reset emails sent: {self.emails_sent}
|
||||
Email failures: {self.emails_failed}
|
||||
OAuth users skipped: {self.oauth_users_skipped}
|
||||
Errors: {len(self.errors)}
|
||||
"""
|
||||
|
||||
|
||||
async def get_users_to_migrate(
|
||||
batch_size: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[User]:
|
||||
"""
|
||||
Get users that need to be migrated.
|
||||
|
||||
Returns users where:
|
||||
- authProvider is "supabase" or NULL
|
||||
- migratedFromSupabase is False or NULL
|
||||
- passwordHash is NULL (they haven't set a native password)
|
||||
"""
|
||||
users = await User.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"authProvider": "supabase"},
|
||||
{"authProvider": None},
|
||||
],
|
||||
"migratedFromSupabase": False,
|
||||
"passwordHash": None,
|
||||
},
|
||||
take=batch_size,
|
||||
skip=offset,
|
||||
order={"createdAt": "asc"},
|
||||
)
|
||||
return users
|
||||
|
||||
|
||||
async def get_migrated_users_needing_email(
|
||||
batch_size: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[User]:
|
||||
"""
|
||||
Get migrated users who haven't set their password yet.
|
||||
|
||||
These users need a password reset email.
|
||||
"""
|
||||
users = await User.prisma().find_many(
|
||||
where={
|
||||
"migratedFromSupabase": True,
|
||||
"passwordHash": None,
|
||||
"authProvider": {"not": "google"}, # Skip OAuth users
|
||||
},
|
||||
take=batch_size,
|
||||
skip=offset,
|
||||
order={"createdAt": "asc"},
|
||||
)
|
||||
return users
|
||||
|
||||
|
||||
async def mark_user_as_migrated(user: User, dry_run: bool = False) -> bool:
|
||||
"""
|
||||
Mark a user as migrated from Supabase.
|
||||
|
||||
Sets migratedFromSupabase=True and authProvider="supabase".
|
||||
"""
|
||||
if dry_run:
|
||||
logger.info(f"[DRY RUN] Would mark user {user.id} ({user.email}) as migrated")
|
||||
return True
|
||||
|
||||
try:
|
||||
await User.prisma().update(
|
||||
where={"id": user.id},
|
||||
data={
|
||||
"migratedFromSupabase": True,
|
||||
"authProvider": "supabase",
|
||||
},
|
||||
)
|
||||
logger.info(f"Marked user {user.id} ({user.email}) as migrated")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark user {user.id} as migrated: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def send_migration_email(
|
||||
user: User,
|
||||
email_service,
|
||||
dry_run: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a password reset email to a migrated user.
|
||||
"""
|
||||
if dry_run:
|
||||
logger.info(f"[DRY RUN] Would send migration email to {user.email}")
|
||||
return True
|
||||
|
||||
try:
|
||||
token = await create_password_reset_link(user.email, user.id)
|
||||
success = email_service.send_migrated_user_password_reset(user.email, token)
|
||||
|
||||
if success:
|
||||
logger.info(f"Sent migration email to {user.email}")
|
||||
else:
|
||||
logger.warning(f"Failed to send migration email to {user.email}")
|
||||
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending migration email to {user.email}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def run_migration(
|
||||
mark_migrated: bool = False,
|
||||
send_emails: bool = False,
|
||||
batch_size: int = 100,
|
||||
dry_run: bool = False,
|
||||
email_delay: float = 0.5, # Delay between emails to avoid rate limiting
|
||||
) -> MigrationStats:
|
||||
"""
|
||||
Run the migration process.
|
||||
|
||||
Args:
|
||||
mark_migrated: Mark users as migrated from Supabase
|
||||
send_emails: Send password reset emails to migrated users
|
||||
batch_size: Number of users to process at a time
|
||||
dry_run: If True, don't make any changes
|
||||
email_delay: Seconds to wait between sending emails
|
||||
|
||||
Returns:
|
||||
MigrationStats with results
|
||||
"""
|
||||
stats = MigrationStats()
|
||||
email_service = get_auth_email_service() if send_emails else None
|
||||
|
||||
# Phase 1: Mark users as migrated
|
||||
if mark_migrated:
|
||||
logger.info("Phase 1: Marking users as migrated...")
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
users = await get_users_to_migrate(batch_size, offset)
|
||||
if not users:
|
||||
break
|
||||
|
||||
for user in users:
|
||||
stats.total_users += 1
|
||||
|
||||
# Skip OAuth users
|
||||
if user.authProvider == "google":
|
||||
stats.oauth_users_skipped += 1
|
||||
continue
|
||||
|
||||
success = await mark_user_as_migrated(user, dry_run)
|
||||
if success:
|
||||
stats.marked_migrated += 1
|
||||
else:
|
||||
stats.errors.append(f"Failed to mark {user.email}")
|
||||
|
||||
offset += batch_size
|
||||
logger.info(f"Processed {offset} users...")
|
||||
|
||||
# Phase 2: Send password reset emails
|
||||
if send_emails:
|
||||
logger.info("Phase 2: Sending password reset emails...")
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
users = await get_migrated_users_needing_email(batch_size, offset)
|
||||
if not users:
|
||||
break
|
||||
|
||||
for user in users:
|
||||
stats.total_users += 1
|
||||
|
||||
success = await send_migration_email(user, email_service, dry_run)
|
||||
if success:
|
||||
stats.emails_sent += 1
|
||||
else:
|
||||
stats.emails_failed += 1
|
||||
stats.errors.append(f"Failed to email {user.email}")
|
||||
|
||||
# Rate limiting
|
||||
if not dry_run and email_delay > 0:
|
||||
await asyncio.sleep(email_delay)
|
||||
|
||||
offset += batch_size
|
||||
logger.info(f"Processed {offset} users for email...")
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def generate_migration_report(output_path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate a CSV report of all users and their migration status.
|
||||
"""
|
||||
if output_path is None:
|
||||
output_path = f"migration_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||
|
||||
users = await User.prisma().find_many(
|
||||
order={"createdAt": "asc"},
|
||||
)
|
||||
|
||||
with open(output_path, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([
|
||||
"user_id",
|
||||
"email",
|
||||
"auth_provider",
|
||||
"migrated_from_supabase",
|
||||
"has_password",
|
||||
"email_verified",
|
||||
"created_at",
|
||||
"needs_action",
|
||||
])
|
||||
|
||||
for user in users:
|
||||
needs_action = (
|
||||
user.migratedFromSupabase
|
||||
and user.passwordHash is None
|
||||
and user.authProvider != "google"
|
||||
)
|
||||
|
||||
writer.writerow([
|
||||
user.id,
|
||||
user.email,
|
||||
user.authProvider or "unknown",
|
||||
user.migratedFromSupabase,
|
||||
user.passwordHash is not None,
|
||||
user.emailVerified,
|
||||
user.createdAt.isoformat() if user.createdAt else "",
|
||||
"YES" if needs_action else "NO",
|
||||
])
|
||||
|
||||
logger.info(f"Report saved to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
async def count_migration_status():
|
||||
"""
|
||||
Get counts of users in different migration states.
|
||||
"""
|
||||
total = await User.prisma().count()
|
||||
|
||||
already_native = await User.prisma().count(
|
||||
where={"authProvider": "password", "passwordHash": {"not": None}}
|
||||
)
|
||||
|
||||
oauth_users = await User.prisma().count(
|
||||
where={"authProvider": "google"}
|
||||
)
|
||||
|
||||
migrated_pending = await User.prisma().count(
|
||||
where={
|
||||
"migratedFromSupabase": True,
|
||||
"passwordHash": None,
|
||||
"authProvider": {"not": "google"},
|
||||
}
|
||||
)
|
||||
|
||||
not_migrated = await User.prisma().count(
|
||||
where={
|
||||
"migratedFromSupabase": False,
|
||||
"authProvider": {"in": ["supabase", None]},
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"already_native": already_native,
|
||||
"oauth_users": oauth_users,
|
||||
"migrated_pending_password": migrated_pending,
|
||||
"not_yet_migrated": not_migrated,
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Migrate users from Supabase Auth to native FastAPI auth"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Don't make any changes, just show what would happen",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mark-migrated",
|
||||
action="store_true",
|
||||
help="Mark existing Supabase users as migrated",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--send-emails",
|
||||
action="store_true",
|
||||
help="Send password reset emails to migrated users",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-migration",
|
||||
action="store_true",
|
||||
help="Run full migration (mark + send emails)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of users to process at a time (default: 100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--email-delay",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Seconds to wait between emails (default: 0.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report",
|
||||
action="store_true",
|
||||
help="Generate a CSV report of migration status",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--status",
|
||||
action="store_true",
|
||||
help="Show current migration status counts",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Connect to database
|
||||
await connect()
|
||||
|
||||
try:
|
||||
if args.status:
|
||||
counts = await count_migration_status()
|
||||
print("\nMigration Status:")
|
||||
print("-" * 40)
|
||||
print(f"Total users: {counts['total']}")
|
||||
print(f"Already using native auth: {counts['already_native']}")
|
||||
print(f"OAuth users (Google): {counts['oauth_users']}")
|
||||
print(f"Migrated, pending password: {counts['migrated_pending_password']}")
|
||||
print(f"Not yet migrated: {counts['not_yet_migrated']}")
|
||||
return
|
||||
|
||||
if args.report:
|
||||
await generate_migration_report()
|
||||
return
|
||||
|
||||
if args.full_migration:
|
||||
args.mark_migrated = True
|
||||
args.send_emails = True
|
||||
|
||||
if not args.mark_migrated and not args.send_emails:
|
||||
parser.print_help()
|
||||
print("\nError: Must specify --mark-migrated, --send-emails, --full-migration, --report, or --status")
|
||||
return
|
||||
|
||||
if args.dry_run:
|
||||
logger.info("=" * 50)
|
||||
logger.info("DRY RUN MODE - No changes will be made")
|
||||
logger.info("=" * 50)
|
||||
|
||||
stats = await run_migration(
|
||||
mark_migrated=args.mark_migrated,
|
||||
send_emails=args.send_emails,
|
||||
batch_size=args.batch_size,
|
||||
dry_run=args.dry_run,
|
||||
email_delay=args.email_delay,
|
||||
)
|
||||
|
||||
print(stats)
|
||||
|
||||
if stats.errors:
|
||||
print("\nErrors encountered:")
|
||||
for error in stats.errors[:10]: # Show first 10 errors
|
||||
print(f" - {error}")
|
||||
if len(stats.errors) > 10:
|
||||
print(f" ... and {len(stats.errors) - 10} more")
|
||||
|
||||
finally:
|
||||
await disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
70
autogpt_platform/backend/backend/data/auth/password.py
Normal file
70
autogpt_platform/backend/backend/data/auth/password.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Password hashing service using Argon2id.
|
||||
|
||||
OWASP 2024 recommended configuration:
|
||||
- time_cost: 2 iterations
|
||||
- memory_cost: 19456 KiB (19 MiB)
|
||||
- parallelism: 1
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import InvalidHashError, VerifyMismatchError
|
||||
from argon2.profiles import RFC_9106_LOW_MEMORY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Use RFC 9106 low-memory profile (OWASP recommended)
|
||||
# time_cost=2, memory_cost=19456, parallelism=1
|
||||
_hasher = PasswordHasher.from_parameters(RFC_9106_LOW_MEMORY)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""
|
||||
Hash a password using Argon2id.
|
||||
|
||||
Args:
|
||||
password: The plaintext password to hash.
|
||||
|
||||
Returns:
|
||||
The hashed password string (includes algorithm params and salt).
|
||||
"""
|
||||
return _hasher.hash(password)
|
||||
|
||||
|
||||
def verify_password(password_hash: str, password: str) -> bool:
|
||||
"""
|
||||
Verify a password against a hash.
|
||||
|
||||
Args:
|
||||
password_hash: The stored password hash.
|
||||
password: The plaintext password to verify.
|
||||
|
||||
Returns:
|
||||
True if the password matches, False otherwise.
|
||||
"""
|
||||
try:
|
||||
_hasher.verify(password_hash, password)
|
||||
return True
|
||||
except VerifyMismatchError:
|
||||
return False
|
||||
except InvalidHashError:
|
||||
logger.warning("Invalid password hash format encountered")
|
||||
return False
|
||||
|
||||
|
||||
def needs_rehash(password_hash: str) -> bool:
|
||||
"""
|
||||
Check if a password hash needs to be rehashed.
|
||||
|
||||
This returns True if the hash was created with different parameters
|
||||
than the current configuration, allowing for transparent upgrades.
|
||||
|
||||
Args:
|
||||
password_hash: The stored password hash.
|
||||
|
||||
Returns:
|
||||
True if the hash should be rehashed, False otherwise.
|
||||
"""
|
||||
return _hasher.check_needs_rehash(password_hash)
|
||||
270
autogpt_platform/backend/backend/data/auth/tokens.py
Normal file
270
autogpt_platform/backend/backend/data/auth/tokens.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
JWT token generation and validation for user authentication.
|
||||
|
||||
This module generates tokens compatible with Supabase JWT format to ensure
|
||||
a smooth migration without requiring frontend changes.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
import jwt
|
||||
from prisma.models import UserAuthRefreshToken
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt_libs.auth.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Token TTLs
|
||||
ACCESS_TOKEN_TTL = timedelta(hours=1)
|
||||
REFRESH_TOKEN_TTL = timedelta(days=30)
|
||||
|
||||
# Refresh token prefix for identification
|
||||
REFRESH_TOKEN_PREFIX = "agpt_rt_"
|
||||
|
||||
|
||||
class TokenPair(BaseModel):
|
||||
"""Access and refresh token pair."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int # seconds until access token expires
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class JWTPayload(BaseModel):
|
||||
"""JWT payload structure matching Supabase format."""
|
||||
|
||||
sub: str # user ID
|
||||
email: str
|
||||
phone: str = ""
|
||||
role: str = "authenticated"
|
||||
aud: str = "authenticated"
|
||||
iat: int # issued at (unix timestamp)
|
||||
exp: int # expiration (unix timestamp)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: str,
|
||||
email: str,
|
||||
role: str = "authenticated",
|
||||
phone: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
|
||||
The token format matches Supabase JWT structure so existing backend
|
||||
validation code continues to work without modification.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
email: The user's email address.
|
||||
role: The user's role (default: "authenticated").
|
||||
phone: The user's phone number (optional).
|
||||
|
||||
Returns:
|
||||
The encoded JWT token string.
|
||||
"""
|
||||
settings = get_settings()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"phone": phone,
|
||||
"role": role,
|
||||
"aud": "authenticated",
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int((now + ACCESS_TOKEN_TTL).timestamp()),
|
||||
}
|
||||
|
||||
return jwt.encode(payload, settings.JWT_VERIFY_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> Optional[JWTPayload]:
|
||||
"""
|
||||
Decode and validate a JWT access token.
|
||||
|
||||
Args:
|
||||
token: The JWT token string.
|
||||
|
||||
Returns:
|
||||
The decoded payload if valid, None otherwise.
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_VERIFY_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience="authenticated",
|
||||
)
|
||||
return JWTPayload(**payload)
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug("Token has expired")
|
||||
return None
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.debug(f"Invalid token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_refresh_token() -> str:
|
||||
"""
|
||||
Generate a cryptographically secure refresh token.
|
||||
|
||||
Returns:
|
||||
A prefixed random token string.
|
||||
"""
|
||||
random_bytes = secrets.token_urlsafe(32)
|
||||
return f"{REFRESH_TOKEN_PREFIX}{random_bytes}"
|
||||
|
||||
|
||||
def hash_refresh_token(token: str) -> str:
|
||||
"""
|
||||
Hash a refresh token for storage.
|
||||
|
||||
Uses SHA256 for deterministic lookup (unlike passwords which use Argon2).
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token.
|
||||
|
||||
Returns:
|
||||
The SHA256 hex digest.
|
||||
"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
async def create_refresh_token_db(
|
||||
user_id: str,
|
||||
token: Optional[str] = None,
|
||||
) -> tuple[str, datetime]:
|
||||
"""
|
||||
Create a refresh token and store it in the database.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
token: Optional pre-generated token (used in OAuth flow).
|
||||
|
||||
Returns:
|
||||
Tuple of (plaintext token, expiration datetime).
|
||||
"""
|
||||
if token is None:
|
||||
token = generate_refresh_token()
|
||||
|
||||
token_hash = hash_refresh_token(token)
|
||||
expires_at = datetime.now(timezone.utc) + REFRESH_TOKEN_TTL
|
||||
|
||||
await UserAuthRefreshToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"tokenHash": token_hash,
|
||||
"userId": user_id,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
return token, expires_at
|
||||
|
||||
|
||||
async def validate_refresh_token(token: str) -> Optional[str]:
|
||||
"""
|
||||
Validate a refresh token and return the associated user ID.
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token.
|
||||
|
||||
Returns:
|
||||
The user ID if valid, None otherwise.
|
||||
"""
|
||||
token_hash = hash_refresh_token(token)
|
||||
|
||||
db_token = await UserAuthRefreshToken.prisma().find_first(
|
||||
where={
|
||||
"tokenHash": token_hash,
|
||||
"revokedAt": None,
|
||||
"expiresAt": {"gt": datetime.now(timezone.utc)},
|
||||
}
|
||||
)
|
||||
|
||||
if not db_token:
|
||||
return None
|
||||
|
||||
return db_token.userId
|
||||
|
||||
|
||||
async def revoke_refresh_token(token: str) -> bool:
|
||||
"""
|
||||
Revoke a refresh token.
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token.
|
||||
|
||||
Returns:
|
||||
True if a token was revoked, False otherwise.
|
||||
"""
|
||||
token_hash = hash_refresh_token(token)
|
||||
|
||||
result = await UserAuthRefreshToken.prisma().update_many(
|
||||
where={
|
||||
"tokenHash": token_hash,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
return result > 0
|
||||
|
||||
|
||||
async def revoke_all_user_refresh_tokens(user_id: str) -> int:
|
||||
"""
|
||||
Revoke all refresh tokens for a user.
|
||||
|
||||
Used for global logout or security events.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
Number of tokens revoked.
|
||||
"""
|
||||
result = await UserAuthRefreshToken.prisma().update_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"revokedAt": None,
|
||||
},
|
||||
data={"revokedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def create_token_pair(
|
||||
user_id: str,
|
||||
email: str,
|
||||
role: str = "authenticated",
|
||||
) -> TokenPair:
|
||||
"""
|
||||
Create a complete token pair (access + refresh).
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
email: The user's email.
|
||||
role: The user's role.
|
||||
|
||||
Returns:
|
||||
TokenPair with access_token, refresh_token, and metadata.
|
||||
"""
|
||||
access_token = create_access_token(user_id, email, role)
|
||||
refresh_token, _ = await create_refresh_token_db(user_id)
|
||||
|
||||
return TokenPair(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=int(ACCESS_TOKEN_TTL.total_seconds()),
|
||||
)
|
||||
@@ -50,8 +50,6 @@ from .model import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
from .graph import Link
|
||||
|
||||
app_config = Config()
|
||||
@@ -474,7 +472,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.requires_human_review: bool = False
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -617,77 +614,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
block_id=self.id,
|
||||
) from ex
|
||||
|
||||
async def is_block_exec_need_review(
|
||||
self,
|
||||
input_data: BlockInput,
|
||||
*,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: "ExecutionContext",
|
||||
**kwargs,
|
||||
) -> tuple[bool, BlockInput]:
|
||||
"""
|
||||
Check if this block execution needs human review and handle the review process.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_pause, input_data_to_use)
|
||||
- should_pause: True if execution should be paused for review
|
||||
- input_data_to_use: The input data to use (may be modified by reviewer)
|
||||
"""
|
||||
# Skip review if not required or safe mode is disabled
|
||||
if not self.requires_human_review or not execution_context.safe_mode:
|
||||
return False, input_data
|
||||
|
||||
from backend.blocks.helpers.review import HITLReviewHelper
|
||||
|
||||
# Handle the review request and get decision
|
||||
decision = await HITLReviewHelper.handle_review_decision(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
)
|
||||
|
||||
if decision is None:
|
||||
# We're awaiting review - pause execution
|
||||
return True, input_data
|
||||
|
||||
if not decision.should_proceed:
|
||||
# Review was rejected, raise an error to stop execution
|
||||
raise BlockExecutionError(
|
||||
message=f"Block execution rejected by reviewer: {decision.message}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Review was approved - use the potentially modified data
|
||||
# ReviewResult.data must be a dict for block inputs
|
||||
reviewed_data = decision.review_result.data
|
||||
if not isinstance(reviewed_data, dict):
|
||||
raise BlockExecutionError(
|
||||
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
return False, reviewed_data
|
||||
|
||||
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
# Check for review requirement and get potentially modified input data
|
||||
should_pause, input_data = await self.is_block_exec_need_review(
|
||||
input_data, **kwargs
|
||||
)
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
@@ -695,7 +622,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||
**kwargs,
|
||||
|
||||
@@ -59,13 +59,12 @@ from backend.integrations.credentials_store import (
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2,
|
||||
LlmModel.O1: 16,
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: 6,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
@@ -88,7 +87,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
|
||||
@@ -16,7 +16,6 @@ from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBala
|
||||
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.admin.model import UserHistoryResponse
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
@@ -30,6 +29,7 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
@@ -341,19 +341,6 @@ class UserCreditBase(ABC):
|
||||
|
||||
if result:
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if transaction.amount > 0 and transaction.type in [
|
||||
CreditTransactionType.GRANT,
|
||||
CreditTransactionType.TOP_UP,
|
||||
]:
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return result[0]["balance"]
|
||||
|
||||
async def _add_transaction(
|
||||
@@ -543,22 +530,6 @@ class UserCreditBase(ABC):
|
||||
if result:
|
||||
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if (
|
||||
amount > 0
|
||||
and is_active
|
||||
and transaction_type
|
||||
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
|
||||
):
|
||||
# Lazy import to avoid circular dependency with executor.manager
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return new_balance, tx_key
|
||||
|
||||
# If no result, either user doesn't exist or insufficient balance
|
||||
|
||||
@@ -111,7 +111,7 @@ def get_database_schema() -> str:
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||
schema_prefix = f"{schema}." if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
@@ -383,7 +383,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
self,
|
||||
execution_context: ExecutionContext,
|
||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
@@ -391,7 +390,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip or set(),
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
@@ -1147,8 +1145,6 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||
nodes_to_skip: set[str] = Field(default_factory=set)
|
||||
"""Node IDs that should be skipped due to optional credentials not being configured."""
|
||||
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
|
||||
|
||||
|
||||
|
||||
@@ -94,15 +94,6 @@ class Node(BaseDbModel):
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
|
||||
@property
|
||||
def credentials_optional(self) -> bool:
|
||||
"""
|
||||
Whether credentials are optional for this node.
|
||||
When True and credentials are not configured, the node will be skipped
|
||||
during execution rather than causing a validation error.
|
||||
"""
|
||||
return self.metadata.get("credentials_optional", False)
|
||||
|
||||
@property
|
||||
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
|
||||
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
||||
@@ -244,10 +235,7 @@ class BaseGraph(BaseDbModel):
|
||||
return any(
|
||||
node.block_id
|
||||
for node in self.nodes
|
||||
if (
|
||||
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||
or node.block.requires_human_review
|
||||
)
|
||||
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -338,35 +326,7 @@ class Graph(BaseGraph):
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
schema = self._credentials_input_schema.jsonschema()
|
||||
|
||||
# Determine which credential fields are required based on credentials_optional metadata
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
required_fields = []
|
||||
|
||||
# Build a map of node_id -> node for quick lookup
|
||||
all_nodes = {node.id: node for node in self.nodes}
|
||||
for sub_graph in self.sub_graphs:
|
||||
for node in sub_graph.nodes:
|
||||
all_nodes[node.id] = node
|
||||
|
||||
for field_key, (
|
||||
_field_info,
|
||||
node_field_pairs,
|
||||
) in graph_credentials_inputs.items():
|
||||
# A field is required if ANY node using it has credentials_optional=False
|
||||
is_required = False
|
||||
for node_id, _field_name in node_field_pairs:
|
||||
node = all_nodes.get(node_id)
|
||||
if node and not node.credentials_optional:
|
||||
is_required = True
|
||||
break
|
||||
|
||||
if is_required:
|
||||
required_fields.append(field_key)
|
||||
|
||||
schema["required"] = required_fields
|
||||
return schema
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
|
||||
@@ -6,14 +6,14 @@ import fastapi.exceptions
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.api.features.store.model as store
|
||||
from backend.api.model import CreateGraph
|
||||
import backend.server.v2.store.model as store
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.block import BlockSchema, BlockSchemaInput
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_user
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -396,58 +396,3 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
)
|
||||
assert got_graph is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_node_credentials_optional_default():
|
||||
"""Test that credentials_optional defaults to False when not set in metadata."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_true():
|
||||
"""Test that credentials_optional returns True when explicitly set."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": True},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
|
||||
|
||||
def test_node_credentials_optional_false():
|
||||
"""Test that credentials_optional returns False when explicitly set to False."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": False},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_with_other_metadata():
|
||||
"""Test that credentials_optional works correctly with other metadata present."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={
|
||||
"position": {"x": 100, "y": 200},
|
||||
"customized_name": "My Custom Node",
|
||||
"credentials_optional": True,
|
||||
},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||
assert node.metadata["customized_name"] == "My Custom Node"
|
||||
|
||||
@@ -13,7 +13,7 @@ from prisma.models import PendingHumanReview
|
||||
from prisma.types import PendingHumanReviewUpdateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.executions.review.model import (
|
||||
from backend.server.v2.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
SafeJsonData,
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .db import BaseDbModel
|
||||
from .graph import NodeModel
|
||||
@@ -79,7 +79,7 @@ class WebhookWithRelations(Webhook):
|
||||
# integrations.py → library/model.py → integrations.py (for Webhook)
|
||||
# Runtime import is used in WebhookWithRelations.from_db() method instead
|
||||
# Import at runtime to avoid circular dependency
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
return WebhookWithRelations(
|
||||
**Webhook.from_db(webhook).model_dump(),
|
||||
@@ -285,8 +285,8 @@ async def unlink_webhook_from_graph(
|
||||
user_id: The ID of the user (for authorization)
|
||||
"""
|
||||
# Avoid circular imports
|
||||
from backend.api.features.library.db import set_preset_webhook
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.server.v2.library.db import set_preset_webhook
|
||||
|
||||
# Find all nodes in this graph that use this webhook
|
||||
nodes = await AgentNode.prisma().find_many(
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
from backend.api.model import NotificationPayload
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.server.model import NotificationPayload
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.api.features.store.model import StoreAgentDetails
|
||||
from backend.api.model import OnboardingNotificationPayload
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.notification_bus import (
|
||||
@@ -18,6 +16,8 @@ from backend.data.notification_bus import (
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.server.model import OnboardingNotificationPayload
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.timezone_utils import get_user_timezone_or_utc
|
||||
@@ -442,8 +442,6 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
agentGraphVersions=agent.agentGraphVersions,
|
||||
agentGraphId=agent.agentGraphId,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
for agent in recommended_agents
|
||||
|
||||
@@ -2,11 +2,6 @@ import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
list_library_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -66,6 +61,8 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
|
||||
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
|
||||
@@ -48,8 +48,27 @@ from backend.data.notifications import (
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
@@ -76,24 +95,7 @@ from backend.util.retry import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .activity_status_generator import generate_activity_status_for_execution
|
||||
from .automod.manager import automod_manager
|
||||
from .cluster_lock import ClusterLock
|
||||
from .utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
@@ -114,40 +116,6 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||||
|
||||
|
||||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
"""
|
||||
Clear all insufficient funds notification flags for a user.
|
||||
|
||||
This should be called when a user tops up their credits, allowing
|
||||
Discord notifications to be sent again if they run out of funds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to clear notifications for.
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted.
|
||||
"""
|
||||
try:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
f"{user_id}: {e}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
@@ -178,7 +146,6 @@ async def execute_node(
|
||||
execution_processor: "ExecutionProcessor",
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -246,7 +213,6 @@ async def execute_node(
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
"execution_processor": execution_processor,
|
||||
"nodes_to_skip": nodes_to_skip or set(),
|
||||
}
|
||||
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
@@ -544,7 +510,6 @@ class ExecutionProcessor:
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -567,7 +532,6 @@ class ExecutionProcessor:
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
)
|
||||
if isinstance(status, BaseException):
|
||||
raise status
|
||||
@@ -613,7 +577,6 @@ class ExecutionProcessor:
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
@@ -650,7 +613,6 @@ class ExecutionProcessor:
|
||||
execution_processor=self,
|
||||
execution_stats=stats,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
):
|
||||
await persist_output(output_name, output_data)
|
||||
|
||||
@@ -962,21 +924,6 @@ class ExecutionProcessor:
|
||||
|
||||
queued_node_exec = execution_queue.get()
|
||||
|
||||
# Check if this node should be skipped due to optional credentials
|
||||
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
|
||||
log_metadata.info(
|
||||
f"Skipping node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id} - optional credentials not configured"
|
||||
)
|
||||
# Mark the node as completed without executing
|
||||
# No outputs will be produced, so downstream nodes won't trigger
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
exec_id=queued_node_exec.node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
)
|
||||
continue
|
||||
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id}",
|
||||
@@ -1037,7 +984,6 @@ class ExecutionProcessor:
|
||||
execution_stats,
|
||||
execution_stats_lock,
|
||||
),
|
||||
nodes_to_skip=graph_exec.nodes_to_skip,
|
||||
),
|
||||
self.node_execution_loop,
|
||||
)
|
||||
@@ -1317,40 +1263,12 @@ class ExecutionProcessor:
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
# Check if we've already sent a notification for this user+agent combo.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
try:
|
||||
redis_client = redis.get_redis()
|
||||
# SET NX returns True only if the key was newly set (didn't exist)
|
||||
is_new_notification = redis_client.set(
|
||||
redis_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||||
)
|
||||
if not is_new_notification:
|
||||
# Already notified for this user+agent, skip all notifications
|
||||
logger.debug(
|
||||
f"Skipping duplicate insufficient funds notification for "
|
||||
f"user={user_id}, graph={graph_id}"
|
||||
)
|
||||
return
|
||||
except Exception as redis_error:
|
||||
# If Redis fails, log and continue to send the notification
|
||||
# (better to occasionally duplicate than to never notify)
|
||||
logger.warning(
|
||||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||||
f"{redis_error}"
|
||||
)
|
||||
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
# Queue user email notification
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
@@ -1364,7 +1282,6 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
# Send Discord system alert
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
|
||||
@@ -1,560 +0,0 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import ZeroBalanceData
|
||||
from backend.executor.manager import (
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
|
||||
ExecutionProcessor,
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def async_iter(items):
|
||||
"""Helper to create an async iterator from a list."""
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that the first insufficient funds notification sends a Discord alert."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72, # $0.72
|
||||
amount=-714, # Attempting to spend $7.14
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate first-time notification (set returns True)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = True # Key was newly set
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify notification was queued
|
||||
mock_queue_notif.assert_called_once()
|
||||
notification_call = mock_queue_notif.call_args[0][0]
|
||||
assert notification_call.type == NotificationType.ZERO_BALANCE
|
||||
assert notification_call.user_id == user_id
|
||||
assert isinstance(notification_call.data, ZeroBalanceData)
|
||||
assert notification_call.data.current_balance == 72
|
||||
|
||||
# Verify Redis was checked with correct key pattern
|
||||
expected_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
mock_redis_client.set.assert_called_once()
|
||||
call_args = mock_redis_client.set.call_args
|
||||
assert call_args[0][0] == expected_key
|
||||
assert call_args[1]["nx"] is True
|
||||
|
||||
# Verify Discord alert was sent
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
discord_message = mock_client.discord_system_alert.call_args[0][0]
|
||||
assert "Insufficient Funds Alert" in discord_message
|
||||
assert "test@example.com" in discord_message
|
||||
assert "Test Agent" in discord_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate duplicate notification (set returns False/None)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = None # Key already existed
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was NOT queued (deduplication worked)
|
||||
mock_queue_notif.assert_not_called()
|
||||
|
||||
# Verify Discord alert was NOT sent (deduplication worked)
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that different agents for the same user get separate Discord alerts."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id_1 = "test-graph-111"
|
||||
graph_id_2 = "test-graph-222"
|
||||
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch("backend.executor.manager.queue_notification"), patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
# Both calls return True (first time for each agent)
|
||||
mock_redis_client.set.return_value = True
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# First agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_1,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Second agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_2,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify Discord alerts were sent for both agents
|
||||
assert mock_client.discord_system_alert.call_count == 2
|
||||
|
||||
# Verify Redis was called with different keys
|
||||
assert mock_redis_client.set.call_count == 2
|
||||
calls = mock_redis_client.set.call_args_list
|
||||
assert (
|
||||
calls[0][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_1}"
|
||||
)
|
||||
assert (
|
||||
calls[1][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_2}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
|
||||
"""Test that clearing notifications removes all keys for a user."""
|
||||
|
||||
user_id = "test-user-123"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return some keys as an async iterator
|
||||
mock_keys = [
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-2",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-3",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = async_iter(mock_keys)
|
||||
# delete is awaited, so use AsyncMock
|
||||
mock_redis_client.delete = AsyncMock(return_value=3)
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify correct pattern was used
|
||||
expected_pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
mock_redis_client.scan_iter.assert_called_once_with(match=expected_pattern)
|
||||
|
||||
# Verify delete was called with all keys
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
# Verify return value
|
||||
assert result == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestServer):
|
||||
"""Test clearing notifications when there are no keys to clear."""
|
||||
|
||||
user_id = "test-user-no-notifications"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return no keys as an async iterator
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify delete was not called
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
# Verify return value
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_handles_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that clearing notifications handles Redis errors gracefully."""
|
||||
|
||||
user_id = "test-user-redis-error"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock get_redis_async to raise an error
|
||||
mock_redis_module.get_redis_async = AsyncMock(
|
||||
side_effect=Exception("Redis connection failed")
|
||||
)
|
||||
|
||||
# Clear notifications should not raise, just return 0
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify it returned 0 (graceful failure)
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that both email and Discord notifications are still sent when Redis fails."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to raise an error
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.side_effect = Exception("Redis connection error")
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was still queued despite Redis error
|
||||
mock_queue_notif.assert_called_once()
|
||||
|
||||
# Verify Discord alert was still sent despite Redis error
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_grant(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding GRANT credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-grant-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
# Create a concrete instance
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with GRANT type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500, # Positive amount
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
is_active=True, # Active transaction
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding TOP_UP credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-topup-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
mock_redis_client.delete = AsyncMock(return_value=0)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with TOP_UP type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000, # Positive amount
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was attempted
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_inactive_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for inactive transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-inactive"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with is_active=False (should NOT clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=False, # Inactive - pending Stripe payment
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_usage_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for USAGE transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-usage"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with USAGE type (spending, should NOT clear)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-100, # Negative - spending credits
|
||||
transaction_type=CreditTransactionType.USAGE,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_clears_notifications(server: SpinTestServer):
|
||||
"""Test that _enable_transaction clears notification flags when enabling a TOP_UP."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-enable"
|
||||
|
||||
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
|
||||
"backend.data.credit.query_raw_with_schema"
|
||||
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock finding the pending transaction
|
||||
mock_transaction = MagicMock()
|
||||
mock_transaction.amount = 1000
|
||||
mock_transaction.type = CreditTransactionType.TOP_UP
|
||||
mock_credit_tx.prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_transaction
|
||||
)
|
||||
|
||||
# Mock the query to return updated balance
|
||||
mock_query.return_value = [{"balance": 1500}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _enable_transaction (simulates Stripe checkout completion)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
await credit_model._enable_transaction(
|
||||
transaction_key="cs_test_123",
|
||||
user_id=user_id,
|
||||
metadata=SafeJson({"payment": "completed"}),
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
@@ -3,16 +3,16 @@ import logging
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
|
||||
import backend.api.features.library.model
|
||||
import backend.api.features.store.model
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.api.rest_api import AgentServer
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.data_manipulation import FindInDictionaryBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.data.model import User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
@@ -356,7 +356,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -444,7 +444,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -485,7 +485,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
|
||||
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
|
||||
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
slug=test_graph.id,
|
||||
@@ -514,7 +514,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
|
||||
admin_user = await create_test_user(alt_user=True)
|
||||
await server.agent_server.test_review_store_listing(
|
||||
backend.api.features.store.model.ReviewSubmissionRequest(
|
||||
backend.server.v2.store.model.ReviewSubmissionRequest(
|
||||
store_listing_version_id=slv_id,
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
@@ -523,7 +523,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Add the approved store listing to the admin user's library so they can execute it
|
||||
from backend.api.features.library.db import add_store_agent_to_library
|
||||
from backend.server.v2.library.db import add_store_agent_to_library
|
||||
|
||||
await add_store_agent_to_library(
|
||||
store_listing_version_id=slv_id, user_id=admin_user.id
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from backend.api.model import CreateGraph
|
||||
from backend.data import db
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -239,19 +239,14 @@ async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[dict[str, dict[str, str]], set[str]]:
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors
|
||||
and a set of nodes that should be skipped due to optional missing credentials.
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
|
||||
"""
|
||||
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
@@ -261,46 +256,27 @@ async def _validate_node_input_credentials(
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
# Track if any credential field is missing for this node
|
||||
has_missing_credentials = False
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
try:
|
||||
# Check nodes_input_masks first, then input_default
|
||||
field_value = None
|
||||
if (
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
):
|
||||
field_value = node_input_mask[field_name]
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
elif field_name in node.input_default:
|
||||
# For optional credentials, don't use input_default - treat as missing
|
||||
# This prevents stale credential IDs from failing validation
|
||||
if node.credentials_optional:
|
||||
field_value = None
|
||||
else:
|
||||
field_value = node.input_default[field_name]
|
||||
|
||||
# Check if credentials are missing (None, empty, or not present)
|
||||
if field_value is None or (
|
||||
isinstance(field_value, dict) and not field_value.get("id")
|
||||
):
|
||||
has_missing_credentials = True
|
||||
# If node has credentials_optional flag, mark for skipping instead of error
|
||||
if node.credentials_optional:
|
||||
continue # Don't add error, will be marked for skip after loop
|
||||
else:
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(field_value)
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
# Missing credentials
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
except ValidationError as e:
|
||||
# Validation error means credentials were provided but invalid
|
||||
# This should always be an error, even if optional
|
||||
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
||||
continue
|
||||
|
||||
@@ -311,7 +287,6 @@ async def _validate_node_input_credentials(
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any errors fetching credentials
|
||||
# If credentials were explicitly configured but unavailable, it's an error
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
@@ -338,19 +313,7 @@ async def _validate_node_input_credentials(
|
||||
] = "Invalid credentials: type/provider mismatch"
|
||||
continue
|
||||
|
||||
# If node has optional credentials and any are missing, mark for skipping
|
||||
# But only if there are no other errors for this node
|
||||
if (
|
||||
has_missing_credentials
|
||||
and node.credentials_optional
|
||||
and node.id not in credential_errors
|
||||
):
|
||||
nodes_to_skip.add(node.id)
|
||||
logger.info(
|
||||
f"Node #{node.id} will be skipped: optional credentials not configured"
|
||||
)
|
||||
|
||||
return credential_errors, nodes_to_skip
|
||||
return credential_errors
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
@@ -392,25 +355,21 @@ async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node,
|
||||
along with a set of nodes that should be skipped due to optional missing credentials.
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node
|
||||
"""
|
||||
# Get input validation errors
|
||||
node_input_errors = GraphModel.validate_graph_get_errors(
|
||||
graph, for_run=True, nodes_input_masks=nodes_input_masks
|
||||
)
|
||||
|
||||
# Get credential input/availability/validation errors and nodes to skip
|
||||
node_credential_input_errors, nodes_to_skip = (
|
||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
||||
# Get credential input/availability/validation errors
|
||||
node_credential_input_errors = await _validate_node_input_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
|
||||
# Merge credential errors with structural errors
|
||||
@@ -419,7 +378,7 @@ async def validate_graph_with_credentials(
|
||||
node_input_errors[node_id] = {}
|
||||
node_input_errors[node_id].update(field_errors)
|
||||
|
||||
return node_input_errors, nodes_to_skip
|
||||
return node_input_errors
|
||||
|
||||
|
||||
async def _construct_starting_node_execution_input(
|
||||
@@ -427,7 +386,7 @@ async def _construct_starting_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
@@ -441,14 +400,11 @@ async def _construct_starting_node_execution_input(
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
|
||||
and the corresponding input data for that node.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured)
|
||||
]
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
# Use new validation function that includes credentials
|
||||
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
validation_errors = await validate_graph_with_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
n_error_nodes = len(validation_errors)
|
||||
@@ -489,7 +445,7 @@ async def _construct_starting_node_execution_input(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input, nodes_to_skip
|
||||
return nodes_input
|
||||
|
||||
|
||||
async def validate_and_construct_node_execution_input(
|
||||
@@ -500,7 +456,7 @@ async def validate_and_construct_node_execution_input(
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
is_sub_graph: bool = False,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
@@ -517,7 +473,6 @@ async def validate_and_construct_node_execution_input(
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured).
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
@@ -559,16 +514,14 @@ async def validate_and_construct_node_execution_input(
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
|
||||
starting_nodes_input, nodes_to_skip = (
|
||||
await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
starting_nodes_input = await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
|
||||
return graph, starting_nodes_input, nodes_input_masks
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
@@ -826,9 +779,6 @@ async def add_graph_execution(
|
||||
|
||||
# Use existing execution's compiled input masks
|
||||
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
|
||||
# For resumed executions, nodes_to_skip was already determined at creation time
|
||||
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
|
||||
else:
|
||||
@@ -837,7 +787,7 @@ async def add_graph_execution(
|
||||
)
|
||||
|
||||
# Create new execution
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
@@ -886,7 +836,6 @@ async def add_graph_execution(
|
||||
try:
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
|
||||
@@ -367,13 +367,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
||||
nodes_to_skip: set[str] = set()
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip,
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
@@ -459,212 +456,3 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
# Both executions should succeed (though they create different objects)
|
||||
assert result1 == mock_graph_exec
|
||||
assert result2 == mock_graph_exec_2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns nodes_to_skip set
|
||||
for nodes with credentials_optional=True and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=True
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-optional-creds"
|
||||
mock_node.credentials_optional = True
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in nodes_to_skip, not in errors
|
||||
assert mock_node.id in nodes_to_skip
|
||||
assert mock_node.id not in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_required_missing_creds_error(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns errors
|
||||
for nodes with credentials_optional=False and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=False (required)
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-required-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in errors, not in nodes_to_skip
|
||||
assert mock_node.id in errors
|
||||
assert "credentials" in errors[mock_node.id]
|
||||
assert "required" in errors[mock_node.id]["credentials"].lower()
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that validate_graph_with_credentials returns nodes_to_skip set
|
||||
from _validate_node_input_credentials.
|
||||
"""
|
||||
from backend.executor.utils import validate_graph_with_credentials
|
||||
|
||||
# Mock _validate_node_input_credentials to return specific values
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils._validate_node_input_credentials"
|
||||
)
|
||||
expected_errors = {"node1": {"field": "error"}}
|
||||
expected_nodes_to_skip = {"node2", "node3"}
|
||||
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
|
||||
|
||||
# Mock GraphModel with validate_graph_get_errors method
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.validate_graph_get_errors.return_value = {}
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip is passed through
|
||||
assert nodes_to_skip == expected_nodes_to_skip
|
||||
assert "node1" in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
"""
|
||||
Test that add_graph_execution properly passes nodes_to_skip
|
||||
to the graph execution entry.
|
||||
"""
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
# Mock data
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
inputs = {"test_input": "test_value"}
|
||||
graph_version = 1
|
||||
|
||||
# Mock the graph object
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = graph_version
|
||||
|
||||
# Starting nodes and masks
|
||||
starting_nodes_input = [("node1", {"input1": "value1"})]
|
||||
compiled_nodes_input_masks = {}
|
||||
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
|
||||
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_to_entry(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mocker.MagicMock()
|
||||
|
||||
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
|
||||
|
||||
# Setup mocks
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_udb = mocker.patch("backend.executor.utils.user_db")
|
||||
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
|
||||
# Setup returns - include nodes_to_skip in the tuple
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip, # This should be passed through
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
|
||||
mock_user = mocker.MagicMock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_settings = mocker.MagicMock()
|
||||
mock_settings.human_in_the_loop_safe_mode = True
|
||||
|
||||
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
|
||||
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
|
||||
mock_get_queue.return_value = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
|
||||
|
||||
# Call the function
|
||||
await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
@@ -8,7 +8,6 @@ from .discord import DiscordOAuthHandler
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
from .reddit import RedditOAuthHandler
|
||||
from .twitter import TwitterOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -21,7 +20,6 @@ _ORIGINAL_HANDLERS = [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
RedditOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
TodoistOAuthHandler,
|
||||
]
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class RedditOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Reddit OAuth 2.0 handler.
|
||||
|
||||
Based on the documentation at:
|
||||
- https://github.com/reddit-archive/reddit/wiki/OAuth2
|
||||
|
||||
Notes:
|
||||
- Reddit requires `duration=permanent` to get refresh tokens
|
||||
- Access tokens expire after 1 hour (3600 seconds)
|
||||
- Reddit requires HTTP Basic Auth for token requests
|
||||
- Reddit requires a unique User-Agent header
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.REDDIT
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = [
|
||||
"identity", # Get username, verify auth
|
||||
"read", # Access posts and comments
|
||||
"submit", # Submit new posts and comments
|
||||
"edit", # Edit own posts and comments
|
||||
"history", # Access user's post history
|
||||
"privatemessages", # Access inbox and send private messages
|
||||
"flair", # Access and set flair on posts/subreddits
|
||||
]
|
||||
|
||||
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
|
||||
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
|
||||
USERNAME_URL = "https://oauth.reddit.com/api/v1/me"
|
||||
REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
"""Generate Reddit OAuth 2.0 authorization URL"""
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
"duration": "permanent", # Required for refresh tokens
|
||||
}
|
||||
|
||||
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
"""Exchange authorization code for access tokens"""
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
|
||||
# Reddit requires HTTP Basic Auth for token requests
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.TOKEN_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise ValueError(
|
||||
f"Reddit token exchange failed: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
||||
|
||||
username = await self._get_username(tokens["access_token"])
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
||||
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
async def _get_username(self, access_token: str) -> str:
|
||||
"""Get the username from the access token"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
response = await Requests().get(self.USERNAME_URL, headers=headers)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"Failed to get Reddit username: {response.status}")
|
||||
|
||||
data = response.json()
|
||||
return data.get("name", "unknown")
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
"""Refresh access tokens using refresh token"""
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError("No refresh token available")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.TOKEN_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise ValueError(
|
||||
f"Reddit token refresh failed: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
||||
|
||||
username = await self._get_username(tokens["access_token"])
|
||||
|
||||
# Reddit may or may not return a new refresh token
|
||||
new_refresh_token = tokens.get("refresh_token")
|
||||
if new_refresh_token:
|
||||
refresh_token: SecretStr | None = SecretStr(new_refresh_token)
|
||||
elif credentials.refresh_token:
|
||||
# Keep the existing refresh token
|
||||
refresh_token = credentials.refresh_token
|
||||
else:
|
||||
refresh_token = None
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=refresh_token,
|
||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""Revoke the access token"""
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.REVOKE_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
# Reddit returns 204 No Content on successful revocation
|
||||
return response.ok
|
||||
@@ -149,10 +149,10 @@ async def setup_webhook_for_block(
|
||||
async def migrate_legacy_triggered_graphs():
|
||||
from prisma.models import AgentGraph
|
||||
|
||||
from backend.api.features.library.db import create_preset
|
||||
from backend.api.features.library.model import LibraryAgentPresetCreatable
|
||||
from backend.data.graph import AGENT_GRAPH_INCLUDE, GraphModel, set_node_webhook
|
||||
from backend.data.model import is_credentials_field_name
|
||||
from backend.server.v2.library.db import create_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPresetCreatable
|
||||
|
||||
triggered_graphs = [
|
||||
GraphModel.from_db(_graph)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.app import run_processes
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionEventType,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.conn_manager import ConnectionManager
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
29
autogpt_platform/backend/backend/server/external/api.py
vendored
Normal file
29
autogpt_platform/backend/backend/server/external/api.py
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes.integrations import integrations_router
|
||||
from .routes.tools import tools_router
|
||||
from .routes.v1 import v1_router
|
||||
|
||||
external_app = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_app.add_middleware(SecurityHeadersMiddleware)
|
||||
external_app.include_router(v1_router, prefix="/v1")
|
||||
external_app.include_router(tools_router, prefix="/v1")
|
||||
external_app.include_router(integrations_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_app,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
@@ -16,8 +16,6 @@ from fastapi import APIRouter, Body, HTTPException, Path, Security, status
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.integrations.models import get_all_provider_names
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
@@ -30,6 +28,8 @@ from backend.data.model import (
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.integrations.models import get_all_provider_names
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -14,11 +14,11 @@ from fastapi import APIRouter, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.server.v2.chat.tools.models import ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -8,29 +8,23 @@ from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.data.block
|
||||
from backend.api.external.middleware import require_permission
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .integrations import integrations_router
|
||||
from .tools import tools_router
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
v1_router = APIRouter()
|
||||
|
||||
v1_router.include_router(integrations_router)
|
||||
v1_router.include_router(tools_router)
|
||||
|
||||
|
||||
class UserInfoResponse(BaseModel):
|
||||
id: str
|
||||
@@ -17,8 +17,6 @@ from fastapi import (
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||
|
||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.data.graph import NodeModel, get_graph, set_node_webhook
|
||||
from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
@@ -47,6 +45,13 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.integrations.models import (
|
||||
ProviderConstants,
|
||||
ProviderNamesResponse,
|
||||
get_all_provider_names,
|
||||
)
|
||||
from backend.server.v2.library.db import set_preset_webhook, update_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
from backend.util.exceptions import (
|
||||
GraphNotInLibraryError,
|
||||
MissingConfigError,
|
||||
@@ -55,8 +60,6 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import ProviderConstants, ProviderNamesResponse, get_all_provider_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
@@ -3,7 +3,7 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -16,33 +16,37 @@ from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
import backend.api.features.executions.review.routes
|
||||
import backend.api.features.library.db
|
||||
import backend.api.features.library.model
|
||||
import backend.api.features.library.routes
|
||||
import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
import backend.api.features.v1
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.routers.oauth
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.user_auth
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
import backend.server.v2.admin.execution_analytics_routes
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.builder
|
||||
import backend.server.v2.builder.routes
|
||||
import backend.server.v2.chat.routes as chat_routes
|
||||
import backend.server.v2.executions.review.routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
import backend.server.v2.otto.routes
|
||||
import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.server.utils.cors import build_cors_params
|
||||
from backend.util import json
|
||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||
from backend.util.exceptions import (
|
||||
@@ -53,13 +57,6 @@ from backend.util.exceptions import (
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
|
||||
from .external.fastapi_app import external_api
|
||||
from .features.analytics import router as analytics_router
|
||||
from .features.integrations.router import router as integrations_router
|
||||
from .middleware.security import SecurityHeadersMiddleware
|
||||
from .utils.cors import build_cors_params
|
||||
from .utils.openapi import sort_openapi
|
||||
|
||||
settings = backend.util.settings.Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -113,7 +110,7 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
@@ -180,9 +177,6 @@ app.add_middleware(GZipMiddleware, minimum_size=50_000) # 50KB threshold
|
||||
# Add 401 responses to authenticated endpoints in OpenAPI spec
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Sort OpenAPI schema to eliminate diff on refactors
|
||||
sort_openapi(app)
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
app,
|
||||
@@ -261,52 +255,42 @@ app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
|
||||
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
app.include_router(
|
||||
integrations_router,
|
||||
prefix="/api/integrations",
|
||||
tags=["v1", "integrations"],
|
||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
)
|
||||
app.include_router(
|
||||
analytics_router,
|
||||
prefix="/api/analytics",
|
||||
tags=["analytics"],
|
||||
backend.server.v2.builder.routes.router, tags=["v2"], prefix="/api/builder"
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.builder.routes.router, tags=["v2"], prefix="/api/builder"
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.store_admin_routes.router,
|
||||
backend.server.v2.admin.store_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/store",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.credit_admin_routes.router,
|
||||
backend.server.v2.admin.credit_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/credits",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.execution_analytics_routes.router,
|
||||
backend.server.v2.admin.execution_analytics_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
backend.server.v2.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
prefix="/api/review",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.otto.routes.router, tags=["v2", "otto"], prefix="/api/otto"
|
||||
backend.server.v2.otto.routes.router, tags=["v2", "otto"], prefix="/api/otto"
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
backend.api.features.postmark.postmark.router,
|
||||
backend.server.routers.postmark.postmark.router,
|
||||
tags=["v1", "email"],
|
||||
prefix="/api/email",
|
||||
)
|
||||
@@ -316,12 +300,17 @@ app.include_router(
|
||||
prefix="/api/chat",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.oauth.router,
|
||||
backend.server.routers.oauth.router,
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.routers.user_auth.router,
|
||||
tags=["user-auth"],
|
||||
prefix="/api",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
app.mount("/external-api", external_app)
|
||||
|
||||
|
||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||
@@ -374,7 +363,7 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
graph_version: Optional[int] = None,
|
||||
node_input: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
return await backend.api.features.v1.execute_graph(
|
||||
return await backend.server.routers.v1.execute_graph(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
@@ -389,16 +378,16 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
user_id: str,
|
||||
for_export: bool = False,
|
||||
):
|
||||
return await backend.api.features.v1.get_graph(
|
||||
return await backend.server.routers.v1.get_graph(
|
||||
graph_id, user_id, graph_version, for_export
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_create_graph(
|
||||
create_graph: backend.api.features.v1.CreateGraph,
|
||||
create_graph: backend.server.routers.v1.CreateGraph,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.api.features.v1.create_new_graph(create_graph, user_id)
|
||||
return await backend.server.routers.v1.create_new_graph(create_graph, user_id)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_status(graph_exec_id: str, user_id: str):
|
||||
@@ -414,45 +403,45 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
@staticmethod
|
||||
async def test_delete_graph(graph_id: str, user_id: str):
|
||||
"""Used for clean-up after a test run"""
|
||||
await backend.api.features.library.db.delete_library_agent_by_graph_id(
|
||||
await backend.server.v2.library.db.delete_library_agent_by_graph_id(
|
||||
graph_id=graph_id, user_id=user_id
|
||||
)
|
||||
return await backend.api.features.v1.delete_graph(graph_id, user_id)
|
||||
return await backend.server.routers.v1.delete_graph(graph_id, user_id)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
|
||||
return await backend.api.features.library.routes.presets.list_presets(
|
||||
return await backend.server.v2.library.routes.presets.list_presets(
|
||||
user_id=user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_preset(preset_id: str, user_id: str):
|
||||
return await backend.api.features.library.routes.presets.get_preset(
|
||||
return await backend.server.v2.library.routes.presets.get_preset(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_create_preset(
|
||||
preset: backend.api.features.library.model.LibraryAgentPresetCreatable,
|
||||
preset: backend.server.v2.library.model.LibraryAgentPresetCreatable,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.api.features.library.routes.presets.create_preset(
|
||||
return await backend.server.v2.library.routes.presets.create_preset(
|
||||
preset=preset, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_update_preset(
|
||||
preset_id: str,
|
||||
preset: backend.api.features.library.model.LibraryAgentPresetUpdatable,
|
||||
preset: backend.server.v2.library.model.LibraryAgentPresetUpdatable,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.api.features.library.routes.presets.update_preset(
|
||||
return await backend.server.v2.library.routes.presets.update_preset(
|
||||
preset_id=preset_id, preset=preset, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_delete_preset(preset_id: str, user_id: str):
|
||||
return await backend.api.features.library.routes.presets.delete_preset(
|
||||
return await backend.server.v2.library.routes.presets.delete_preset(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
|
||||
@@ -462,7 +451,7 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
user_id: str,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
return await backend.api.features.library.routes.presets.execute_preset(
|
||||
return await backend.server.v2.library.routes.presets.execute_preset(
|
||||
preset_id=preset_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs or {},
|
||||
@@ -471,20 +460,18 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
@staticmethod
|
||||
async def test_create_store_listing(
|
||||
request: backend.api.features.store.model.StoreSubmissionRequest, user_id: str
|
||||
request: backend.server.v2.store.model.StoreSubmissionRequest, user_id: str
|
||||
):
|
||||
return await backend.api.features.store.routes.create_submission(
|
||||
request, user_id
|
||||
)
|
||||
return await backend.server.v2.store.routes.create_submission(request, user_id)
|
||||
|
||||
### ADMIN ###
|
||||
|
||||
@staticmethod
|
||||
async def test_review_store_listing(
|
||||
request: backend.api.features.store.model.ReviewSubmissionRequest,
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.api.features.admin.store_admin_routes.review_submission(
|
||||
return await backend.server.v2.admin.store_admin_routes.review_submission(
|
||||
request.store_listing_version_id, request, user_id
|
||||
)
|
||||
|
||||
@@ -494,7 +481,10 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from .features.integrations.router import create_credentials, get_credential
|
||||
from backend.server.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
@@ -6,11 +6,10 @@ from typing import Annotated
|
||||
import fastapi
|
||||
import pydantic
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from autogpt_libs.auth.dependencies import requires_user
|
||||
|
||||
import backend.data.analytics
|
||||
|
||||
router = fastapi.APIRouter(dependencies=[fastapi.Security(requires_user)])
|
||||
router = fastapi.APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Example of analytics tests with improved error handling and assertions."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.routers.analytics as analytics_routes
|
||||
from backend.server.test_helpers import (
|
||||
assert_error_response_structure,
|
||||
assert_mock_called_with_partial,
|
||||
assert_response_status,
|
||||
safe_parse_json,
|
||||
)
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_log_raw_metric_success_improved(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging with improved assertions."""
|
||||
# Mock the analytics function
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
# Improved assertions with better error messages
|
||||
assert_response_status(response, 200, "Metric logging should succeed")
|
||||
response_data = safe_parse_json(response, "Metric response parsing")
|
||||
|
||||
assert response_data == "metric-123-uuid", f"Unexpected response: {response_data}"
|
||||
|
||||
# Verify the function was called with correct parameters
|
||||
assert_mock_called_with_partial(
|
||||
mock_log_metric,
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response_data}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success_improved",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_metric_invalid_request_improved() -> None:
|
||||
"""Test invalid metric request with improved error assertions."""
|
||||
# Test missing required fields
|
||||
response = client.post("/log_raw_metric", json={})
|
||||
|
||||
error_data = assert_error_response_structure(
|
||||
response, expected_status=422, expected_error_fields=["loc", "msg", "type"]
|
||||
)
|
||||
|
||||
# Verify specific error details
|
||||
detail = error_data["detail"]
|
||||
assert isinstance(detail, list), "Error detail should be a list"
|
||||
assert len(detail) > 0, "Should have at least one error"
|
||||
|
||||
# Check that required fields are mentioned in errors
|
||||
error_fields = [error["loc"][-1] for error in detail if "loc" in error]
|
||||
assert "metric_name" in error_fields, "Should report missing metric_name"
|
||||
assert "metric_value" in error_fields, "Should report missing metric_value"
|
||||
assert "data_string" in error_fields, "Should report missing data_string"
|
||||
|
||||
|
||||
def test_log_raw_metric_type_validation_improved(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test metric type validation with improved assertions."""
|
||||
# Mock the analytics function to avoid event loop issues
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Mock(id="test-id"),
|
||||
)
|
||||
|
||||
invalid_requests = [
|
||||
{
|
||||
"data": {
|
||||
"metric_name": "test",
|
||||
"metric_value": "not_a_number", # Invalid type
|
||||
"data_string": "test",
|
||||
},
|
||||
"expected_error": "Input should be a valid number",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"metric_name": "", # Empty string
|
||||
"metric_value": 1.0,
|
||||
"data_string": "test",
|
||||
},
|
||||
"expected_error": "String should have at least 1 character",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"metric_name": "test",
|
||||
"metric_value": 123, # Valid number
|
||||
"data_string": "", # Empty data_string
|
||||
},
|
||||
"expected_error": "String should have at least 1 character",
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in invalid_requests:
|
||||
response = client.post("/log_raw_metric", json=test_case["data"])
|
||||
|
||||
error_data = assert_error_response_structure(response, expected_status=422)
|
||||
|
||||
# Check that expected error is in the response
|
||||
error_text = json.dumps(error_data)
|
||||
assert (
|
||||
test_case["expected_error"] in error_text
|
||||
or test_case["expected_error"].lower() in error_text.lower()
|
||||
), f"Expected error '{test_case['expected_error']}' not found in: {error_text}"
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Example of parametrized tests for analytics endpoints."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.routers.analytics as analytics_routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string,test_id",
|
||||
[
|
||||
(100, "api_calls_count", "external_api", "integer_value"),
|
||||
(0, "error_count", "no_errors", "zero_value"),
|
||||
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||
(999999999, "large_number", "max_value", "large_number"),
|
||||
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_values_parametrized(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
metric_value: float,
|
||||
metric_name: str,
|
||||
data_string: str,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values using parametrize."""
|
||||
# Mock the analytics function
|
||||
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": metric_name,
|
||||
"metric_value": metric_value,
|
||||
"data_string": data_string,
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
# Better error handling
|
||||
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||
response_data = response.json()
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"metric_id": response_data, "test_case": test_id}, indent=2, sort_keys=True
|
||||
),
|
||||
f"analytics_metric_{test_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"), # Missing all fields
|
||||
({"metric_name": "test"}, "Field required"), # Missing metric_value
|
||||
(
|
||||
{"metric_name": "test", "metric_value": "not_a_number"},
|
||||
"Input should be a valid number",
|
||||
), # Invalid type
|
||||
(
|
||||
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||
"String should have at least 1 character",
|
||||
), # Empty name
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_invalid_requests_parametrized(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test invalid metric requests with parametrize."""
|
||||
# Mock the analytics function to avoid event loop issues
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Mock(id="test-id"),
|
||||
)
|
||||
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail
|
||||
# Verify error message contains expected error
|
||||
error_text = json.dumps(error_detail)
|
||||
assert expected_error in error_text or expected_error.lower() in error_text.lower()
|
||||
@@ -0,0 +1,284 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.routers.analytics as analytics_routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging"""
|
||||
|
||||
# Mock the analytics function
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data == "metric-123-uuid"
|
||||
|
||||
# Verify the function was called with correct parameters
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_metric_various_values(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values"""
|
||||
|
||||
# Mock the analytics function
|
||||
mock_result = Mock(id="metric-456-uuid")
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
# Test with integer value
|
||||
request_data = {
|
||||
"metric_name": "api_calls_count",
|
||||
"metric_value": 100,
|
||||
"data_string": "external_api",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test with zero value
|
||||
request_data = {
|
||||
"metric_name": "error_count",
|
||||
"metric_value": 0,
|
||||
"data_string": "no_errors",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test with negative value
|
||||
request_data = {
|
||||
"metric_name": "temperature_delta",
|
||||
"metric_value": -5.2,
|
||||
"data_string": "cooling",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Snapshot the last response
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_various_values",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging"""
|
||||
|
||||
# Mock the analytics function
|
||||
mock_result = Mock(id="analytics-789-uuid")
|
||||
|
||||
mock_log_analytics = mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "user_action",
|
||||
"data": {
|
||||
"action": "button_click",
|
||||
"button_id": "submit_form",
|
||||
"timestamp": "2023-01-01T00:00:00Z",
|
||||
"metadata": {
|
||||
"form_type": "registration",
|
||||
"fields_filled": 5,
|
||||
},
|
||||
},
|
||||
"data_index": "button_click_submit_form",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data == "analytics-789-uuid"
|
||||
|
||||
# Verify the function was called with correct parameters
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
)
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"analytics_id": response_data}, indent=2, sort_keys=True),
|
||||
"analytics_log_analytics_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_complex_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw analytics logging with complex nested data"""
|
||||
|
||||
# Mock the analytics function
|
||||
mock_result = Mock(id="analytics-complex-uuid")
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "agent_execution",
|
||||
"data": {
|
||||
"agent_id": "agent_123",
|
||||
"execution_id": "exec_456",
|
||||
"status": "completed",
|
||||
"duration_ms": 3500,
|
||||
"nodes_executed": 15,
|
||||
"blocks_used": [
|
||||
{"block_id": "llm_block", "count": 3},
|
||||
{"block_id": "http_block", "count": 5},
|
||||
{"block_id": "code_block", "count": 2},
|
||||
],
|
||||
"errors": [],
|
||||
"metadata": {
|
||||
"trigger": "manual",
|
||||
"user_tier": "premium",
|
||||
"environment": "production",
|
||||
},
|
||||
},
|
||||
"data_index": "agent_123_exec_456",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
|
||||
# Snapshot test the complex data structure
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{
|
||||
"analytics_id": response_data,
|
||||
"logged_data": request_data["data"],
|
||||
},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
"analytics_log_analytics_complex_data",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_metric_invalid_request() -> None:
|
||||
"""Test raw metric logging with invalid request data"""
|
||||
# Missing required fields
|
||||
response = client.post("/log_raw_metric", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Invalid metric_value type
|
||||
response = client.post(
|
||||
"/log_raw_metric",
|
||||
json={
|
||||
"metric_name": "test",
|
||||
"metric_value": "not_a_number",
|
||||
"data_string": "test",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
# Missing data_string
|
||||
response = client.post(
|
||||
"/log_raw_metric",
|
||||
json={
|
||||
"metric_name": "test",
|
||||
"metric_value": 1.0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_log_raw_analytics_invalid_request() -> None:
|
||||
"""Test raw analytics logging with invalid request data"""
|
||||
# Missing required fields
|
||||
response = client.post("/log_raw_analytics", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Invalid data type (should be dict)
|
||||
response = client.post(
|
||||
"/log_raw_analytics",
|
||||
json={
|
||||
"type": "test",
|
||||
"data": "not_a_dict",
|
||||
"data_index": "test",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
# Missing data_index
|
||||
response = client.post(
|
||||
"/log_raw_analytics",
|
||||
json={
|
||||
"type": "test",
|
||||
"data": {"key": "value"},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -5,11 +5,11 @@ Implements OAuth 2.0 Authorization Code flow with PKCE support.
|
||||
|
||||
Flow:
|
||||
1. User clicks "Login with AutoGPT" in 3rd party app
|
||||
2. App redirects user to /auth/authorize with client_id, redirect_uri, scope, state
|
||||
2. App redirects user to /oauth/authorize with client_id, redirect_uri, scope, state
|
||||
3. User sees consent screen (if not already logged in, redirects to login first)
|
||||
4. User approves → backend creates authorization code
|
||||
5. User redirected back to app with code
|
||||
6. App exchanges code for access/refresh tokens at /api/oauth/token
|
||||
6. App exchanges code for access/refresh tokens at /oauth/token
|
||||
7. App uses access token to call external API endpoints
|
||||
"""
|
||||
|
||||
@@ -28,7 +28,7 @@ from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
||||
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
||||
from prisma.models import User as PrismaUser
|
||||
|
||||
from backend.api.rest_api import app
|
||||
from backend.server.rest_api import app
|
||||
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
@@ -4,15 +4,12 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from backend.api.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.data.user import (
|
||||
get_user_by_email,
|
||||
set_user_email_verification,
|
||||
unsubscribe_user_by_token,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import (
|
||||
from backend.server.routers.postmark.models import (
|
||||
PostmarkBounceEnum,
|
||||
PostmarkBounceWebhook,
|
||||
PostmarkClickWebhook,
|
||||
@@ -22,6 +19,8 @@ from .models import (
|
||||
PostmarkSubscriptionChangeWebhook,
|
||||
PostmarkWebhook,
|
||||
)
|
||||
from backend.server.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
910
autogpt_platform/backend/backend/server/routers/user_auth.py
Normal file
910
autogpt_platform/backend/backend/server/routers/user_auth.py
Normal file
@@ -0,0 +1,910 @@
|
||||
"""
|
||||
User authentication router for native FastAPI auth.
|
||||
|
||||
This router provides endpoints that are compatible with the Supabase Auth API
|
||||
structure, allowing the frontend to migrate without code changes.
|
||||
|
||||
Endpoints:
|
||||
- POST /auth/signup - Register a new user
|
||||
- POST /auth/login - Login with email/password
|
||||
- POST /auth/logout - Logout (clear session)
|
||||
- POST /auth/refresh - Refresh access token
|
||||
- GET /auth/me - Get current user
|
||||
- POST /auth/password/reset - Request password reset email
|
||||
- POST /auth/password/set - Set new password from reset link
|
||||
- GET /auth/verify-email - Verify email from magic link
|
||||
- GET /auth/oauth/google/authorize - Get Google OAuth URL
|
||||
- GET /auth/oauth/google/callback - Handle Google OAuth callback
|
||||
|
||||
Admin Endpoints:
|
||||
- GET /auth/admin/users - List users (admin only)
|
||||
- GET /auth/admin/users/{user_id} - Get user details (admin only)
|
||||
- POST /auth/admin/users/{user_id}/impersonate - Get impersonation token (admin only)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from prisma.models import User
|
||||
|
||||
from backend.data.auth.email_service import get_auth_email_service
|
||||
from backend.data.auth.magic_links import (
|
||||
create_email_verification_link,
|
||||
create_password_reset_link,
|
||||
verify_email_token,
|
||||
verify_password_reset_token,
|
||||
)
|
||||
from backend.data.auth.password import hash_password, needs_rehash, verify_password
|
||||
from backend.data.auth.tokens import (
|
||||
ACCESS_TOKEN_TTL,
|
||||
REFRESH_TOKEN_TTL,
|
||||
create_access_token,
|
||||
create_refresh_token_db,
|
||||
decode_access_token,
|
||||
revoke_all_user_refresh_tokens,
|
||||
revoke_refresh_token,
|
||||
validate_refresh_token,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["user-auth"])
|
||||
|
||||
# Cookie configuration
|
||||
ACCESS_TOKEN_COOKIE = "access_token"
|
||||
REFRESH_TOKEN_COOKIE = "refresh_token"
|
||||
OAUTH_STATE_COOKIE = "oauth_state"
|
||||
|
||||
# Header for admin impersonation (matches existing autogpt_libs pattern)
|
||||
IMPERSONATION_HEADER = "X-Act-As-User-Id"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Admin Role Detection
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _get_admin_domains() -> set[str]:
|
||||
"""Get set of email domains that grant admin role."""
|
||||
domains_str = settings.config.admin_email_domains
|
||||
if not domains_str:
|
||||
return set()
|
||||
return {d.strip().lower() for d in domains_str.split(",") if d.strip()}
|
||||
|
||||
|
||||
def _get_admin_emails() -> set[str]:
|
||||
"""Get set of specific email addresses that grant admin role."""
|
||||
emails_str = settings.config.admin_emails
|
||||
if not emails_str:
|
||||
return set()
|
||||
return {e.strip().lower() for e in emails_str.split(",") if e.strip()}
|
||||
|
||||
|
||||
def get_user_role(email: str) -> str:
|
||||
"""
|
||||
Determine user role based on email.
|
||||
|
||||
Returns "admin" if:
|
||||
- Email domain is in admin_email_domains list
|
||||
- Email is in admin_emails list
|
||||
|
||||
Otherwise returns "authenticated".
|
||||
"""
|
||||
email_lower = email.lower()
|
||||
domain = email_lower.split("@")[-1] if "@" in email_lower else ""
|
||||
|
||||
# Check specific emails first
|
||||
if email_lower in _get_admin_emails():
|
||||
return "admin"
|
||||
|
||||
# Check domains
|
||||
if domain in _get_admin_domains():
|
||||
return "admin"
|
||||
|
||||
return "authenticated"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SignupRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PasswordSetRequest(BaseModel):
|
||||
token: str
|
||||
password: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
email_verified: bool
|
||||
name: Optional[str] = None
|
||||
created_at: datetime
|
||||
role: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(user: User, include_role: bool = False) -> "UserResponse":
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
email_verified=user.emailVerified,
|
||||
name=user.name,
|
||||
created_at=user.createdAt,
|
||||
role=get_user_role(user.email) if include_role else None,
|
||||
)
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""Response matching Supabase auth response structure."""
|
||||
|
||||
user: UserResponse
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class AdminUserListResponse(BaseModel):
|
||||
users: List[UserResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class ImpersonationResponse(BaseModel):
|
||||
access_token: str
|
||||
impersonated_user: UserResponse
|
||||
expires_in: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Cookie Helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _is_production() -> bool:
|
||||
return os.getenv("APP_ENV", "local").lower() in ("production", "prod")
|
||||
|
||||
|
||||
def _set_auth_cookies(response: Response, access_token: str, refresh_token: str):
|
||||
"""Set authentication cookies on the response."""
|
||||
secure = _is_production()
|
||||
|
||||
# Access token: accessible to JavaScript for API calls
|
||||
response.set_cookie(
|
||||
key=ACCESS_TOKEN_COOKIE,
|
||||
value=access_token,
|
||||
httponly=False, # JS needs access for Authorization header
|
||||
secure=secure,
|
||||
samesite="lax",
|
||||
max_age=int(ACCESS_TOKEN_TTL.total_seconds()),
|
||||
path="/",
|
||||
)
|
||||
|
||||
# Refresh token: httpOnly, restricted path
|
||||
response.set_cookie(
|
||||
key=REFRESH_TOKEN_COOKIE,
|
||||
value=refresh_token,
|
||||
httponly=True, # Not accessible to JavaScript
|
||||
secure=secure,
|
||||
samesite="strict",
|
||||
max_age=int(REFRESH_TOKEN_TTL.total_seconds()),
|
||||
path="/api/auth/refresh", # Only sent to refresh endpoint
|
||||
)
|
||||
|
||||
|
||||
def _clear_auth_cookies(response: Response):
|
||||
"""Clear authentication cookies."""
|
||||
response.delete_cookie(key=ACCESS_TOKEN_COOKIE, path="/")
|
||||
response.delete_cookie(key=REFRESH_TOKEN_COOKIE, path="/api/auth/refresh")
|
||||
|
||||
|
||||
def _get_access_token(request: Request) -> Optional[str]:
|
||||
"""Get access token from cookie or Authorization header."""
|
||||
# Try cookie first
|
||||
token = request.cookies.get(ACCESS_TOKEN_COOKIE)
|
||||
if token:
|
||||
return token
|
||||
|
||||
# Try Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header[7:]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Auth Dependencies
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def get_current_user_from_token(request: Request) -> Optional[User]:
|
||||
"""Get the current user from the access token."""
|
||||
access_token = _get_access_token(request)
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
payload = decode_access_token(access_token)
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
return await User.prisma().find_unique(where={"id": payload.sub})
|
||||
|
||||
|
||||
async def require_auth(request: Request) -> User:
|
||||
"""Require authentication - returns user or raises 401."""
|
||||
user = await get_current_user_from_token(request)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return user
|
||||
|
||||
|
||||
async def require_admin(request: Request) -> User:
|
||||
"""Require admin authentication - returns user or raises 401/403."""
|
||||
user = await require_auth(request)
|
||||
role = get_user_role(user.email)
|
||||
if role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
return user
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authentication Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/signup", response_model=MessageResponse)
|
||||
async def signup(data: SignupRequest):
|
||||
"""
|
||||
Register a new user.
|
||||
|
||||
Returns a message prompting the user to verify their email.
|
||||
No automatic login until email is verified.
|
||||
"""
|
||||
# Check if email already exists
|
||||
existing = await User.prisma().find_unique(where={"email": data.email})
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
|
||||
# Validate password strength
|
||||
if len(data.password) < 8:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Password must be at least 8 characters"
|
||||
)
|
||||
|
||||
# Create user with hashed password
|
||||
password_hash = hash_password(data.password)
|
||||
user = await User.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"email": data.email,
|
||||
"passwordHash": password_hash,
|
||||
"authProvider": "password",
|
||||
"emailVerified": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Create verification link and send email
|
||||
token = await create_email_verification_link(data.email)
|
||||
email_service = get_auth_email_service()
|
||||
email_sent = email_service.send_verification_email(data.email, token)
|
||||
|
||||
if not email_sent:
|
||||
logger.warning(f"Failed to send verification email to {data.email}")
|
||||
# Still log the token for development
|
||||
logger.info(f"Verification token for {data.email}: {token}")
|
||||
|
||||
return MessageResponse(
|
||||
message="Please check your email to verify your account"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthResponse)
|
||||
async def login(data: LoginRequest, response: Response):
|
||||
"""
|
||||
Login with email and password.
|
||||
|
||||
Sets httpOnly cookies for session management.
|
||||
"""
|
||||
user = await User.prisma().find_unique(where={"email": data.email})
|
||||
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid email or password")
|
||||
|
||||
# Check if this is a migrated user without password
|
||||
if user.passwordHash is None:
|
||||
if user.migratedFromSupabase:
|
||||
# Send password reset email for migrated user
|
||||
token = await create_password_reset_link(data.email, user.id)
|
||||
email_service = get_auth_email_service()
|
||||
email_service.send_migrated_user_password_reset(data.email, token)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Please check your email to set your password",
|
||||
)
|
||||
else:
|
||||
# OAuth user trying to login with password
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"This account uses {user.authProvider} login",
|
||||
)
|
||||
|
||||
# Verify password
|
||||
if not verify_password(user.passwordHash, data.password):
|
||||
raise HTTPException(status_code=401, detail="Invalid email or password")
|
||||
|
||||
# Check if email is verified
|
||||
if not user.emailVerified:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Please verify your email before logging in"
|
||||
)
|
||||
|
||||
# Rehash password if needed (transparent security upgrade)
|
||||
if needs_rehash(user.passwordHash):
|
||||
new_hash = hash_password(data.password)
|
||||
await User.prisma().update(
|
||||
where={"id": user.id}, data={"passwordHash": new_hash}
|
||||
)
|
||||
|
||||
# Create tokens
|
||||
role = get_user_role(user.email)
|
||||
access_token = create_access_token(user.id, user.email, role)
|
||||
refresh_token, _ = await create_refresh_token_db(user.id)
|
||||
|
||||
# Set cookies
|
||||
_set_auth_cookies(response, access_token, refresh_token)
|
||||
|
||||
return AuthResponse(
|
||||
user=UserResponse.from_db(user),
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=int(ACCESS_TOKEN_TTL.total_seconds()),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
async def logout(request: Request, response: Response, scope: str = Query("local")):
|
||||
"""
|
||||
Logout the current user.
|
||||
|
||||
Args:
|
||||
scope: "local" to clear current session, "global" to revoke all sessions.
|
||||
"""
|
||||
# Get refresh token to revoke
|
||||
refresh_token = request.cookies.get(REFRESH_TOKEN_COOKIE)
|
||||
|
||||
if scope == "global":
|
||||
# Get user from access token
|
||||
access_token = _get_access_token(request)
|
||||
if access_token:
|
||||
payload = decode_access_token(access_token)
|
||||
if payload:
|
||||
await revoke_all_user_refresh_tokens(payload.sub)
|
||||
elif refresh_token:
|
||||
await revoke_refresh_token(refresh_token)
|
||||
|
||||
# Clear cookies
|
||||
_clear_auth_cookies(response)
|
||||
|
||||
return MessageResponse(message="Logged out successfully")
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthResponse)
|
||||
async def refresh(request: Request, response: Response):
|
||||
"""
|
||||
Refresh the access token using the refresh token.
|
||||
"""
|
||||
refresh_token = request.cookies.get(REFRESH_TOKEN_COOKIE)
|
||||
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=401, detail="No refresh token")
|
||||
|
||||
# Validate refresh token
|
||||
user_id = await validate_refresh_token(refresh_token)
|
||||
if not user_id:
|
||||
_clear_auth_cookies(response)
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
|
||||
|
||||
# Get user
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
_clear_auth_cookies(response)
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
# Revoke old refresh token
|
||||
await revoke_refresh_token(refresh_token)
|
||||
|
||||
# Create new tokens
|
||||
role = get_user_role(user.email)
|
||||
new_access_token = create_access_token(user.id, user.email, role)
|
||||
new_refresh_token, _ = await create_refresh_token_db(user.id)
|
||||
|
||||
# Set new cookies
|
||||
_set_auth_cookies(response, new_access_token, new_refresh_token)
|
||||
|
||||
return AuthResponse(
|
||||
user=UserResponse.from_db(user),
|
||||
access_token=new_access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
expires_in=int(ACCESS_TOKEN_TTL.total_seconds()),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user(request: Request):
|
||||
"""
|
||||
Get the currently authenticated user.
|
||||
|
||||
Supports admin impersonation via X-Act-As-User-Id header.
|
||||
"""
|
||||
access_token = _get_access_token(request)
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
payload = decode_access_token(access_token)
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
# Check for impersonation header
|
||||
impersonate_user_id = request.headers.get(IMPERSONATION_HEADER, "").strip()
|
||||
if impersonate_user_id:
|
||||
# Verify caller is admin
|
||||
if payload.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only admins can impersonate users"
|
||||
)
|
||||
|
||||
# Log impersonation for audit
|
||||
logger.info(
|
||||
f"Admin impersonation: {payload.sub} ({payload.email}) "
|
||||
f"viewing as user {impersonate_user_id}"
|
||||
)
|
||||
|
||||
user = await User.prisma().find_unique(where={"id": impersonate_user_id})
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="Impersonated user not found")
|
||||
else:
|
||||
user = await User.prisma().find_unique(where={"id": payload.sub})
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
return UserResponse.from_db(user, include_role=True)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Password Reset Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/password/reset", response_model=MessageResponse)
|
||||
async def request_password_reset(data: PasswordResetRequest):
|
||||
"""
|
||||
Request a password reset email.
|
||||
"""
|
||||
user = await User.prisma().find_unique(where={"email": data.email})
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
if not user:
|
||||
return MessageResponse(message="If the email exists, a reset link has been sent")
|
||||
|
||||
# Don't allow password reset for OAuth-only users
|
||||
if user.authProvider not in ("password", "supabase"):
|
||||
return MessageResponse(message="If the email exists, a reset link has been sent")
|
||||
|
||||
# Create reset link and send email
|
||||
token = await create_password_reset_link(data.email, user.id)
|
||||
email_service = get_auth_email_service()
|
||||
email_service.send_password_reset_email(data.email, token)
|
||||
|
||||
return MessageResponse(message="If the email exists, a reset link has been sent")
|
||||
|
||||
|
||||
@router.post("/password/set", response_model=MessageResponse)
|
||||
async def set_password(data: PasswordSetRequest, response: Response):
|
||||
"""
|
||||
Set a new password using a reset token.
|
||||
"""
|
||||
# Validate token
|
||||
result = await verify_password_reset_token(data.token)
|
||||
if not result:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired reset token")
|
||||
|
||||
user_id, email = result
|
||||
|
||||
# Validate password strength
|
||||
if len(data.password) < 8:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Password must be at least 8 characters"
|
||||
)
|
||||
|
||||
# Update password and verify email (if not already)
|
||||
password_hash = hash_password(data.password)
|
||||
await User.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={
|
||||
"passwordHash": password_hash,
|
||||
"emailVerified": True,
|
||||
"emailVerifiedAt": datetime.now(timezone.utc),
|
||||
"authProvider": "password",
|
||||
"migratedFromSupabase": False, # Clear migration flag
|
||||
},
|
||||
)
|
||||
|
||||
# Send notification that password was changed
|
||||
email_service = get_auth_email_service()
|
||||
email_service.send_password_changed_notification(email)
|
||||
|
||||
# Revoke all existing sessions for security
|
||||
await revoke_all_user_refresh_tokens(user_id)
|
||||
|
||||
# Clear any existing cookies
|
||||
_clear_auth_cookies(response)
|
||||
|
||||
return MessageResponse(message="Password updated successfully. Please log in.")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Email Verification
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get("/verify-email", response_model=MessageResponse)
|
||||
async def verify_email(token: str = Query(...)):
|
||||
"""
|
||||
Verify email address from magic link.
|
||||
"""
|
||||
email = await verify_email_token(token)
|
||||
if not email:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired verification link")
|
||||
|
||||
# Update user as verified
|
||||
user = await User.prisma().find_unique(where={"email": email})
|
||||
if user:
|
||||
await User.prisma().update(
|
||||
where={"id": user.id},
|
||||
data={
|
||||
"emailVerified": True,
|
||||
"emailVerifiedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
return MessageResponse(message="Email verified successfully. You can now log in.")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Google OAuth
|
||||
# ============================================================================
|
||||
|
||||
GOOGLE_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
GOOGLE_USERINFO_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
|
||||
def _get_google_config():
|
||||
"""Get Google OAuth configuration from environment."""
|
||||
client_id = os.getenv("GOOGLE_CLIENT_ID")
|
||||
client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
|
||||
redirect_uri = os.getenv("GOOGLE_REDIRECT_URI", "")
|
||||
|
||||
if not client_id or not client_secret:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Google OAuth not configured"
|
||||
)
|
||||
|
||||
return client_id, client_secret, redirect_uri
|
||||
|
||||
|
||||
@router.get("/oauth/google/authorize")
|
||||
async def google_authorize(
|
||||
response: Response,
|
||||
redirect_to: str = Query("/marketplace", description="URL to redirect after auth"),
|
||||
):
|
||||
"""
|
||||
Initiate Google OAuth flow.
|
||||
|
||||
Returns the authorization URL to redirect the user to.
|
||||
"""
|
||||
client_id, _, redirect_uri = _get_google_config()
|
||||
|
||||
# Generate state for CSRF protection
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Store state and redirect_to in cookie
|
||||
secure = _is_production()
|
||||
response.set_cookie(
|
||||
key=OAUTH_STATE_COOKIE,
|
||||
value=f"{state}|{redirect_to}",
|
||||
httponly=True,
|
||||
secure=secure,
|
||||
samesite="lax",
|
||||
max_age=600, # 10 minutes
|
||||
)
|
||||
|
||||
# Build authorization URL
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": "openid email profile",
|
||||
"state": state,
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
}
|
||||
|
||||
auth_url = f"{GOOGLE_AUTH_URL}?{urlencode(params)}"
|
||||
|
||||
return {"url": auth_url}
|
||||
|
||||
|
||||
@router.get("/oauth/google/callback")
|
||||
async def google_callback(
|
||||
request: Request,
|
||||
response: Response,
|
||||
code: str = Query(...),
|
||||
state: str = Query(...),
|
||||
):
|
||||
"""
|
||||
Handle Google OAuth callback.
|
||||
|
||||
Exchanges the authorization code for tokens and creates/updates the user.
|
||||
"""
|
||||
client_id, client_secret, redirect_uri = _get_google_config()
|
||||
|
||||
# Verify state
|
||||
stored_state_cookie = request.cookies.get(OAUTH_STATE_COOKIE)
|
||||
if not stored_state_cookie:
|
||||
raise HTTPException(status_code=400, detail="Missing OAuth state")
|
||||
|
||||
stored_state, redirect_to = stored_state_cookie.split("|", 1)
|
||||
if state != stored_state:
|
||||
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
||||
|
||||
# Clear state cookie
|
||||
response.delete_cookie(key=OAUTH_STATE_COOKIE)
|
||||
|
||||
# Exchange code for tokens
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
GOOGLE_TOKEN_URL,
|
||||
data={
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
logger.error(f"Google token exchange failed: {token_response.text}")
|
||||
raise HTTPException(status_code=400, detail="Failed to exchange code")
|
||||
|
||||
tokens = token_response.json()
|
||||
google_access_token = tokens.get("access_token")
|
||||
|
||||
# Get user info
|
||||
userinfo_response = await client.get(
|
||||
GOOGLE_USERINFO_URL,
|
||||
headers={"Authorization": f"Bearer {google_access_token}"},
|
||||
)
|
||||
|
||||
if userinfo_response.status_code != 200:
|
||||
raise HTTPException(status_code=400, detail="Failed to get user info")
|
||||
|
||||
userinfo = userinfo_response.json()
|
||||
|
||||
email = userinfo.get("email")
|
||||
if not email:
|
||||
raise HTTPException(status_code=400, detail="Email not provided by Google")
|
||||
|
||||
# Get or create user
|
||||
user = await User.prisma().find_unique(where={"email": email})
|
||||
|
||||
if user:
|
||||
# Update existing user if needed
|
||||
if user.authProvider == "supabase":
|
||||
await User.prisma().update(
|
||||
where={"id": user.id},
|
||||
data={"authProvider": "google"},
|
||||
)
|
||||
else:
|
||||
# Create new user
|
||||
user = await User.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"email": email,
|
||||
"name": userinfo.get("name"),
|
||||
"emailVerified": True, # Google verifies emails
|
||||
"emailVerifiedAt": datetime.now(timezone.utc),
|
||||
"authProvider": "google",
|
||||
}
|
||||
)
|
||||
|
||||
# Create tokens
|
||||
role = get_user_role(email)
|
||||
access_token = create_access_token(user.id, user.email, role)
|
||||
refresh_token, _ = await create_refresh_token_db(user.id)
|
||||
|
||||
# Set cookies
|
||||
_set_auth_cookies(response, access_token, refresh_token)
|
||||
|
||||
# Redirect to frontend
|
||||
frontend_url = os.getenv("FRONTEND_BASE_URL", "http://localhost:3000")
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
return RedirectResponse(url=f"{frontend_url}{redirect_to}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Admin Routes
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get("/admin/users", response_model=AdminUserListResponse)
|
||||
async def list_users(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="Search by email"),
|
||||
admin_user: User = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
List all users (admin only).
|
||||
"""
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
where_clause = {}
|
||||
if search:
|
||||
where_clause["email"] = {"contains": search, "mode": "insensitive"}
|
||||
|
||||
users = await User.prisma().find_many(
|
||||
where=where_clause,
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
|
||||
total = await User.prisma().count(where=where_clause)
|
||||
|
||||
return AdminUserListResponse(
|
||||
users=[UserResponse.from_db(u, include_role=True) for u in users],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/users/{user_id}", response_model=UserResponse)
|
||||
async def get_user_by_id(
|
||||
user_id: str,
|
||||
admin_user: User = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
Get a specific user by ID (admin only).
|
||||
"""
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return UserResponse.from_db(user, include_role=True)
|
||||
|
||||
|
||||
@router.post("/admin/users/{user_id}/impersonate", response_model=ImpersonationResponse)
|
||||
async def impersonate_user(
|
||||
request: Request,
|
||||
user_id: str,
|
||||
admin_user: User = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
Get an access token to impersonate a user (admin only).
|
||||
|
||||
This token can be used with the Authorization header to act as the user.
|
||||
All actions are logged for audit purposes.
|
||||
"""
|
||||
target_user = await User.prisma().find_unique(where={"id": user_id})
|
||||
if not target_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Log the impersonation
|
||||
logger.warning(
|
||||
f"ADMIN IMPERSONATION: Admin {admin_user.id} ({admin_user.email}) "
|
||||
f"generated impersonation token for user {target_user.id} ({target_user.email})"
|
||||
)
|
||||
|
||||
# Create an access token for the target user (but with original role for safety)
|
||||
# The impersonation is tracked via the audit log
|
||||
role = get_user_role(target_user.email)
|
||||
access_token = create_access_token(target_user.id, target_user.email, role)
|
||||
|
||||
return ImpersonationResponse(
|
||||
access_token=access_token,
|
||||
impersonated_user=UserResponse.from_db(target_user, include_role=True),
|
||||
expires_in=int(ACCESS_TOKEN_TTL.total_seconds()),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/users/{user_id}/force-password-reset", response_model=MessageResponse)
|
||||
async def force_password_reset(
|
||||
user_id: str,
|
||||
admin_user: User = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
Force send a password reset email to a user (admin only).
|
||||
|
||||
Useful for helping users who are locked out.
|
||||
"""
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Create and send password reset
|
||||
token = await create_password_reset_link(user.email, user.id)
|
||||
email_service = get_auth_email_service()
|
||||
email_sent = email_service.send_password_reset_email(user.email, token)
|
||||
|
||||
# Log the action
|
||||
logger.info(
|
||||
f"Admin {admin_user.id} ({admin_user.email}) "
|
||||
f"triggered password reset for user {user.id} ({user.email})"
|
||||
)
|
||||
|
||||
if email_sent:
|
||||
return MessageResponse(message=f"Password reset email sent to {user.email}")
|
||||
else:
|
||||
return MessageResponse(message="Email service unavailable, reset link logged")
|
||||
|
||||
|
||||
@router.post("/admin/users/{user_id}/revoke-sessions", response_model=MessageResponse)
|
||||
async def revoke_user_sessions(
|
||||
user_id: str,
|
||||
admin_user: User = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
Revoke all sessions for a user (admin only).
|
||||
|
||||
Useful for security incidents.
|
||||
"""
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
count = await revoke_all_user_refresh_tokens(user_id)
|
||||
|
||||
# Log the action
|
||||
logger.warning(
|
||||
f"Admin {admin_user.id} ({admin_user.email}) "
|
||||
f"revoked all sessions for user {user.id} ({user.email}). "
|
||||
f"Revoked {count} refresh tokens."
|
||||
)
|
||||
|
||||
return MessageResponse(message=f"Revoked {count} sessions for user {user.email}")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user