Compare commits

..

7 Commits

Author SHA1 Message Date
Swifty
a86d750cf5 Merge branch 'fix/integrations-credential-type' into swiftyos/dev 2025-12-04 16:14:51 +01:00
Swifty
13bd648731 Merge branch 'swiftyos/vector-search' into swiftyos/dev 2025-12-04 16:14:47 +01:00
Swifty
3d7ee7cc29 Merge branch 'swiftyos/add-default-agents' into swiftyos/dev 2025-12-04 16:14:44 +01:00
Swifty
1ea52934cd add store agents for seeding test databases 2025-12-04 16:07:58 +01:00
Swifty
7b6db6e260 add vector search 2025-12-04 16:05:47 +01:00
Swifty
2c9563353e formatting 2025-12-04 09:35:53 +01:00
Swifty
fb2a70e2d8 pass credential type 2025-12-04 09:21:12 +01:00
860 changed files with 20749 additions and 67802 deletions

View File

@@ -1,37 +0,0 @@
{
"worktreeCopyPatterns": [
".env*",
".vscode/**",
".auth/**",
".claude/**",
"autogpt_platform/.env*",
"autogpt_platform/backend/.env*",
"autogpt_platform/frontend/.env*",
"autogpt_platform/frontend/.auth/**",
"autogpt_platform/db/docker/.env*"
],
"worktreeCopyIgnores": [
"**/node_modules/**",
"**/dist/**",
"**/.git/**",
"**/Thumbs.db",
"**/.DS_Store",
"**/.next/**",
"**/__pycache__/**",
"**/.ruff_cache/**",
"**/.pytest_cache/**",
"**/*.pyc",
"**/playwright-report/**",
"**/logs/**",
"**/site/**"
],
"worktreePathTemplate": "$BASE_PATH.worktree",
"postCreateCmd": [
"cd autogpt_platform/autogpt_libs && poetry install",
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
"cd autogpt_platform/frontend && pnpm install",
"cd docs && pip install -r requirements.txt"
],
"terminalCommand": "code .",
"deleteBranchWithWorktree": false
}

View File

@@ -16,7 +16,6 @@
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
!autogpt_platform/backend/gen_prisma_types_stub.py
# Platform - Market
!autogpt_platform/market/market/

View File

@@ -74,7 +74,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js

View File

@@ -90,7 +90,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js

View File

@@ -72,7 +72,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
@@ -108,16 +108,6 @@ jobs:
# run: pnpm playwright install --with-deps chromium
# Docker setup for development environment
- name: Free up disk space
run: |
# Remove large unused tools to free disk space for Docker builds
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker system prune -af
df -h
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -1,74 +0,0 @@
name: Block Documentation Sync Check
on:
push:
branches: [master, dev]
paths:
- "autogpt_platform/backend/backend/blocks/**"
- "docs/content/platform/blocks/**"
- "autogpt_platform/backend/scripts/generate_block_docs.py"
- ".github/workflows/docs-block-sync.yml"
pull_request:
branches: [master, dev]
paths:
- "autogpt_platform/backend/backend/blocks/**"
- "docs/content/platform/blocks/**"
- "autogpt_platform/backend/scripts/generate_block_docs.py"
- ".github/workflows/docs-block-sync.yml"
jobs:
check-docs-sync:
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
restore-keys: |
poetry-${{ runner.os }}-
- name: Install Poetry
run: |
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Install dependencies
working-directory: autogpt_platform/backend
run: |
poetry install --only main
poetry run prisma generate
- name: Check block documentation is in sync
working-directory: autogpt_platform/backend
run: |
echo "Checking if block documentation is in sync with code..."
poetry run python scripts/generate_block_docs.py --check
- name: Show diff if out of sync
if: failure()
run: |
echo "::error::Block documentation is out of sync with code!"
echo ""
echo "To fix this, run the following command locally:"
echo " cd autogpt_platform/backend && poetry run python scripts/generate_block_docs.py"
echo ""
echo "Then commit the updated documentation files."
echo ""
echo "Changes detected:"
git diff docs/content/platform/blocks/ || true

View File

@@ -1,94 +0,0 @@
name: Claude Block Docs Review
on:
pull_request:
types: [opened, synchronize]
paths:
- "docs/content/platform/blocks/**"
- "autogpt_platform/backend/backend/blocks/**"
jobs:
claude-review:
# Only run for PRs from members/collaborators
if: |
github.event.pull_request.author_association == 'OWNER' ||
github.event.pull_request.author_association == 'MEMBER' ||
github.event.pull_request.author_association == 'COLLABORATOR'
runs-on: ubuntu-latest
timeout-minutes: 15
permissions:
contents: read
pull-requests: write
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
restore-keys: |
poetry-${{ runner.os }}-
- name: Install Poetry
run: |
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Install dependencies
working-directory: autogpt_platform/backend
run: |
poetry install --only main
poetry run prisma generate
- name: Run Claude Code Review
uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
claude_args: |
--allowedTools "Read,Glob,Grep,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*)"
prompt: |
You are reviewing a PR that modifies block documentation or block code for AutoGPT.
## Your Task
Review the changes in this PR and provide constructive feedback. Focus on:
1. **Documentation Accuracy**: For any block code changes, verify that:
- Input/output tables in docs match the actual block schemas
- Description text accurately reflects what the block does
- Any new blocks have corresponding documentation
2. **Manual Content Quality**: Check manual sections (marked with `<!-- MANUAL: -->` markers):
- "How it works" sections should have clear technical explanations
- "Possible use case" sections should have practical, real-world examples
- Content should be helpful for users trying to understand the blocks
3. **Template Compliance**: Ensure docs follow the standard template:
- What it is (brief intro)
- What it does (description)
- How it works (technical explanation)
- Inputs table
- Outputs table
- Possible use case
4. **Cross-references**: Check that links and anchors are correct
## Review Process
1. First, get the PR diff to see what changed: `gh pr diff ${{ github.event.pull_request.number }}`
2. Read any modified block files to understand the implementation
3. Read corresponding documentation files to verify accuracy
4. Provide your feedback as a PR comment
Be constructive and specific. If everything looks good, say so!
If there are issues, explain what's wrong and suggest how to fix it.

View File

@@ -1,193 +0,0 @@
name: Enhance Block Documentation
on:
workflow_dispatch:
inputs:
block_pattern:
description: 'Block file pattern to enhance (e.g., "google/*.md" or "*" for all blocks)'
required: true
default: '*'
type: string
dry_run:
description: 'Dry run mode - show proposed changes without committing'
type: boolean
default: true
max_blocks:
description: 'Maximum number of blocks to process (0 for unlimited)'
type: number
default: 10
jobs:
enhance-docs:
runs-on: ubuntu-latest
timeout-minutes: 45
permissions:
contents: write
pull-requests: write
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
restore-keys: |
poetry-${{ runner.os }}-
- name: Install Poetry
run: |
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Install dependencies
working-directory: autogpt_platform/backend
run: |
poetry install --only main
poetry run prisma generate
- name: Run Claude Enhancement
uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
claude_args: |
--allowedTools "Read,Edit,Glob,Grep,Write,Bash(git:*),Bash(gh:*),Bash(find:*),Bash(ls:*)"
prompt: |
You are enhancing block documentation for AutoGPT. Your task is to improve the MANUAL sections
of block documentation files by reading the actual block implementations and writing helpful content.
## Configuration
- Block pattern: ${{ inputs.block_pattern }}
- Dry run: ${{ inputs.dry_run }}
- Max blocks to process: ${{ inputs.max_blocks }}
## Your Task
1. **Find Documentation Files**
Find block documentation files matching the pattern in `docs/content/platform/blocks/`
Pattern: ${{ inputs.block_pattern }}
Use: `find docs/content/platform/blocks -name "*.md" -type f`
2. **For Each Documentation File** (up to ${{ inputs.max_blocks }} files):
a. Read the documentation file
b. Identify which block(s) it documents (look for the block class name)
c. Find and read the corresponding block implementation in `autogpt_platform/backend/backend/blocks/`
d. Improve the MANUAL sections:
**"How it works" section** (within `<!-- MANUAL: how_it_works -->` markers):
- Explain the technical flow of the block
- Describe what APIs or services it connects to
- Note any important configuration or prerequisites
- Keep it concise but informative (2-4 paragraphs)
**"Possible use case" section** (within `<!-- MANUAL: use_case -->` markers):
- Provide 2-3 practical, real-world examples
- Make them specific and actionable
- Show how this block could be used in an automation workflow
3. **Important Rules**
- ONLY modify content within `<!-- MANUAL: -->` and `<!-- END MANUAL -->` markers
- Do NOT modify auto-generated sections (inputs/outputs tables, descriptions)
- Keep content accurate based on the actual block implementation
- Write for users who may not be technical experts
4. **Output**
${{ inputs.dry_run == true && 'DRY RUN MODE: Show proposed changes for each file but do NOT actually edit the files. Describe what you would change.' || 'LIVE MODE: Actually edit the files to improve the documentation.' }}
## Example Improvements
**Before (How it works):**
```
_Add technical explanation here._
```
**After (How it works):**
```
This block connects to the GitHub API to retrieve issue information. When executed,
it authenticates using your GitHub credentials and fetches issue details including
title, body, labels, and assignees.
The block requires a valid GitHub OAuth connection with repository access permissions.
It supports both public and private repositories you have access to.
```
**Before (Possible use case):**
```
_Add practical use case examples here._
```
**After (Possible use case):**
```
**Customer Support Automation**: Monitor a GitHub repository for new issues with
the "bug" label, then automatically create a ticket in your support system and
notify the on-call engineer via Slack.
**Release Notes Generation**: When a new release is published, gather all closed
issues since the last release and generate a summary for your changelog.
```
Begin by finding and listing the documentation files to process.
- name: Create PR with enhanced documentation
if: ${{ inputs.dry_run == false }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Check if there are changes
if git diff --quiet docs/content/platform/blocks/; then
echo "No changes to commit"
exit 0
fi
# Configure git
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
# Create branch and commit
BRANCH_NAME="docs/enhance-blocks-$(date +%Y%m%d-%H%M%S)"
git checkout -b "$BRANCH_NAME"
git add docs/content/platform/blocks/
git commit -m "docs: enhance block documentation with LLM-generated content
Pattern: ${{ inputs.block_pattern }}
Max blocks: ${{ inputs.max_blocks }}
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>"
# Push and create PR
git push -u origin "$BRANCH_NAME"
gh pr create \
--title "docs: LLM-enhanced block documentation" \
--body "## Summary
This PR contains LLM-enhanced documentation for block files matching pattern: \`${{ inputs.block_pattern }}\`
The following manual sections were improved:
- **How it works**: Technical explanations based on block implementations
- **Possible use case**: Practical, real-world examples
## Review Checklist
- [ ] Content is accurate based on block implementations
- [ ] Examples are practical and helpful
- [ ] No auto-generated sections were modified
---
🤖 Generated with [Claude Code](https://claude.com/claude-code)" \
--base dev

View File

@@ -134,7 +134,7 @@ jobs:
run: poetry install
- name: Generate Prisma Client
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
- id: supabase
name: Start Supabase

View File

@@ -11,7 +11,7 @@ jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v10
- uses: actions/stale@v9
with:
# operations-per-run: 5000
stale-issue-message: >

View File

@@ -61,6 +61,6 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v6
- uses: actions/labeler@v5
with:
sync-labels: true

View File

@@ -12,7 +12,6 @@ reset-db:
rm -rf db/docker/volumes/db/data
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
cd backend && poetry run gen-prisma-stub
# View logs for core services
logs-core:
@@ -34,7 +33,6 @@ init-env:
migrate:
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
cd backend && poetry run gen-prisma-stub
run-backend:
cd backend && poetry run app
@@ -46,7 +44,7 @@ test-data:
cd backend && poetry run python test/test_data_creator.py
load-store-agents:
cd backend && poetry run load-store-agents
cd backend && poetry run python test/load_store_agents.py
help:
@echo "Usage: make <target>"

View File

@@ -57,9 +57,6 @@ class APIKeySmith:
def hash_key(self, raw_key: str) -> tuple[str, str]:
"""Migrate a legacy hash to secure hash format."""
if not raw_key.startswith(self.PREFIX):
raise ValueError("Key without 'agpt_' prefix would fail validation")
salt = self._generate_salt()
hash = self._hash_key_with_salt(raw_key, salt)
return hash, salt.hex()

View File

@@ -1,25 +1,29 @@
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from .jwt_utils import bearer_jwt_auth
def add_auth_responses_to_openapi(app: FastAPI) -> None:
"""
Patch a FastAPI instance's `openapi()` method to add 401 responses
Set up custom OpenAPI schema generation that adds 401 responses
to all authenticated endpoints.
This is needed when using HTTPBearer with auto_error=False to get proper
401 responses instead of 403, but FastAPI only automatically adds security
responses when auto_error=True.
"""
# Wrap current method to allow stacking OpenAPI schema modifiers like this
wrapped_openapi = app.openapi
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = wrapped_openapi()
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# Add 401 response to all endpoints that have security requirements
for path, methods in openapi_schema["paths"].items():

View File

@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
# Generate Prisma client
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub
RUN poetry run prisma generate
FROM debian:13-slim AS server_dependencies

View File

@@ -108,7 +108,7 @@ import fastapi.testclient
import pytest
from pytest_snapshot.plugin import Snapshot
from backend.api.features.myroute import router
from backend.server.v2.myroute import router
app = fastapi.FastAPI()
app.include_router(router)
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
import fastapi
import fastapi.testclient
import pytest
from backend.api.features.myroute import router
from backend.server.v2.myroute import router
app = fastapi.FastAPI()
app.include_router(router)

View File

@@ -1,25 +0,0 @@
from fastapi import FastAPI
from backend.api.middleware.security import SecurityHeadersMiddleware
from backend.monitoring.instrumentation import instrument_fastapi
from .v1.routes import v1_router
external_api = FastAPI(
title="AutoGPT External API",
description="External API for AutoGPT integrations",
docs_url="/docs",
version="1.0",
)
external_api.add_middleware(SecurityHeadersMiddleware)
external_api.include_router(v1_router, prefix="/v1")
# Add Prometheus instrumentation
instrument_fastapi(
external_api,
service_name="external-api",
expose_endpoint=True,
endpoint="/metrics",
include_in_schema=True,
)

View File

@@ -1,107 +0,0 @@
from fastapi import HTTPException, Security, status
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
from prisma.enums import APIKeyPermission
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.auth.oauth import (
InvalidClientError,
InvalidTokenError,
OAuthAccessTokenInfo,
validate_access_token,
)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
bearer_auth = HTTPBearer(auto_error=False)
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
"""Middleware for API key authentication only"""
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
)
api_key_obj = await validate_api_key(api_key)
if not api_key_obj:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
)
return api_key_obj
async def require_access_token(
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
) -> OAuthAccessTokenInfo:
"""Middleware for OAuth access token authentication only"""
if bearer is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Authorization header",
)
try:
token_info, _ = await validate_access_token(bearer.credentials)
except (InvalidClientError, InvalidTokenError) as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
return token_info
async def require_auth(
api_key: str | None = Security(api_key_header),
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
) -> APIAuthorizationInfo:
"""
Unified authentication middleware supporting both API keys and OAuth tokens.
Supports two authentication methods, which are checked in order:
1. X-API-Key header (existing API key authentication)
2. Authorization: Bearer <token> header (OAuth access token)
Returns:
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
"""
# Try API key first
if api_key is not None:
api_key_info = await validate_api_key(api_key)
if api_key_info:
return api_key_info
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
)
# Try OAuth bearer token
if bearer is not None:
try:
token_info, _ = await validate_access_token(bearer.credentials)
return token_info
except (InvalidClientError, InvalidTokenError) as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
# No credentials provided
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication. Provide API key or access token.",
)
def require_permission(permission: APIKeyPermission):
"""
Dependency function for checking specific permissions
(works with API keys and OAuth tokens)
"""
async def check_permission(
auth: APIAuthorizationInfo = Security(require_auth),
) -> APIAuthorizationInfo:
if permission not in auth.scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permission: {permission.value}",
)
return auth
return check_permission

View File

@@ -1,340 +0,0 @@
"""Tests for analytics API endpoints."""
import json
from unittest.mock import AsyncMock, Mock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from pytest_snapshot.plugin import Snapshot
from .analytics import router as analytics_router
app = fastapi.FastAPI()
app.include_router(analytics_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module."""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
# =============================================================================
# /log_raw_metric endpoint tests
# =============================================================================
def test_log_raw_metric_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test successful raw metric logging."""
mock_result = Mock(id="metric-123-uuid")
mock_log_metric = mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
return_value=mock_result,
)
request_data = {
"metric_name": "page_load_time",
"metric_value": 2.5,
"data_string": "/dashboard",
}
response = client.post("/log_raw_metric", json=request_data)
assert response.status_code == 200, f"Unexpected response: {response.text}"
assert response.json() == "metric-123-uuid"
mock_log_metric.assert_called_once_with(
user_id=test_user_id,
metric_name="page_load_time",
metric_value=2.5,
data_string="/dashboard",
)
configured_snapshot.assert_match(
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
"analytics_log_metric_success",
)
@pytest.mark.parametrize(
"metric_value,metric_name,data_string,test_id",
[
(100, "api_calls_count", "external_api", "integer_value"),
(0, "error_count", "no_errors", "zero_value"),
(-5.2, "temperature_delta", "cooling", "negative_value"),
(1.23456789, "precision_test", "float_precision", "float_precision"),
(999999999, "large_number", "max_value", "large_number"),
(0.0000001, "tiny_number", "min_value", "tiny_number"),
],
)
def test_log_raw_metric_various_values(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
metric_value: float,
metric_name: str,
data_string: str,
test_id: str,
) -> None:
"""Test raw metric logging with various metric values."""
mock_result = Mock(id=f"metric-{test_id}-uuid")
mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
return_value=mock_result,
)
request_data = {
"metric_name": metric_name,
"metric_value": metric_value,
"data_string": data_string,
}
response = client.post("/log_raw_metric", json=request_data)
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
configured_snapshot.assert_match(
json.dumps(
{"metric_id": response.json(), "test_case": test_id},
indent=2,
sort_keys=True,
),
f"analytics_metric_{test_id}",
)
@pytest.mark.parametrize(
"invalid_data,expected_error",
[
({}, "Field required"),
({"metric_name": "test"}, "Field required"),
(
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
"Input should be a valid number",
),
(
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
"String should have at least 1 character",
),
(
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
"String should have at least 1 character",
),
],
ids=[
"empty_request",
"missing_metric_value_and_data_string",
"invalid_metric_value_type",
"empty_metric_name",
"empty_data_string",
],
)
def test_log_raw_metric_validation_errors(
invalid_data: dict,
expected_error: str,
) -> None:
"""Test validation errors for invalid metric requests."""
response = client.post("/log_raw_metric", json=invalid_data)
assert response.status_code == 422
error_detail = response.json()
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
error_text = json.dumps(error_detail)
assert (
expected_error in error_text
), f"Expected '{expected_error}' in error response: {error_text}"
def test_log_raw_metric_service_error(
mocker: pytest_mock.MockFixture,
test_user_id: str,
) -> None:
"""Test error handling when analytics service fails."""
mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
side_effect=Exception("Database connection failed"),
)
request_data = {
"metric_name": "test_metric",
"metric_value": 1.0,
"data_string": "test",
}
response = client.post("/log_raw_metric", json=request_data)
assert response.status_code == 500
error_detail = response.json()["detail"]
assert "Database connection failed" in error_detail["message"]
assert "hint" in error_detail
# =============================================================================
# /log_raw_analytics endpoint tests
# =============================================================================
def test_log_raw_analytics_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test successful raw analytics logging."""
mock_result = Mock(id="analytics-789-uuid")
mock_log_analytics = mocker.patch(
"backend.data.analytics.log_raw_analytics",
new_callable=AsyncMock,
return_value=mock_result,
)
request_data = {
"type": "user_action",
"data": {
"action": "button_click",
"button_id": "submit_form",
"timestamp": "2023-01-01T00:00:00Z",
"metadata": {"form_type": "registration", "fields_filled": 5},
},
"data_index": "button_click_submit_form",
}
response = client.post("/log_raw_analytics", json=request_data)
assert response.status_code == 200, f"Unexpected response: {response.text}"
assert response.json() == "analytics-789-uuid"
mock_log_analytics.assert_called_once_with(
test_user_id,
"user_action",
request_data["data"],
"button_click_submit_form",
)
configured_snapshot.assert_match(
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
"analytics_log_analytics_success",
)
def test_log_raw_analytics_complex_data(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test raw analytics logging with complex nested data structures."""
mock_result = Mock(id="analytics-complex-uuid")
mocker.patch(
"backend.data.analytics.log_raw_analytics",
new_callable=AsyncMock,
return_value=mock_result,
)
request_data = {
"type": "agent_execution",
"data": {
"agent_id": "agent_123",
"execution_id": "exec_456",
"status": "completed",
"duration_ms": 3500,
"nodes_executed": 15,
"blocks_used": [
{"block_id": "llm_block", "count": 3},
{"block_id": "http_block", "count": 5},
{"block_id": "code_block", "count": 2},
],
"errors": [],
"metadata": {
"trigger": "manual",
"user_tier": "premium",
"environment": "production",
},
},
"data_index": "agent_123_exec_456",
}
response = client.post("/log_raw_analytics", json=request_data)
assert response.status_code == 200
configured_snapshot.assert_match(
json.dumps(
{"analytics_id": response.json(), "logged_data": request_data["data"]},
indent=2,
sort_keys=True,
),
"analytics_log_analytics_complex_data",
)
@pytest.mark.parametrize(
"invalid_data,expected_error",
[
({}, "Field required"),
({"type": "test"}, "Field required"),
(
{"type": "test", "data": "not_a_dict", "data_index": "test"},
"Input should be a valid dictionary",
),
({"type": "test", "data": {"key": "value"}}, "Field required"),
],
ids=[
"empty_request",
"missing_data_and_data_index",
"invalid_data_type",
"missing_data_index",
],
)
def test_log_raw_analytics_validation_errors(
invalid_data: dict,
expected_error: str,
) -> None:
"""Test validation errors for invalid analytics requests."""
response = client.post("/log_raw_analytics", json=invalid_data)
assert response.status_code == 422
error_detail = response.json()
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
error_text = json.dumps(error_detail)
assert (
expected_error in error_text
), f"Expected '{expected_error}' in error response: {error_text}"
def test_log_raw_analytics_service_error(
mocker: pytest_mock.MockFixture,
test_user_id: str,
) -> None:
"""Test error handling when analytics service fails."""
mocker.patch(
"backend.data.analytics.log_raw_analytics",
new_callable=AsyncMock,
side_effect=Exception("Analytics DB unreachable"),
)
request_data = {
"type": "test_event",
"data": {"key": "value"},
"data_index": "test_index",
}
response = client.post("/log_raw_analytics", json=request_data)
assert response.status_code == 500
error_detail = response.json()["detail"]
assert "Analytics DB unreachable" in error_detail["message"]
assert "hint" in error_detail

View File

@@ -1,689 +0,0 @@
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from difflib import SequenceMatcher
from typing import Sequence
import prisma
import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
import backend.data.block
from backend.blocks import load_all_blocks
from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
from .model import (
BlockCategoryResponse,
BlockResponse,
BlockType,
CountResponse,
FilterType,
Provider,
ProviderResponse,
SearchEntry,
)
logger = logging.getLogger(__name__)
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
@dataclass
class _ScoredItem:
item: SearchResultItem
filter_type: FilterType
score: float
sort_key: str
@dataclass
class _SearchCacheEntry:
items: list[SearchResultItem]
total_items: dict[FilterType, int]
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
categories: dict[BlockCategory, BlockCategoryResponse] = {}
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't have categories (all should have at least one)
if not block.categories:
continue
# Add block to the categories
for category in block.categories:
if category not in categories:
categories[category] = BlockCategoryResponse(
name=category.name.lower(),
total_blocks=0,
blocks=[],
)
categories[category].total_blocks += 1
# Append if the category has less than the specified number of blocks
if len(categories[category].blocks) < category_blocks:
categories[category].blocks.append(block.get_info())
# Sort categories by name
return sorted(categories.values(), key=lambda x: x.name)
def get_blocks(
*,
category: str | None = None,
type: BlockType | None = None,
provider: ProviderName | None = None,
page: int = 1,
page_size: int = 50,
) -> BlockResponse:
"""
Get blocks based on either category, type or provider.
Providing nothing fetches all block types.
"""
# Only one of category, type, or provider can be specified
if (category and type) or (category and provider) or (type and provider):
raise ValueError("Only one of category, type, or provider can be specified")
blocks: list[AnyBlockSchema] = []
skip = (page - 1) * page_size
take = page_size
total = 0
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't match the category
if category and category not in {c.name.lower() for c in block.categories}:
continue
# Skip blocks that don't match the type
if (
(type == "input" and block.block_type.value != "Input")
or (type == "output" and block.block_type.value != "Output")
or (type == "action" and block.block_type.value in ("Input", "Output"))
):
continue
# Skip blocks that don't match the provider
if provider:
credentials_info = block.input_schema.get_credentials_fields_info().values()
if not any(provider in info.provider for info in credentials_info):
continue
total += 1
if skip > 0:
skip -= 1
continue
if take > 0:
take -= 1
blocks.append(block)
return BlockResponse(
blocks=[b.get_info() for b in blocks],
pagination=Pagination(
total_items=total,
total_pages=(total + page_size - 1) // page_size,
current_page=page,
page_size=page_size,
),
)
def get_block_by_id(block_id: str) -> BlockInfo | None:
"""
Get a specific block by its ID.
"""
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
if block.id == block_id:
return block.get_info()
return None
async def update_search(user_id: str, search: SearchEntry) -> str:
"""
Upsert a search request for the user and return the search ID.
"""
if search.search_id:
# Update existing search
await prisma.models.BuilderSearchHistory.prisma().update(
where={
"id": search.search_id,
},
data={
"searchQuery": search.search_query or "",
"filter": search.filter or [], # type: ignore
"byCreator": search.by_creator or [],
},
)
return search.search_id
else:
# Create new search
new_search = await prisma.models.BuilderSearchHistory.prisma().create(
data={
"userId": user_id,
"searchQuery": search.search_query or "",
"filter": search.filter or [], # type: ignore
"byCreator": search.by_creator or [],
}
)
return new_search.id
async def get_recent_searches(user_id: str, limit: int = 5) -> list[SearchEntry]:
"""
Get the user's most recent search requests.
"""
searches = await prisma.models.BuilderSearchHistory.prisma().find_many(
where={
"userId": user_id,
},
order={
"updatedAt": "desc",
},
take=limit,
)
return [
SearchEntry(
search_query=s.searchQuery,
filter=s.filter, # type: ignore
by_creator=s.byCreator,
search_id=s.id,
)
for s in searches
]
async def get_sorted_search_results(
*,
user_id: str,
search_query: str | None,
filters: Sequence[FilterType],
by_creator: Sequence[str] | None = None,
) -> _SearchCacheEntry:
normalized_filters: tuple[FilterType, ...] = tuple(sorted(set(filters or [])))
normalized_creators: tuple[str, ...] = tuple(sorted(set(by_creator or [])))
return await _build_cached_search_results(
user_id=user_id,
search_query=search_query or "",
filters=normalized_filters,
by_creator=normalized_creators,
)
@cached(ttl_seconds=300, shared_cache=True)
async def _build_cached_search_results(
user_id: str,
search_query: str,
filters: tuple[FilterType, ...],
by_creator: tuple[str, ...],
) -> _SearchCacheEntry:
normalized_query = (search_query or "").strip().lower()
include_blocks = "blocks" in filters
include_integrations = "integrations" in filters
include_library_agents = "my_agents" in filters
include_marketplace_agents = "marketplace_agents" in filters
scored_items: list[_ScoredItem] = []
total_items: dict[FilterType, int] = {
"blocks": 0,
"integrations": 0,
"marketplace_agents": 0,
"my_agents": 0,
}
block_results, block_total, integration_total = _collect_block_results(
normalized_query=normalized_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
if include_library_agents:
library_response = await library_db.list_library_agents(
user_id=user_id,
search_term=search_query or None,
page=1,
page_size=MAX_LIBRARY_AGENT_RESULTS,
)
total_items["my_agents"] = library_response.pagination.total_items
scored_items.extend(
_build_library_items(
agents=library_response.agents,
normalized_query=normalized_query,
)
)
if include_marketplace_agents:
marketplace_response = await store_db.get_store_agents(
creators=list(by_creator) or None,
search_query=search_query or None,
page=1,
page_size=MAX_MARKETPLACE_AGENT_RESULTS,
)
total_items["marketplace_agents"] = marketplace_response.pagination.total_items
scored_items.extend(
_build_marketplace_items(
agents=marketplace_response.agents,
normalized_query=normalized_query,
)
)
sorted_items = sorted(
scored_items,
key=lambda entry: (-entry.score, entry.sort_key, entry.filter_type),
)
return _SearchCacheEntry(
items=[entry.item for entry in sorted_items],
total_items=total_items,
)
def _collect_block_results(
*,
normalized_query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
if not include_blocks and not include_integrations:
return results, block_count, integration_count
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
if block.disabled:
continue
block_info = block.get_info()
credentials = list(block.input_schema.get_credentials_fields().values())
is_integration = len(credentials) > 0
if is_integration and not include_integrations:
continue
if not is_integration and not include_blocks:
continue
score = _score_block(block, block_info, normalized_query)
if not _should_include_item(score, normalized_query):
continue
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
else:
block_count += 1
results.append(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=score,
sort_key=_get_item_name(block_info),
)
)
return results, block_count, integration_count
def _build_library_items(
*,
agents: list[library_model.LibraryAgent],
normalized_query: str,
) -> list[_ScoredItem]:
results: list[_ScoredItem] = []
for agent in agents:
score = _score_library_agent(agent, normalized_query)
if not _should_include_item(score, normalized_query):
continue
results.append(
_ScoredItem(
item=agent,
filter_type="my_agents",
score=score,
sort_key=_get_item_name(agent),
)
)
return results
def _build_marketplace_items(
*,
agents: list[store_model.StoreAgent],
normalized_query: str,
) -> list[_ScoredItem]:
results: list[_ScoredItem] = []
for agent in agents:
score = _score_store_agent(agent, normalized_query)
if not _should_include_item(score, normalized_query):
continue
results.append(
_ScoredItem(
item=agent,
filter_type="marketplace_agents",
score=score,
sort_key=_get_item_name(agent),
)
)
return results
def get_providers(
query: str = "",
page: int = 1,
page_size: int = 50,
) -> ProviderResponse:
providers = []
query = query.lower()
skip = (page - 1) * page_size
take = page_size
all_providers = _get_all_providers()
for provider in all_providers.values():
if (
query not in provider.name.value.lower()
and query not in provider.description.lower()
):
continue
if skip > 0:
skip -= 1
continue
if take > 0:
take -= 1
providers.append(provider)
total = len(all_providers)
return ProviderResponse(
providers=providers,
pagination=Pagination(
total_items=total,
total_pages=(total + page_size - 1) // page_size,
current_page=page,
page_size=page_size,
),
)
async def get_counts(user_id: str) -> CountResponse:
my_agents = await prisma.models.LibraryAgent.prisma().count(
where={
"userId": user_id,
"isDeleted": False,
"isArchived": False,
}
)
counts = await _get_static_counts()
return CountResponse(
my_agents=my_agents,
**counts,
)
@cached(ttl_seconds=3600)
async def _get_static_counts():
"""
Get counts of blocks, integrations, and marketplace agents.
This is cached to avoid unnecessary database queries and calculations.
"""
all_blocks = 0
input_blocks = 0
action_blocks = 0
output_blocks = 0
integrations = 0
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
if block.disabled:
continue
all_blocks += 1
if block.block_type.value == "Input":
input_blocks += 1
elif block.block_type.value == "Output":
output_blocks += 1
else:
action_blocks += 1
credentials = list(block.input_schema.get_credentials_fields().values())
if len(credentials) > 0:
integrations += 1
marketplace_agents = await prisma.models.StoreAgent.prisma().count()
return {
"all_blocks": all_blocks,
"input_blocks": input_blocks,
"action_blocks": action_blocks,
"output_blocks": output_blocks,
"integrations": integrations,
"marketplace_agents": marketplace_agents,
}
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if field.annotation == LlmModel:
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
return True
return False
def _score_block(
block: AnyBlockSchema,
block_info: BlockInfo,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = block_info.name.lower()
description = block_info.description.lower()
score = _score_primary_fields(name, description, normalized_query)
category_text = " ".join(
category.get("category", "").lower() for category in block_info.categories
)
score += _score_additional_field(category_text, normalized_query, 12, 6)
credentials_info = block.input_schema.get_credentials_fields_info().values()
provider_names = [
provider.value.lower()
for info in credentials_info
for provider in info.provider
]
provider_text = " ".join(provider_names)
score += _score_additional_field(provider_text, normalized_query, 15, 6)
if _matches_llm_model(block.input_schema, normalized_query):
score += 20
return score
def _score_library_agent(
agent: library_model.LibraryAgent,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = agent.name.lower()
description = (agent.description or "").lower()
instructions = (agent.instructions or "").lower()
score = _score_primary_fields(name, description, normalized_query)
score += _score_additional_field(instructions, normalized_query, 15, 6)
score += _score_additional_field(
agent.creator_name.lower(), normalized_query, 10, 5
)
return score
def _score_store_agent(
agent: store_model.StoreAgent,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = agent.agent_name.lower()
description = agent.description.lower()
sub_heading = agent.sub_heading.lower()
score = _score_primary_fields(name, description, normalized_query)
score += _score_additional_field(sub_heading, normalized_query, 12, 6)
score += _score_additional_field(agent.creator.lower(), normalized_query, 10, 5)
return score
def _score_primary_fields(name: str, description: str, query: str) -> float:
score = 0.0
if name == query:
score += 120
elif name.startswith(query):
score += 90
elif query in name:
score += 60
score += SequenceMatcher(None, name, query).ratio() * 50
if description:
if query in description:
score += 30
score += SequenceMatcher(None, description, query).ratio() * 25
return score
def _score_additional_field(
value: str,
query: str,
contains_weight: float,
similarity_weight: float,
) -> float:
if not value or not query:
return 0.0
score = 0.0
if query in value:
score += contains_weight
score += SequenceMatcher(None, value, query).ratio() * similarity_weight
return score
def _should_include_item(score: float, normalized_query: str) -> bool:
if not normalized_query:
return True
return score >= MIN_SCORE_FOR_FILTERED_RESULTS
def _get_item_name(item: SearchResultItem) -> str:
if isinstance(item, BlockInfo):
return item.name.lower()
if isinstance(item, library_model.LibraryAgent):
return item.name.lower()
return item.agent_name.lower()
@cached(ttl_seconds=3600)
def _get_all_providers() -> dict[ProviderName, Provider]:
providers: dict[ProviderName, Provider] = {}
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
if block.disabled:
continue
credentials_info = block.input_schema.get_credentials_fields_info().values()
for info in credentials_info:
for provider in info.provider: # provider is a ProviderName enum member
if provider in providers:
providers[provider].integration_count += 1
else:
providers[provider] = Provider(
name=provider, description="", integration_count=1
)
return providers
@cached(ttl_seconds=3600)
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
suggested_blocks = []
# Sum the number of executions for each block type
# Prisma cannot group by nested relations, so we do a raw query
# Calculate the cutoff timestamp
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
results = await query_raw_with_schema(
"""
SELECT
agent_node."agentBlockId" AS block_id,
COUNT(execution.id) AS execution_count
FROM {schema_prefix}"AgentNodeExecution" execution
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
WHERE execution."endedTime" >= $1::timestamp
GROUP BY agent_node."agentBlockId"
ORDER BY execution_count DESC;
""",
timestamp_threshold,
)
# Get the top blocks based on execution count
# But ignore Input and Output blocks
blocks: list[tuple[BlockInfo, int]] = []
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
if block.disabled or block.block_type in (
backend.data.block.BlockType.INPUT,
backend.data.block.BlockType.OUTPUT,
backend.data.block.BlockType.AGENT,
):
continue
# Find the execution count for this block
execution_count = next(
(row["execution_count"] for row in results if row["block_id"] == block.id),
0,
)
blocks.append((block.get_info(), execution_count))
# Sort blocks by execution count
blocks.sort(key=lambda x: x[1], reverse=True)
suggested_blocks = [block[0] for block in blocks]
# Return the top blocks
return suggested_blocks[:count]

View File

@@ -1,833 +0,0 @@
"""
OAuth 2.0 Provider Endpoints
Implements OAuth 2.0 Authorization Code flow with PKCE support.
Flow:
1. User clicks "Login with AutoGPT" in 3rd party app
2. App redirects user to /auth/authorize with client_id, redirect_uri, scope, state
3. User sees consent screen (if not already logged in, redirects to login first)
4. User approves → backend creates authorization code
5. User redirected back to app with code
6. App exchanges code for access/refresh tokens at /api/oauth/token
7. App uses access token to call external API endpoints
"""
import io
import logging
import os
import uuid
from datetime import datetime
from typing import Literal, Optional
from urllib.parse import urlencode
from autogpt_libs.auth import get_user_id
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
from gcloud.aio import storage as async_storage
from PIL import Image
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.data.auth.oauth import (
InvalidClientError,
InvalidGrantError,
OAuthApplicationInfo,
TokenIntrospectionResult,
consume_authorization_code,
create_access_token,
create_authorization_code,
create_refresh_token,
get_oauth_application,
get_oauth_application_by_id,
introspect_token,
list_user_oauth_applications,
refresh_tokens,
revoke_access_token,
revoke_refresh_token,
update_oauth_application,
validate_client_credentials,
validate_redirect_uri,
validate_scopes,
)
from backend.util.settings import Settings
from backend.util.virus_scanner import scan_content_safe
settings = Settings()
logger = logging.getLogger(__name__)
router = APIRouter()
# ============================================================================
# Request/Response Models
# ============================================================================
class TokenResponse(BaseModel):
"""OAuth 2.0 token response"""
token_type: Literal["Bearer"] = "Bearer"
access_token: str
access_token_expires_at: datetime
refresh_token: str
refresh_token_expires_at: datetime
scopes: list[str]
class ErrorResponse(BaseModel):
"""OAuth 2.0 error response"""
error: str
error_description: Optional[str] = None
class OAuthApplicationPublicInfo(BaseModel):
"""Public information about an OAuth application (for consent screen)"""
name: str
description: Optional[str] = None
logo_url: Optional[str] = None
scopes: list[str]
# ============================================================================
# Application Info Endpoint
# ============================================================================
@router.get(
"/app/{client_id}",
responses={
404: {"description": "Application not found or disabled"},
},
)
async def get_oauth_app_info(
client_id: str, user_id: str = Security(get_user_id)
) -> OAuthApplicationPublicInfo:
"""
Get public information about an OAuth application.
This endpoint is used by the consent screen to display application details
to the user before they authorize access.
Returns:
- name: Application name
- description: Application description (if provided)
- scopes: List of scopes the application is allowed to request
"""
app = await get_oauth_application(client_id)
if not app or not app.is_active:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found",
)
return OAuthApplicationPublicInfo(
name=app.name,
description=app.description,
logo_url=app.logo_url,
scopes=[s.value for s in app.scopes],
)
# ============================================================================
# Authorization Endpoint
# ============================================================================
class AuthorizeRequest(BaseModel):
"""OAuth 2.0 authorization request"""
client_id: str = Field(description="Client identifier")
redirect_uri: str = Field(description="Redirect URI")
scopes: list[str] = Field(description="List of scopes")
state: str = Field(description="Anti-CSRF token from client")
response_type: str = Field(
default="code", description="Must be 'code' for authorization code flow"
)
code_challenge: str = Field(description="PKCE code challenge (required)")
code_challenge_method: Literal["S256", "plain"] = Field(
default="S256", description="PKCE code challenge method (S256 recommended)"
)
class AuthorizeResponse(BaseModel):
"""OAuth 2.0 authorization response with redirect URL"""
redirect_url: str = Field(description="URL to redirect the user to")
@router.post("/authorize")
async def authorize(
request: AuthorizeRequest = Body(),
user_id: str = Security(get_user_id),
) -> AuthorizeResponse:
"""
OAuth 2.0 Authorization Endpoint
User must be logged in (authenticated with Supabase JWT).
This endpoint creates an authorization code and returns a redirect URL.
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
The frontend consent screen should call this endpoint after the user approves,
then redirect the user to the returned `redirect_url`.
Request Body:
- client_id: The OAuth application's client ID
- redirect_uri: Where to redirect after authorization (must match registered URI)
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
- state: Anti-CSRF token provided by client (will be returned in redirect)
- response_type: Must be "code" (for authorization code flow)
- code_challenge: PKCE code challenge (required)
- code_challenge_method: "S256" (recommended) or "plain"
Returns:
- redirect_url: The URL to redirect the user to (includes authorization code)
Error cases return a redirect_url with error parameters, or raise HTTPException
for critical errors (like invalid redirect_uri).
"""
try:
# Validate response_type
if request.response_type != "code":
return _error_redirect_url(
request.redirect_uri,
request.state,
"unsupported_response_type",
"Only 'code' response type is supported",
)
# Get application
app = await get_oauth_application(request.client_id)
if not app:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_client",
"Unknown client_id",
)
if not app.is_active:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_client",
"Application is not active",
)
# Validate redirect URI
if not validate_redirect_uri(app, request.redirect_uri):
# For invalid redirect_uri, we can't redirect safely
# Must return error instead
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"Invalid redirect_uri. "
f"Must be one of: {', '.join(app.redirect_uris)}"
),
)
# Parse and validate scopes
try:
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
except ValueError as e:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_scope",
f"Invalid scope: {e}",
)
if not requested_scopes:
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_scope",
"At least one scope is required",
)
if not validate_scopes(app, requested_scopes):
return _error_redirect_url(
request.redirect_uri,
request.state,
"invalid_scope",
"Application is not authorized for all requested scopes. "
f"Allowed: {', '.join(s.value for s in app.scopes)}",
)
# Create authorization code
auth_code = await create_authorization_code(
application_id=app.id,
user_id=user_id,
scopes=requested_scopes,
redirect_uri=request.redirect_uri,
code_challenge=request.code_challenge,
code_challenge_method=request.code_challenge_method,
)
# Build redirect URL with authorization code
params = {
"code": auth_code.code,
"state": request.state,
}
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
logger.info(
f"Authorization code issued for user #{user_id} "
f"and app {app.name} (#{app.id})"
)
return AuthorizeResponse(redirect_url=redirect_url)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
return _error_redirect_url(
request.redirect_uri,
request.state,
"server_error",
"An unexpected error occurred",
)
def _error_redirect_url(
redirect_uri: str,
state: str,
error: str,
error_description: Optional[str] = None,
) -> AuthorizeResponse:
"""Helper to build redirect URL with OAuth error parameters"""
params = {
"error": error,
"state": state,
}
if error_description:
params["error_description"] = error_description
redirect_url = f"{redirect_uri}?{urlencode(params)}"
return AuthorizeResponse(redirect_url=redirect_url)
# ============================================================================
# Token Endpoint
# ============================================================================
class TokenRequestByCode(BaseModel):
grant_type: Literal["authorization_code"]
code: str = Field(description="Authorization code")
redirect_uri: str = Field(
description="Redirect URI (must match authorization request)"
)
client_id: str
client_secret: str
code_verifier: str = Field(description="PKCE code verifier")
class TokenRequestByRefreshToken(BaseModel):
grant_type: Literal["refresh_token"]
refresh_token: str
client_id: str
client_secret: str
@router.post("/token")
async def token(
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
) -> TokenResponse:
"""
OAuth 2.0 Token Endpoint
Exchanges authorization code or refresh token for access token.
Grant Types:
1. authorization_code: Exchange authorization code for tokens
- Required: grant_type, code, redirect_uri, client_id, client_secret
- Optional: code_verifier (required if PKCE was used)
2. refresh_token: Exchange refresh token for new access token
- Required: grant_type, refresh_token, client_id, client_secret
Returns:
- access_token: Bearer token for API access (1 hour TTL)
- token_type: "Bearer"
- expires_in: Seconds until access token expires
- refresh_token: Token for refreshing access (30 days TTL)
- scopes: List of scopes
"""
# Validate client credentials
try:
app = await validate_client_credentials(
request.client_id, request.client_secret
)
except InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
# Handle authorization_code grant
if request.grant_type == "authorization_code":
# Consume authorization code
try:
user_id, scopes = await consume_authorization_code(
code=request.code,
application_id=app.id,
redirect_uri=request.redirect_uri,
code_verifier=request.code_verifier,
)
except InvalidGrantError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
# Create access and refresh tokens
access_token = await create_access_token(app.id, user_id, scopes)
refresh_token = await create_refresh_token(app.id, user_id, scopes)
logger.info(
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
"via authorization code"
)
if not access_token.token or not refresh_token.token:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate tokens",
)
return TokenResponse(
token_type="Bearer",
access_token=access_token.token.get_secret_value(),
access_token_expires_at=access_token.expires_at,
refresh_token=refresh_token.token.get_secret_value(),
refresh_token_expires_at=refresh_token.expires_at,
scopes=list(s.value for s in scopes),
)
# Handle refresh_token grant
elif request.grant_type == "refresh_token":
# Refresh access token
try:
new_access_token, new_refresh_token = await refresh_tokens(
request.refresh_token, app.id
)
except InvalidGrantError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
logger.info(
f"Tokens refreshed for user #{new_access_token.user_id} "
f"by app {app.name} (#{app.id})"
)
if not new_access_token.token or not new_refresh_token.token:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate tokens",
)
return TokenResponse(
token_type="Bearer",
access_token=new_access_token.token.get_secret_value(),
access_token_expires_at=new_access_token.expires_at,
refresh_token=new_refresh_token.token.get_secret_value(),
refresh_token_expires_at=new_refresh_token.expires_at,
scopes=list(s.value for s in new_access_token.scopes),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant_type: {request.grant_type}. "
"Must be 'authorization_code' or 'refresh_token'",
)
# ============================================================================
# Token Introspection Endpoint
# ============================================================================
@router.post("/introspect")
async def introspect(
token: str = Body(description="Token to introspect"),
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
None, description="Hint about token type ('access_token' or 'refresh_token')"
),
client_id: str = Body(description="Client identifier"),
client_secret: str = Body(description="Client secret"),
) -> TokenIntrospectionResult:
"""
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
Allows clients to check if a token is valid and get its metadata.
Returns:
- active: Whether the token is currently active
- scopes: List of authorized scopes (if active)
- client_id: The client the token was issued to (if active)
- user_id: The user the token represents (if active)
- exp: Expiration timestamp (if active)
- token_type: "access_token" or "refresh_token" (if active)
"""
# Validate client credentials
try:
await validate_client_credentials(client_id, client_secret)
except InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
# Introspect the token
return await introspect_token(token, token_type_hint)
# ============================================================================
# Token Revocation Endpoint
# ============================================================================
@router.post("/revoke")
async def revoke(
token: str = Body(description="Token to revoke"),
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
None, description="Hint about token type ('access_token' or 'refresh_token')"
),
client_id: str = Body(description="Client identifier"),
client_secret: str = Body(description="Client secret"),
):
"""
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
Allows clients to revoke an access or refresh token.
Note: Revoking a refresh token does NOT revoke associated access tokens.
Revoking an access token does NOT revoke the associated refresh token.
"""
# Validate client credentials
try:
app = await validate_client_credentials(client_id, client_secret)
except InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
# Try to revoke as access token first
# Note: We pass app.id to ensure the token belongs to the authenticated app
if token_type_hint != "refresh_token":
revoked = await revoke_access_token(token, app.id)
if revoked:
logger.info(
f"Access token revoked for app {app.name} (#{app.id}); "
f"user #{revoked.user_id}"
)
return {"status": "ok"}
# Try to revoke as refresh token
revoked = await revoke_refresh_token(token, app.id)
if revoked:
logger.info(
f"Refresh token revoked for app {app.name} (#{app.id}); "
f"user #{revoked.user_id}"
)
return {"status": "ok"}
# Per RFC 7009, revocation endpoint returns 200 even if token not found
# or if token belongs to a different application.
# This prevents token scanning attacks.
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
return {"status": "ok"}
# ============================================================================
# Application Management Endpoints (for app owners)
# ============================================================================
@router.get("/apps/mine")
async def list_my_oauth_apps(
user_id: str = Security(get_user_id),
) -> list[OAuthApplicationInfo]:
"""
List all OAuth applications owned by the current user.
Returns a list of OAuth applications with their details including:
- id, name, description, logo_url
- client_id (public identifier)
- redirect_uris, grant_types, scopes
- is_active status
- created_at, updated_at timestamps
Note: client_secret is never returned for security reasons.
"""
return await list_user_oauth_applications(user_id)
@router.patch("/apps/{app_id}/status")
async def update_app_status(
app_id: str,
user_id: str = Security(get_user_id),
is_active: bool = Body(description="Whether the app should be active", embed=True),
) -> OAuthApplicationInfo:
"""
Enable or disable an OAuth application.
Only the application owner can update the status.
When disabled, the application cannot be used for new authorizations
and existing access tokens will fail validation.
Returns the updated application info.
"""
updated_app = await update_oauth_application(
app_id=app_id,
owner_id=user_id,
is_active=is_active,
)
if not updated_app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found or you don't have permission to update it",
)
action = "enabled" if is_active else "disabled"
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
return updated_app
class UpdateAppLogoRequest(BaseModel):
logo_url: str = Field(description="URL of the uploaded logo image")
@router.patch("/apps/{app_id}/logo")
async def update_app_logo(
app_id: str,
request: UpdateAppLogoRequest = Body(),
user_id: str = Security(get_user_id),
) -> OAuthApplicationInfo:
"""
Update the logo URL for an OAuth application.
Only the application owner can update the logo.
The logo should be uploaded first using the media upload endpoint,
then this endpoint is called with the resulting URL.
Logo requirements:
- Must be square (1:1 aspect ratio)
- Minimum 512x512 pixels
- Maximum 2048x2048 pixels
Returns the updated application info.
"""
if (
not (app := await get_oauth_application_by_id(app_id))
or app.owner_id != user_id
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth App not found",
)
# Delete the current app logo file (if any and it's in our cloud storage)
await _delete_app_current_logo_file(app)
updated_app = await update_oauth_application(
app_id=app_id,
owner_id=user_id,
logo_url=request.logo_url,
)
if not updated_app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found or you don't have permission to update it",
)
logger.info(
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
)
return updated_app
# Logo upload constraints
LOGO_MIN_SIZE = 512
LOGO_MAX_SIZE = 2048
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
@router.post("/apps/{app_id}/logo/upload")
async def upload_app_logo(
app_id: str,
file: UploadFile,
user_id: str = Security(get_user_id),
) -> OAuthApplicationInfo:
"""
Upload a logo image for an OAuth application.
Requirements:
- Image must be square (1:1 aspect ratio)
- Minimum 512x512 pixels
- Maximum 2048x2048 pixels
- Allowed formats: JPEG, PNG, WebP
- Maximum file size: 3MB
The image is uploaded to cloud storage and the app's logoUrl is updated.
Returns the updated application info.
"""
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
if (
not (app := await get_oauth_application_by_id(app_id))
or app.owner_id != user_id
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth App not found",
)
# Check GCS configuration
if not settings.config.media_gcs_bucket_name:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Media storage is not configured",
)
# Validate content type
content_type = file.content_type
if content_type not in LOGO_ALLOWED_TYPES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
)
# Read file content
try:
file_bytes = await file.read()
except Exception as e:
logger.error(f"Error reading logo file: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to read uploaded file",
)
# Check file size
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"File too large. "
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
),
)
# Validate image dimensions
try:
image = Image.open(io.BytesIO(file_bytes))
width, height = image.size
if width != height:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Logo must be square. Got {width}x{height}",
)
if width < LOGO_MIN_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
f"Got {width}x{height}",
)
if width > LOGO_MAX_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
f"Got {width}x{height}",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error validating logo image: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid image file",
)
# Scan for viruses
filename = file.filename or "logo"
await scan_content_safe(file_bytes, filename=filename)
# Generate unique filename
file_ext = os.path.splitext(filename)[1].lower() or ".png"
unique_filename = f"{uuid.uuid4()}{file_ext}"
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
# Upload to GCS
try:
async with async_storage.Storage() as async_client:
bucket_name = settings.config.media_gcs_bucket_name
await async_client.upload(
bucket_name, storage_path, file_bytes, content_type=content_type
)
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
except Exception as e:
logger.error(f"Error uploading logo to GCS: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to upload logo",
)
# Delete the current app logo file (if any and it's in our cloud storage)
await _delete_app_current_logo_file(app)
# Update the app with the new logo URL
updated_app = await update_oauth_application(
app_id=app_id,
owner_id=user_id,
logo_url=logo_url,
)
if not updated_app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Application not found or you don't have permission to update it",
)
logger.info(
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
)
return updated_app
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
"""
Delete the current logo file for the given app, if there is one in our cloud storage
"""
bucket_name = settings.config.media_gcs_bucket_name
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
if app.logo_url and app.logo_url.startswith(storage_base_url):
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
old_path = app.logo_url.replace(storage_base_url, "")
try:
async with async_storage.Storage() as async_client:
await async_client.delete(bucket_name, old_path)
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
except Exception as e:
# Log but don't fail - the new logo was uploaded successfully
logger.warning(
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,41 +0,0 @@
from fastapi import FastAPI
def sort_openapi(app: FastAPI) -> None:
"""
Patch a FastAPI instance's `openapi()` method to sort the endpoints,
schemas, and responses.
"""
wrapped_openapi = app.openapi
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = wrapped_openapi()
# Sort endpoints
openapi_schema["paths"] = dict(sorted(openapi_schema["paths"].items()))
# Sort endpoints -> methods
for p in openapi_schema["paths"].keys():
openapi_schema["paths"][p] = dict(
sorted(openapi_schema["paths"][p].items())
)
# Sort endpoints -> methods -> responses
for m in openapi_schema["paths"][p].keys():
openapi_schema["paths"][p][m]["responses"] = dict(
sorted(openapi_schema["paths"][p][m]["responses"].items())
)
# Sort schemas and responses as well
for k in openapi_schema["components"].keys():
openapi_schema["components"][k] = dict(
sorted(openapi_schema["components"][k].items())
)
app.openapi_schema = openapi_schema
return openapi_schema
app.openapi = custom_openapi

View File

@@ -36,10 +36,10 @@ def main(**kwargs):
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
"""
from backend.api.rest_api import AgentServer
from backend.api.ws_api import WebsocketServer
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.notifications import NotificationManager
from backend.server.rest_api import AgentServer
from backend.server.ws_api import WebsocketServer
run_processes(
DatabaseManager().set_log_level("warning"),

View File

@@ -1,7 +1,6 @@
from typing import Any
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AIBlockBase,
@@ -50,7 +49,7 @@ class AIConditionBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for evaluating the condition.",
advanced=False,
)
@@ -82,7 +81,7 @@ class AIConditionBlock(AIBlockBase):
"condition": "the input is an email address",
"yes_value": "Valid email",
"no_value": "Not an email",
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,

View File

@@ -20,7 +20,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import BlockExecutionError
from backend.util.request import Requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -247,11 +246,7 @@ class AIShortformVideoCreatorBlock(Block):
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise BlockExecutionError(
message="Video creation timed out",
block_name=self.name,
block_id=self.id,
)
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
@@ -427,11 +422,7 @@ class AIAdMakerVideoCreatorBlock(Block):
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise BlockExecutionError(
message="Video creation timed out",
block_name=self.name,
block_id=self.id,
)
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
@@ -608,11 +599,7 @@ class AIScreenshotToVideoAdBlock(Block):
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise BlockExecutionError(
message="Video creation timed out",
block_name=self.name,
block_id=self.id,
)
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(

View File

@@ -1371,7 +1371,7 @@ async def create_base(
if tables:
params["tables"] = tables
logger.debug(f"Creating Airtable base with params: {params}")
print(params)
response = await Requests().post(
"https://api.airtable.com/v0/meta/bases",

View File

@@ -6,9 +6,6 @@ import hashlib
import hmac
import logging
from enum import Enum
from typing import cast
from prisma.types import Serializable
from backend.sdk import (
BaseWebhooksManager,
@@ -87,9 +84,7 @@ class AirtableWebhookManager(BaseWebhooksManager):
# update webhook config
await update_webhook(
webhook.id,
config=cast(
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
),
config={"base_id": base_id, "cursor": response.cursor},
)
event_type = "notification"

View File

@@ -81,7 +81,7 @@ class StoreValueBlock(Block):
def __init__(self):
super().__init__(
id="1ff065e9-88e8-4358-9d82-8dc91f622ba9",
description="A basic block that stores and forwards a value throughout workflows, allowing it to be reused without changes across multiple blocks.",
description="This block forwards an input value as output, allowing reuse without change.",
categories={BlockCategory.BASIC},
input_schema=StoreValueBlock.Input,
output_schema=StoreValueBlock.Output,
@@ -111,7 +111,7 @@ class PrintToConsoleBlock(Block):
def __init__(self):
super().__init__(
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
description="A debugging block that outputs text to the console for monitoring and troubleshooting workflow execution.",
description="Print the given text to the console, this is used for a debugging purpose.",
categories={BlockCategory.BASIC},
input_schema=PrintToConsoleBlock.Input,
output_schema=PrintToConsoleBlock.Output,
@@ -137,7 +137,7 @@ class NoteBlock(Block):
def __init__(self):
super().__init__(
id="cc10ff7b-7753-4ff2-9af6-9399b1a7eddc",
description="A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes.",
description="This block is used to display a sticky note with the given text.",
categories={BlockCategory.BASIC},
input_schema=NoteBlock.Input,
output_schema=NoteBlock.Output,

View File

@@ -106,10 +106,7 @@ class ConditionBlock(Block):
ComparisonOperator.LESS_THAN_OR_EQUAL: lambda a, b: a <= b,
}
try:
result = comparison_funcs[operator](value1, value2)
except Exception as e:
raise ValueError(f"Comparison failed: {e}") from e
result = comparison_funcs[operator](value1, value2)
yield "result", result

View File

@@ -159,7 +159,7 @@ class FindInDictionaryBlock(Block):
def __init__(self):
super().__init__(
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
description="A block that looks up a value in a dictionary, list, or object by key or index and returns the corresponding value.",
description="Lookup the given key in the input dictionary/object/list and return the value.",
input_schema=FindInDictionaryBlock.Input,
output_schema=FindInDictionaryBlock.Output,
test_input=[

View File

@@ -182,10 +182,13 @@ class DataForSeoRelatedKeywordsBlock(Block):
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
# Handle missing key, null value, or valid list value
if isinstance(first_result, dict):
items = first_result.get("items") or []
else:
items = (
first_result.get("items", [])
if isinstance(first_result, dict)
else []
)
# Ensure items is never None
if items is None:
items = []
for item in items:
# Extract keyword_data from the item

View File

@@ -15,7 +15,6 @@ from backend.sdk import (
SchemaField,
cost,
)
from backend.util.exceptions import BlockExecutionError
from ._config import firecrawl
@@ -60,18 +59,11 @@ class FirecrawlExtractBlock(Block):
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
try:
extract_result = app.extract(
urls=input_data.urls,
prompt=input_data.prompt,
schema=input_data.output_schema,
enable_web_search=input_data.enable_web_search,
)
except Exception as e:
raise BlockExecutionError(
message=f"Extract failed: {e}",
block_name=self.name,
block_id=self.id,
) from e
extract_result = app.extract(
urls=input_data.urls,
prompt=input_data.prompt,
schema=input_data.output_schema,
enable_web_search=input_data.enable_web_search,
)
yield "data", extract_result.data

View File

@@ -19,7 +19,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import ModerationError
from backend.util.file import MediaFileType, store_media_file
TEST_CREDENTIALS = APIKeyCredentials(
@@ -154,8 +153,6 @@ class AIImageEditorBlock(Block):
),
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
user_id=user_id,
graph_exec_id=graph_exec_id,
)
yield "output_image", result
@@ -167,8 +164,6 @@ class AIImageEditorBlock(Block):
input_image_b64: Optional[str],
aspect_ratio: str,
seed: Optional[int],
user_id: str,
graph_exec_id: str,
) -> MediaFileType:
client = ReplicateClient(api_token=api_key.get_secret_value())
input_params = {
@@ -178,21 +173,11 @@ class AIImageEditorBlock(Block):
**({"seed": seed} if seed is not None else {}),
}
try:
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
model_name,
input=input_params,
wait=False,
)
except Exception as e:
if "flagged as sensitive" in str(e).lower():
raise ModerationError(
message="Content was flagged as sensitive by the model provider",
user_id=user_id,
graph_exec_id=graph_exec_id,
moderation_type="model_provider",
)
raise ValueError(f"Model execution failed: {e}") from e
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
model_name,
input=input_params,
wait=False,
)
if isinstance(output, list) and output:
output = output[0]

View File

@@ -1,108 +0,0 @@
{
"action": "created",
"discussion": {
"repository_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"category": {
"id": 12345678,
"node_id": "DIC_kwDOJKSTjM4CXXXX",
"repository_id": 614765452,
"emoji": ":pray:",
"name": "Q&A",
"description": "Ask the community for help",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2023-03-16T09:21:07Z",
"slug": "q-a",
"is_answerable": true
},
"answer_html_url": null,
"answer_chosen_at": null,
"answer_chosen_by": null,
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/discussions/9999",
"id": 5000000001,
"node_id": "D_kwDOJKSTjM4AYYYY",
"number": 9999,
"title": "How do I configure custom blocks?",
"user": {
"login": "curious-user",
"id": 22222222,
"node_id": "MDQ6VXNlcjIyMjIyMjIy",
"avatar_url": "https://avatars.githubusercontent.com/u/22222222?v=4",
"url": "https://api.github.com/users/curious-user",
"html_url": "https://github.com/curious-user",
"type": "User",
"site_admin": false
},
"state": "open",
"state_reason": null,
"locked": false,
"comments": 0,
"created_at": "2024-12-01T17:00:00Z",
"updated_at": "2024-12-01T17:00:00Z",
"author_association": "NONE",
"active_lock_reason": null,
"body": "## Question\n\nI'm trying to create a custom block for my specific use case. I've read the documentation but I'm not sure how to:\n\n1. Define the input/output schema\n2. Handle authentication\n3. Test my block locally\n\nCan someone point me to examples or provide guidance?\n\n## Environment\n\n- AutoGPT Platform version: latest\n- Python: 3.11",
"reactions": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/discussions/9999/reactions",
"total_count": 0,
"+1": 0,
"-1": 0,
"laugh": 0,
"hooray": 0,
"confused": 0,
"heart": 0,
"rocket": 0,
"eyes": 0
},
"timeline_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/discussions/9999/timeline"
},
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T17:00:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170000,
"watchers_count": 170000,
"language": "Python",
"has_discussions": true,
"forks_count": 45000,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "curious-user",
"id": 22222222,
"node_id": "MDQ6VXNlcjIyMjIyMjIy",
"avatar_url": "https://avatars.githubusercontent.com/u/22222222?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/curious-user",
"html_url": "https://github.com/curious-user",
"type": "User",
"site_admin": false
}
}

View File

@@ -1,112 +0,0 @@
{
"action": "opened",
"issue": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345",
"repository_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"labels_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/labels{/name}",
"comments_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/comments",
"events_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/events",
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/issues/12345",
"id": 2000000001,
"node_id": "I_kwDOJKSTjM5wXXXX",
"number": 12345,
"title": "Bug: Application crashes when processing large files",
"user": {
"login": "bug-reporter",
"id": 11111111,
"node_id": "MDQ6VXNlcjExMTExMTEx",
"avatar_url": "https://avatars.githubusercontent.com/u/11111111?v=4",
"url": "https://api.github.com/users/bug-reporter",
"html_url": "https://github.com/bug-reporter",
"type": "User",
"site_admin": false
},
"labels": [
{
"id": 5272676214,
"node_id": "LA_kwDOJKSTjM8AAAABOkandg",
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/labels/bug",
"name": "bug",
"color": "d73a4a",
"default": true,
"description": "Something isn't working"
}
],
"state": "open",
"locked": false,
"assignee": null,
"assignees": [],
"milestone": null,
"comments": 0,
"created_at": "2024-12-01T16:00:00Z",
"updated_at": "2024-12-01T16:00:00Z",
"closed_at": null,
"author_association": "NONE",
"active_lock_reason": null,
"body": "## Description\n\nWhen I try to process a file larger than 100MB, the application crashes with an out of memory error.\n\n## Steps to Reproduce\n\n1. Open the application\n2. Select a file larger than 100MB\n3. Click 'Process'\n4. Application crashes\n\n## Expected Behavior\n\nThe application should handle large files gracefully.\n\n## Environment\n\n- OS: Ubuntu 22.04\n- Python: 3.11\n- AutoGPT Version: 1.0.0",
"reactions": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/reactions",
"total_count": 0,
"+1": 0,
"-1": 0,
"laugh": 0,
"hooray": 0,
"confused": 0,
"heart": 0,
"rocket": 0,
"eyes": 0
},
"timeline_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/timeline",
"state_reason": null
},
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T16:00:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170000,
"watchers_count": 170000,
"language": "Python",
"forks_count": 45000,
"open_issues_count": 190,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "bug-reporter",
"id": 11111111,
"node_id": "MDQ6VXNlcjExMTExMTEx",
"avatar_url": "https://avatars.githubusercontent.com/u/11111111?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/bug-reporter",
"html_url": "https://github.com/bug-reporter",
"type": "User",
"site_admin": false
}
}

View File

@@ -1,97 +0,0 @@
{
"action": "published",
"release": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789",
"assets_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789/assets",
"upload_url": "https://uploads.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789/assets{?name,label}",
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/releases/tag/v1.0.0",
"id": 123456789,
"author": {
"login": "ntindle",
"id": 12345678,
"node_id": "MDQ6VXNlcjEyMzQ1Njc4",
"avatar_url": "https://avatars.githubusercontent.com/u/12345678?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/ntindle",
"html_url": "https://github.com/ntindle",
"type": "User",
"site_admin": false
},
"node_id": "RE_kwDOJKSTjM4HWwAA",
"tag_name": "v1.0.0",
"target_commitish": "master",
"name": "AutoGPT Platform v1.0.0",
"draft": false,
"prerelease": false,
"created_at": "2024-12-01T10:00:00Z",
"published_at": "2024-12-01T12:00:00Z",
"assets": [
{
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/assets/987654321",
"id": 987654321,
"node_id": "RA_kwDOJKSTjM4HWwBB",
"name": "autogpt-v1.0.0.zip",
"label": "Release Package",
"content_type": "application/zip",
"state": "uploaded",
"size": 52428800,
"download_count": 0,
"created_at": "2024-12-01T11:30:00Z",
"updated_at": "2024-12-01T11:35:00Z",
"browser_download_url": "https://github.com/Significant-Gravitas/AutoGPT/releases/download/v1.0.0/autogpt-v1.0.0.zip"
}
],
"tarball_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/tarball/v1.0.0",
"zipball_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/zipball/v1.0.0",
"body": "## What's New\n\n- Feature 1: Amazing new capability\n- Feature 2: Performance improvements\n- Bug fixes and stability improvements\n\n## Breaking Changes\n\nNone\n\n## Contributors\n\nThanks to all our contributors!"
},
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T12:00:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170000,
"watchers_count": 170000,
"language": "Python",
"forks_count": 45000,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "ntindle",
"id": 12345678,
"node_id": "MDQ6VXNlcjEyMzQ1Njc4",
"avatar_url": "https://avatars.githubusercontent.com/u/12345678?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/ntindle",
"html_url": "https://github.com/ntindle",
"type": "User",
"site_admin": false
}
}

View File

@@ -1,53 +0,0 @@
{
"action": "created",
"starred_at": "2024-12-01T15:30:00Z",
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T15:30:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170001,
"watchers_count": 170001,
"language": "Python",
"forks_count": 45000,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "awesome-contributor",
"id": 98765432,
"node_id": "MDQ6VXNlcjk4NzY1NDMy",
"avatar_url": "https://avatars.githubusercontent.com/u/98765432?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/awesome-contributor",
"html_url": "https://github.com/awesome-contributor",
"type": "User",
"site_admin": false
}
}

View File

@@ -51,7 +51,7 @@ class GithubCommentBlock(Block):
def __init__(self):
super().__init__(
id="a8db4d8d-db1c-4a25-a1b0-416a8c33602b",
description="A block that posts comments on GitHub issues or pull requests using the GitHub API.",
description="This block posts a comment on a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCommentBlock.Input,
output_schema=GithubCommentBlock.Output,
@@ -151,7 +151,7 @@ class GithubUpdateCommentBlock(Block):
def __init__(self):
super().__init__(
id="b3f4d747-10e3-4e69-8c51-f2be1d99c9a7",
description="A block that updates an existing comment on a GitHub issue or pull request.",
description="This block updates a comment on a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateCommentBlock.Input,
output_schema=GithubUpdateCommentBlock.Output,
@@ -249,7 +249,7 @@ class GithubListCommentsBlock(Block):
def __init__(self):
super().__init__(
id="c4b5fb63-0005-4a11-b35a-0c2467bd6b59",
description="A block that retrieves all comments from a GitHub issue or pull request, including comment metadata and content.",
description="This block lists all comments for a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListCommentsBlock.Input,
output_schema=GithubListCommentsBlock.Output,
@@ -363,7 +363,7 @@ class GithubMakeIssueBlock(Block):
def __init__(self):
super().__init__(
id="691dad47-f494-44c3-a1e8-05b7990f2dab",
description="A block that creates new issues on GitHub repositories with a title and body content.",
description="This block creates a new issue on a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeIssueBlock.Input,
output_schema=GithubMakeIssueBlock.Output,
@@ -433,7 +433,7 @@ class GithubReadIssueBlock(Block):
def __init__(self):
super().__init__(
id="6443c75d-032a-4772-9c08-230c707c8acc",
description="A block that retrieves information about a specific GitHub issue, including its title, body content, and creator.",
description="This block reads the body, title, and user of a specified GitHub issue.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadIssueBlock.Input,
output_schema=GithubReadIssueBlock.Output,
@@ -510,7 +510,7 @@ class GithubListIssuesBlock(Block):
def __init__(self):
super().__init__(
id="c215bfd7-0e57-4573-8f8c-f7d4963dcd74",
description="A block that retrieves a list of issues from a GitHub repository with their titles and URLs.",
description="This block lists all issues for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListIssuesBlock.Input,
output_schema=GithubListIssuesBlock.Output,
@@ -597,7 +597,7 @@ class GithubAddLabelBlock(Block):
def __init__(self):
super().__init__(
id="98bd6b77-9506-43d5-b669-6b9733c4b1f1",
description="A block that adds a label to a GitHub issue or pull request for categorization and organization.",
description="This block adds a label to a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubAddLabelBlock.Input,
output_schema=GithubAddLabelBlock.Output,
@@ -657,7 +657,7 @@ class GithubRemoveLabelBlock(Block):
def __init__(self):
super().__init__(
id="78f050c5-3e3a-48c0-9e5b-ef1ceca5589c",
description="A block that removes a label from a GitHub issue or pull request.",
description="This block removes a label from a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubRemoveLabelBlock.Input,
output_schema=GithubRemoveLabelBlock.Output,
@@ -720,7 +720,7 @@ class GithubAssignIssueBlock(Block):
def __init__(self):
super().__init__(
id="90507c72-b0ff-413a-886a-23bbbd66f542",
description="A block that assigns a GitHub user to an issue for task ownership and tracking.",
description="This block assigns a user to a specified GitHub issue.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubAssignIssueBlock.Input,
output_schema=GithubAssignIssueBlock.Output,
@@ -786,7 +786,7 @@ class GithubUnassignIssueBlock(Block):
def __init__(self):
super().__init__(
id="d154002a-38f4-46c2-962d-2488f2b05ece",
description="A block that removes a user's assignment from a GitHub issue.",
description="This block unassigns a user from a specified GitHub issue.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUnassignIssueBlock.Input,
output_schema=GithubUnassignIssueBlock.Output,

View File

@@ -159,391 +159,3 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
# --8<-- [end:GithubTriggerExample]
class GithubStarTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub star events - useful for milestone celebrations."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "star.created.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#star
"""
created: bool = False
deleted: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The star events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The star event that triggered the webhook ('created' or 'deleted')"
)
starred_at: str = SchemaField(
description="ISO timestamp when the repo was starred (empty if deleted)"
)
stargazers_count: int = SchemaField(
description="Current number of stars on the repository"
)
repository_name: str = SchemaField(
description="Full name of the repository (owner/repo)"
)
repository_url: str = SchemaField(description="URL to the repository")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="551e0a35-100b-49b7-89b8-3031322239b6",
description="This block triggers on GitHub star events. "
"Useful for celebrating milestones (e.g., 1k, 10k stars) or tracking engagement.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubStarTriggerBlock.Input,
output_schema=GithubStarTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="star.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"created": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("starred_at", example_payload.get("starred_at", "")),
("stargazers_count", example_payload["repository"]["stargazers_count"]),
("repository_name", example_payload["repository"]["full_name"]),
("repository_url", example_payload["repository"]["html_url"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
yield "event", input_data.payload["action"]
yield "starred_at", input_data.payload.get("starred_at", "")
yield "stargazers_count", input_data.payload["repository"]["stargazers_count"]
yield "repository_name", input_data.payload["repository"]["full_name"]
yield "repository_url", input_data.payload["repository"]["html_url"]
class GithubReleaseTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub release events - ideal for announcing new versions."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "release.published.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#release
"""
published: bool = False
unpublished: bool = False
created: bool = False
edited: bool = False
deleted: bool = False
prereleased: bool = False
released: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The release events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The release event that triggered the webhook (e.g., 'published')"
)
release: dict = SchemaField(description="The full release object")
release_url: str = SchemaField(description="URL to the release page")
tag_name: str = SchemaField(description="The release tag name (e.g., 'v1.0.0')")
release_name: str = SchemaField(description="Human-readable release name")
body: str = SchemaField(description="Release notes/description")
prerelease: bool = SchemaField(description="Whether this is a prerelease")
draft: bool = SchemaField(description="Whether this is a draft release")
assets: list = SchemaField(description="List of release assets/files")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="2052dd1b-74e1-46ac-9c87-c7a0e057b60b",
description="This block triggers on GitHub release events. "
"Perfect for automating announcements to Discord, Twitter, or other platforms.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubReleaseTriggerBlock.Input,
output_schema=GithubReleaseTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="release.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"published": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("release", example_payload["release"]),
("release_url", example_payload["release"]["html_url"]),
("tag_name", example_payload["release"]["tag_name"]),
("release_name", example_payload["release"]["name"]),
("body", example_payload["release"]["body"]),
("prerelease", example_payload["release"]["prerelease"]),
("draft", example_payload["release"]["draft"]),
("assets", example_payload["release"]["assets"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
release = input_data.payload["release"]
yield "event", input_data.payload["action"]
yield "release", release
yield "release_url", release["html_url"]
yield "tag_name", release["tag_name"]
yield "release_name", release.get("name", "")
yield "body", release.get("body", "")
yield "prerelease", release["prerelease"]
yield "draft", release["draft"]
yield "assets", release["assets"]
class GithubIssuesTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub issues events - great for triage and notifications."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "issues.opened.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#issues
"""
opened: bool = False
edited: bool = False
deleted: bool = False
closed: bool = False
reopened: bool = False
assigned: bool = False
unassigned: bool = False
labeled: bool = False
unlabeled: bool = False
locked: bool = False
unlocked: bool = False
transferred: bool = False
milestoned: bool = False
demilestoned: bool = False
pinned: bool = False
unpinned: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The issue events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The issue event that triggered the webhook (e.g., 'opened')"
)
number: int = SchemaField(description="The issue number")
issue: dict = SchemaField(description="The full issue object")
issue_url: str = SchemaField(description="URL to the issue")
issue_title: str = SchemaField(description="The issue title")
issue_body: str = SchemaField(description="The issue body/description")
labels: list = SchemaField(description="List of labels on the issue")
assignees: list = SchemaField(description="List of assignees")
state: str = SchemaField(description="Issue state ('open' or 'closed')")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="b2605464-e486-4bf4-aad3-d8a213c8a48a",
description="This block triggers on GitHub issues events. "
"Useful for automated triage, notifications, and welcoming first-time contributors.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubIssuesTriggerBlock.Input,
output_schema=GithubIssuesTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="issues.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"opened": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("number", example_payload["issue"]["number"]),
("issue", example_payload["issue"]),
("issue_url", example_payload["issue"]["html_url"]),
("issue_title", example_payload["issue"]["title"]),
("issue_body", example_payload["issue"]["body"]),
("labels", example_payload["issue"]["labels"]),
("assignees", example_payload["issue"]["assignees"]),
("state", example_payload["issue"]["state"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
issue = input_data.payload["issue"]
yield "event", input_data.payload["action"]
yield "number", issue["number"]
yield "issue", issue
yield "issue_url", issue["html_url"]
yield "issue_title", issue["title"]
yield "issue_body", issue.get("body") or ""
yield "labels", issue["labels"]
yield "assignees", issue["assignees"]
yield "state", issue["state"]
class GithubDiscussionTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub discussion events - perfect for community Q&A sync."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "discussion.created.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#discussion
"""
created: bool = False
edited: bool = False
deleted: bool = False
answered: bool = False
unanswered: bool = False
labeled: bool = False
unlabeled: bool = False
locked: bool = False
unlocked: bool = False
category_changed: bool = False
transferred: bool = False
pinned: bool = False
unpinned: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The discussion events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The discussion event that triggered the webhook"
)
number: int = SchemaField(description="The discussion number")
discussion: dict = SchemaField(description="The full discussion object")
discussion_url: str = SchemaField(description="URL to the discussion")
title: str = SchemaField(description="The discussion title")
body: str = SchemaField(description="The discussion body")
category: dict = SchemaField(description="The discussion category object")
category_name: str = SchemaField(description="Name of the category")
state: str = SchemaField(description="Discussion state")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="87f847b3-d81a-424e-8e89-acadb5c9d52b",
description="This block triggers on GitHub Discussions events. "
"Great for syncing Q&A to Discord or auto-responding to common questions. "
"Note: Discussions must be enabled on the repository.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubDiscussionTriggerBlock.Input,
output_schema=GithubDiscussionTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="discussion.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"created": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("number", example_payload["discussion"]["number"]),
("discussion", example_payload["discussion"]),
("discussion_url", example_payload["discussion"]["html_url"]),
("title", example_payload["discussion"]["title"]),
("body", example_payload["discussion"]["body"]),
("category", example_payload["discussion"]["category"]),
("category_name", example_payload["discussion"]["category"]["name"]),
("state", example_payload["discussion"]["state"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
discussion = input_data.payload["discussion"]
yield "event", input_data.payload["action"]
yield "number", discussion["number"]
yield "discussion", discussion
yield "discussion_url", discussion["html_url"]
yield "title", discussion["title"]
yield "body", discussion.get("body") or ""
yield "category", discussion["category"]
yield "category_name", discussion["category"]["name"]
yield "state", discussion["state"]

File diff suppressed because it is too large Load Diff

View File

@@ -353,7 +353,7 @@ class GmailReadBlock(GmailBase):
def __init__(self):
super().__init__(
id="25310c70-b89b-43ba-b25c-4dfa7e2a481c",
description="A block that retrieves and reads emails from a Gmail account based on search criteria, returning detailed message information including subject, sender, body, and attachments.",
description="This block reads emails from Gmail.",
categories={BlockCategory.COMMUNICATION},
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
input_schema=GmailReadBlock.Input,
@@ -743,7 +743,7 @@ class GmailListLabelsBlock(GmailBase):
def __init__(self):
super().__init__(
id="3e1c2c1c-c689-4520-b956-1f3bf4e02bb7",
description="A block that retrieves all labels (categories) from a Gmail account for organizing and categorizing emails.",
description="This block lists all labels in Gmail.",
categories={BlockCategory.COMMUNICATION},
input_schema=GmailListLabelsBlock.Input,
output_schema=GmailListLabelsBlock.Output,
@@ -807,7 +807,7 @@ class GmailAddLabelBlock(GmailBase):
def __init__(self):
super().__init__(
id="f884b2fb-04f4-4265-9658-14f433926ac9",
description="A block that adds a label to a specific email message in Gmail, creating the label if it doesn't exist.",
description="This block adds a label to a Gmail message.",
categories={BlockCategory.COMMUNICATION},
input_schema=GmailAddLabelBlock.Input,
output_schema=GmailAddLabelBlock.Output,
@@ -893,7 +893,7 @@ class GmailRemoveLabelBlock(GmailBase):
def __init__(self):
super().__init__(
id="0afc0526-aba1-4b2b-888e-a22b7c3f359d",
description="A block that removes a label from a specific email message in a Gmail account.",
description="This block removes a label from a Gmail message.",
categories={BlockCategory.COMMUNICATION},
input_schema=GmailRemoveLabelBlock.Input,
output_schema=GmailRemoveLabelBlock.Output,
@@ -961,7 +961,7 @@ class GmailGetThreadBlock(GmailBase):
def __init__(self):
super().__init__(
id="21a79166-9df7-4b5f-9f36-96f639d86112",
description="A block that retrieves an entire Gmail thread (email conversation) by ID, returning all messages with decoded bodies for reading complete conversations.",
description="Get a full Gmail thread by ID",
categories={BlockCategory.COMMUNICATION},
input_schema=GmailGetThreadBlock.Input,
output_schema=GmailGetThreadBlock.Output,

File diff suppressed because it is too large Load Diff

View File

@@ -1,184 +0,0 @@
"""
Shared helpers for Human-In-The-Loop (HITL) review functionality.
Used by both the dedicated HumanInTheLoopBlock and blocks that require human review.
"""
import logging
from typing import Any, Optional
from prisma.enums import ReviewStatus
from pydantic import BaseModel
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
class ReviewDecision(BaseModel):
"""Result of a review decision."""
should_proceed: bool
message: str
review_result: ReviewResult
class HITLReviewHelper:
"""Helper class for Human-In-The-Loop review operations."""
@staticmethod
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
"""Create or retrieve a human review from the database."""
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)
@staticmethod
async def update_node_execution_status(**kwargs) -> None:
"""Update the execution status of a node."""
await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
@staticmethod
async def update_review_processed_status(
node_exec_id: str, processed: bool
) -> None:
"""Update the processed status of a review."""
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)
@staticmethod
async def _handle_review_request(
input_data: Any,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
Returns:
ReviewResult if review is complete, None if waiting for human input
Raises:
Exception: If review creation or status update fails
"""
# Skip review if safe mode is disabled - return auto-approved result
if not execution_context.safe_mode:
logger.info(
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
)
return ReviewResult(
data=input_data,
status=ReviewStatus.APPROVED,
message="Auto-approved (safe mode disabled)",
processed=True,
node_exec_id=node_exec_id,
)
result = await HITLReviewHelper.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data,
message=f"Review required for {block_name} execution",
editable=editable,
)
if result is None:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
if not result.processed:
await HITLReviewHelper.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)
return result
@staticmethod
async def handle_review_decision(
input_data: Any,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
Returns:
ReviewDecision if review is complete (approved/rejected),
None if execution should pause (awaiting review)
"""
review_result = await HITLReviewHelper._handle_review_request(
input_data=input_data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=block_name,
editable=editable,
)
if review_result is None:
# Still awaiting review - return None to pause execution
return None
# Review is complete, determine outcome
should_proceed = review_result.status == ReviewStatus.APPROVED
message = review_result.message or (
"Execution approved by reviewer"
if should_proceed
else "Execution rejected by reviewer"
)
return ReviewDecision(
should_proceed=should_proceed, message=message, review_result=review_result
)

View File

@@ -1,9 +1,8 @@
import logging
from typing import Any
from typing import Any, Literal
from prisma.enums import ReviewStatus
from backend.blocks.helpers.review import HITLReviewHelper
from backend.data.block import (
Block,
BlockCategory,
@@ -12,9 +11,11 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.data.model import SchemaField
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
@@ -44,11 +45,11 @@ class HumanInTheLoopBlock(Block):
)
class Output(BlockSchemaOutput):
approved_data: Any = SchemaField(
description="The data when approved (may be modified by reviewer)"
reviewed_data: Any = SchemaField(
description="The data after human review (may be modified)"
)
rejected_data: Any = SchemaField(
description="The data when rejected (may be modified by reviewer)"
status: Literal["approved", "rejected"] = SchemaField(
description="Status of the review: 'approved' or 'rejected'"
)
review_message: str = SchemaField(
description="Any message provided by the reviewer", default=""
@@ -68,29 +69,36 @@ class HumanInTheLoopBlock(Block):
"editable": True,
},
test_output=[
("approved_data", {"name": "John Doe", "age": 30}),
("status", "approved"),
("reviewed_data", {"name": "John Doe", "age": 30}),
],
test_mock={
"handle_review_decision": lambda **kwargs: type(
"ReviewDecision",
(),
{
"should_proceed": True,
"message": "Test approval message",
"review_result": ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
message="",
processed=False,
node_exec_id="test-node-exec-id",
),
},
)(),
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
message="",
processed=False,
node_exec_id="test-node-exec-id",
),
"update_node_execution_status": lambda *_args, **_kwargs: None,
"update_review_processed_status": lambda *_args, **_kwargs: None,
},
)
async def handle_review_decision(self, **kwargs):
return await HITLReviewHelper.handle_review_decision(**kwargs)
async def get_or_create_human_review(self, **kwargs):
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)
async def update_node_execution_status(self, **kwargs):
return await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
async def update_review_processed_status(self, node_exec_id: str, processed: bool):
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)
async def run(
self,
@@ -102,38 +110,60 @@ class HumanInTheLoopBlock(Block):
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
**_kwargs,
**kwargs,
) -> BlockOutput:
if not execution_context.safe_mode:
logger.info(
f"HITL block skipping review for node {node_exec_id} - safe mode disabled"
)
yield "approved_data", input_data.data
yield "status", "approved"
yield "reviewed_data", input_data.data
yield "review_message", "Auto-approved (safe mode disabled)"
return
decision = await self.handle_review_decision(
input_data=input_data.data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=input_data.editable,
)
try:
result = await self.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data.data,
message=input_data.name,
editable=input_data.editable,
)
except Exception as e:
logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}")
raise
if decision is None:
return
if result is None:
logger.info(
f"HITL block pausing execution for node {node_exec_id} - awaiting human review"
)
try:
await self.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return
except Exception as e:
logger.error(
f"Failed to update node status for HITL block {node_exec_id}: {str(e)}"
)
raise
status = decision.review_result.status
if status == ReviewStatus.APPROVED:
yield "approved_data", decision.review_result.data
elif status == ReviewStatus.REJECTED:
yield "rejected_data", decision.review_result.data
else:
raise RuntimeError(f"Unexpected review status: {status}")
if not result.processed:
await self.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)
if decision.message:
yield "review_message", decision.message
if result.status == ReviewStatus.APPROVED:
yield "status", "approved"
yield "reviewed_data", result.data
if result.message:
yield "review_message", result.message
elif result.status == ReviewStatus.REJECTED:
yield "status", "rejected"
if result.message:
yield "review_message", result.message

View File

@@ -2,6 +2,7 @@ from enum import Enum
from typing import Any, Dict, Literal, Optional
from pydantic import SecretStr
from requests.exceptions import RequestException
from backend.data.block import (
Block,
@@ -331,8 +332,8 @@ class IdeogramModelBlock(Block):
try:
response = await Requests().post(url, headers=headers, json=data)
return response.json()["data"][0]["url"]
except Exception as e:
raise ValueError(f"Failed to fetch image with V3 endpoint: {e}") from e
except RequestException as e:
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
async def _run_model_legacy(
self,
@@ -384,8 +385,8 @@ class IdeogramModelBlock(Block):
try:
response = await Requests().post(url, headers=headers, json=data)
return response.json()["data"][0]["url"]
except Exception as e:
raise ValueError(f"Failed to fetch image with legacy endpoint: {e}") from e
except RequestException as e:
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
async def upscale_image(self, api_key: SecretStr, image_url: str):
url = "https://api.ideogram.ai/upscale"
@@ -412,5 +413,5 @@ class IdeogramModelBlock(Block):
return (response.json())["data"][0]["url"]
except Exception as e:
raise ValueError(f"Failed to upscale image: {e}") from e
except RequestException as e:
raise Exception(f"Failed to upscale image: {str(e)}")

View File

@@ -76,7 +76,7 @@ class AgentInputBlock(Block):
super().__init__(
**{
"id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
"description": "A block that accepts and processes user input values within a workflow, supporting various input types and validation.",
"description": "Base block for user inputs.",
"input_schema": AgentInputBlock.Input,
"output_schema": AgentInputBlock.Output,
"test_input": [
@@ -168,7 +168,7 @@ class AgentOutputBlock(Block):
def __init__(self):
super().__init__(
id="363ae599-353e-4804-937e-b2ee3cef3da4",
description="A block that records and formats workflow results for display to users, with optional Jinja2 template formatting support.",
description="Stores the output of the graph for users to see.",
input_schema=AgentOutputBlock.Input,
output_schema=AgentOutputBlock.Output,
test_input=[

View File

@@ -16,7 +16,6 @@ from backend.data.block import (
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
class SearchTheWebBlock(Block, GetRequest):
@@ -57,17 +56,7 @@ class SearchTheWebBlock(Block, GetRequest):
# Prepend the Jina Search URL to the encoded query
jina_search_url = f"https://s.jina.ai/{encoded_query}"
try:
results = await self.get_request(
jina_search_url, headers=headers, json=False
)
except Exception as e:
raise BlockExecutionError(
message=f"Search failed: {e}",
block_name=self.name,
block_id=self.id,
) from e
results = await self.get_request(jina_search_url, headers=headers, json=False)
# Output the search results
yield "results", results

View File

@@ -92,9 +92,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5_2 = "gpt-5.2-2025-12-11"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5 = "gpt-5-2025-08-07"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
@@ -195,9 +194,8 @@ MODEL_METADATA = {
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
@@ -305,8 +303,6 @@ MODEL_METADATA = {
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
}
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
for model in LlmModel:
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
@@ -794,7 +790,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -854,12 +850,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
def __init__(self):
super().__init__(
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
description="A block that generates structured JSON responses using a Large Language Model (LLM), with schema validation and format enforcement.",
description="Call a Large Language Model (LLM) to generate formatted object based on the given prompt.",
categories={BlockCategory.AI},
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -1225,7 +1221,7 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -1265,7 +1261,7 @@ class AITextGeneratorBlock(AIBlockBase):
def __init__(self):
super().__init__(
id="1f292d4a-41a4-4977-9684-7c8d560b9f91",
description="A block that produces text responses using a Large Language Model (LLM) based on customizable prompts and system instructions.",
description="Call a Large Language Model (LLM) to generate a string based on the given prompt.",
categories={BlockCategory.AI},
input_schema=AITextGeneratorBlock.Input,
output_schema=AITextGeneratorBlock.Output,
@@ -1321,7 +1317,7 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for summarizing the text.",
)
focus: str = SchemaField(
@@ -1361,7 +1357,7 @@ class AITextSummarizerBlock(AIBlockBase):
def __init__(self):
super().__init__(
id="a0a69be1-4528-491c-a85a-a4ab6873e3f0",
description="A block that summarizes long texts using a Large Language Model (LLM), with configurable focus topics and summary styles.",
description="Utilize a Large Language Model (LLM) to summarize a long text.",
categories={BlockCategory.AI, BlockCategory.TEXT},
input_schema=AITextSummarizerBlock.Input,
output_schema=AITextSummarizerBlock.Output,
@@ -1538,7 +1534,7 @@ class AIConversationBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for the conversation.",
)
credentials: AICredentials = AICredentialsField()
@@ -1562,7 +1558,7 @@ class AIConversationBlock(AIBlockBase):
def __init__(self):
super().__init__(
id="32a87eab-381e-4dd4-bdb8-4c47151be35a",
description="A block that facilitates multi-turn conversations with a Large Language Model (LLM), maintaining context across message exchanges.",
description="Advanced LLM call that takes a list of messages and sends them to the language model.",
categories={BlockCategory.AI},
input_schema=AIConversationBlock.Input,
output_schema=AIConversationBlock.Output,
@@ -1576,7 +1572,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1639,7 +1635,7 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for generating the list.",
advanced=True,
)
@@ -1682,7 +1678,7 @@ class AIListGeneratorBlock(AIBlockBase):
def __init__(self):
super().__init__(
id="9c0b0450-d199-458b-a731-072189dd6593",
description="A block that creates lists of items based on prompts using a Large Language Model (LLM), with optional source data for context.",
description="Generate a list of values based on the given prompt using a Large Language Model (LLM).",
categories={BlockCategory.AI, BlockCategory.TEXT},
input_schema=AIListGeneratorBlock.Input,
output_schema=AIListGeneratorBlock.Output,
@@ -1696,7 +1692,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
"force_json_output": False,

File diff suppressed because it is too large Load Diff

View File

@@ -18,7 +18,6 @@ from backend.data.block import (
BlockSchemaOutput,
)
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from backend.util.exceptions import BlockExecutionError, BlockInputError
logger = logging.getLogger(__name__)
@@ -112,27 +111,9 @@ class ReplicateModelBlock(Block):
yield "status", "succeeded"
yield "model_name", input_data.model_name
except Exception as e:
error_msg = str(e)
logger.error(f"Error running Replicate model: {error_msg}")
# Input validation errors (422, 400) → BlockInputError
if (
"422" in error_msg
or "Input validation failed" in error_msg
or "400" in error_msg
):
raise BlockInputError(
message=f"Invalid model inputs: {error_msg}",
block_name=self.name,
block_id=self.id,
) from e
# Everything else → BlockExecutionError
else:
raise BlockExecutionError(
message=f"Replicate model error: {error_msg}",
block_name=self.name,
block_id=self.id,
) from e
error_msg = f"Unexpected error running Replicate model: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
"""

View File

@@ -45,16 +45,10 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
topic = input_data.topic
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
# Note: User-Agent is now automatically set by the request library
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
try:
response = await self.get_request(url, json=True)
if "extract" not in response:
raise ValueError(f"Unable to parse Wikipedia response: {response}")
yield "summary", response["extract"]
except Exception as e:
raise ValueError(f"Failed to fetch Wikipedia summary: {e}") from e
response = await self.get_request(url, json=True)
if "extract" not in response:
raise RuntimeError(f"Unable to parse Wikipedia response: {response}")
yield "summary", response["extract"]
TEST_CREDENTIALS = APIKeyCredentials(

View File

@@ -1,11 +1,8 @@
import logging
import re
from collections import Counter
from concurrent.futures import Future
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
import backend.blocks.llm as llm
from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import (
@@ -23,41 +20,16 @@ from backend.data.dynamic_fields import (
is_dynamic_field,
is_tool_pin,
)
from backend.data.execution import ExecutionContext
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json
from backend.util.clients import get_database_manager_async_client
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
if TYPE_CHECKING:
from backend.data.graph import Link, Node
from backend.executor.manager import ExecutionProcessor
logger = logging.getLogger(__name__)
class ToolInfo(BaseModel):
"""Processed tool call information."""
tool_call: Any # The original tool call object from LLM response
tool_name: str # The function name
tool_def: dict[str, Any] # The tool definition from tool_functions
input_data: dict[str, Any] # Processed input data ready for tool execution
field_mapping: dict[str, str] # Field name mapping for the tool
class ExecutionParams(BaseModel):
"""Tool execution parameters."""
user_id: str
graph_id: str
node_id: str
graph_version: int
graph_exec_id: str
node_exec_id: str
execution_context: "ExecutionContext"
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool request.
@@ -133,50 +105,6 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
return {"role": "tool", "tool_call_id": call_id, "content": content}
def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Combine multiple Anthropic tool responses into a single user message.
For non-Anthropic formats, returns the original list unchanged.
"""
if len(tool_outputs) <= 1:
return tool_outputs
# Anthropic responses have role="user", type="message", and content is a list with tool_result items
anthropic_responses = [
output
for output in tool_outputs
if (
output.get("role") == "user"
and output.get("type") == "message"
and isinstance(output.get("content"), list)
and any(
item.get("type") == "tool_result"
for item in output.get("content", [])
if isinstance(item, dict)
)
)
]
if len(anthropic_responses) > 1:
combined_content = [
item for response in anthropic_responses for item in response["content"]
]
combined_response = {
"role": "user",
"type": "message",
"content": combined_content,
}
non_anthropic_responses = [
output for output in tool_outputs if output not in anthropic_responses
]
return [combined_response] + non_anthropic_responses
return tool_outputs
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
"""
Safely convert raw_response to dictionary format for conversation history.
@@ -226,7 +154,7 @@ class SmartDecisionMakerBlock(Block):
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default=llm.DEFAULT_LLM_MODEL,
default=llm.LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -276,17 +204,6 @@ class SmartDecisionMakerBlock(Block):
default="localhost:11434",
description="Ollama host for local models",
)
agent_mode_max_iterations: int = SchemaField(
title="Agent Mode Max Iterations",
description="Maximum iterations for agent mode. 0 = traditional mode (single LLM call, yield tool calls for external execution), -1 = infinite agent mode (loop until finished), 1+ = agent mode with max iterations limit.",
advanced=True,
default=0,
)
conversation_compaction: bool = SchemaField(
default=True,
title="Context window auto-compaction",
description="Automatically compact the context window once it hits the limit",
)
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
@@ -589,7 +506,6 @@ class SmartDecisionMakerBlock(Block):
Returns the response if successful, raises ValueError if validation fails.
"""
resp = await llm.llm_call(
compress_prompt_to_fit=input_data.conversation_compaction,
credentials=credentials,
llm_model=input_data.model,
prompt=current_prompt,
@@ -677,291 +593,6 @@ class SmartDecisionMakerBlock(Block):
return resp
def _process_tool_calls(
self, response, tool_functions: list[dict[str, Any]]
) -> list[ToolInfo]:
"""Process tool calls and extract tool definitions, arguments, and input data.
Returns a list of tool info dicts with:
- tool_call: The original tool call object
- tool_name: The function name
- tool_def: The tool definition from tool_functions
- input_data: Processed input data dict (includes None values)
- field_mapping: Field name mapping for the tool
"""
if not response.tool_calls:
return []
processed_tools = []
for tool_call in response.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
tool_def = next(
(
tool
for tool in tool_functions
if tool["function"]["name"] == tool_name
),
None,
)
if not tool_def:
if len(tool_functions) == 1:
tool_def = tool_functions[0]
else:
continue
# Build input data for the tool
input_data = {}
field_mapping = tool_def["function"].get("_field_mapping", {})
if "function" in tool_def and "parameters" in tool_def["function"]:
expected_args = tool_def["function"]["parameters"].get("properties", {})
for clean_arg_name in expected_args:
original_field_name = field_mapping.get(
clean_arg_name, clean_arg_name
)
arg_value = tool_args.get(clean_arg_name)
# Include all expected parameters, even if None (for backward compatibility with tests)
input_data[original_field_name] = arg_value
processed_tools.append(
ToolInfo(
tool_call=tool_call,
tool_name=tool_name,
tool_def=tool_def,
input_data=input_data,
field_mapping=field_mapping,
)
)
return processed_tools
def _update_conversation(
self, prompt: list[dict], response, tool_outputs: list | None = None
):
"""Update conversation history with response and tool outputs."""
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
assistant_message = _convert_raw_response_to_dict(response.raw_response)
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
item.get("type") == "tool_use"
for item in assistant_message.get("content", [])
)
if response.reasoning and not has_tool_calls:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
prompt.append(assistant_message)
if tool_outputs:
prompt.extend(tool_outputs)
async def _execute_single_tool_with_manager(
self,
tool_info: ToolInfo,
execution_params: ExecutionParams,
execution_processor: "ExecutionProcessor",
) -> dict:
"""Execute a single tool using the execution manager for proper integration."""
# Lazy imports to avoid circular dependencies
from backend.data.execution import NodeExecutionEntry
tool_call = tool_info.tool_call
tool_def = tool_info.tool_def
raw_input_data = tool_info.input_data
# Get sink node and field mapping
sink_node_id = tool_def["function"]["_sink_node_id"]
# Use proper database operations for tool execution
db_client = get_database_manager_async_client()
# Get target node
target_node = await db_client.get_node(sink_node_id)
if not target_node:
raise ValueError(f"Target node {sink_node_id} not found")
# Create proper node execution using upsert_execution_input
node_exec_result = None
final_input_data = None
# Add all inputs to the execution
if not raw_input_data:
raise ValueError(f"Tool call has no input data: {tool_call}")
for input_name, input_value in raw_input_data.items():
node_exec_result, final_input_data = await db_client.upsert_execution_input(
node_id=sink_node_id,
graph_exec_id=execution_params.graph_exec_id,
input_name=input_name,
input_data=input_value,
)
assert node_exec_result is not None, "node_exec_result should not be None"
# Create NodeExecutionEntry for execution manager
node_exec_entry = NodeExecutionEntry(
user_id=execution_params.user_id,
graph_exec_id=execution_params.graph_exec_id,
graph_id=execution_params.graph_id,
graph_version=execution_params.graph_version,
node_exec_id=node_exec_result.node_exec_id,
node_id=sink_node_id,
block_id=target_node.block_id,
inputs=final_input_data or {},
execution_context=execution_params.execution_context,
)
# Use the execution manager to execute the tool node
try:
# Get NodeExecutionProgress from the execution manager's running nodes
node_exec_progress = execution_processor.running_node_execution[
sink_node_id
]
# Use the execution manager's own graph stats
graph_stats_pair = (
execution_processor.execution_stats,
execution_processor.execution_stats_lock,
)
# Create a completed future for the task tracking system
node_exec_future = Future()
node_exec_progress.add_task(
node_exec_id=node_exec_result.node_exec_id,
task=node_exec_future,
)
# Execute the node directly since we're in the SmartDecisionMaker context
node_exec_future.set_result(
await execution_processor.on_node_execution(
node_exec=node_exec_entry,
node_exec_progress=node_exec_progress,
nodes_input_masks=None,
graph_stats_pair=graph_stats_pair,
)
)
# Get outputs from database after execution completes using database manager client
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
node_exec_result.node_exec_id
)
# Create tool response
tool_response_content = (
json.dumps(node_outputs)
if node_outputs
else "Tool executed successfully"
)
return _create_tool_response(tool_call.id, tool_response_content)
except Exception as e:
logger.error(f"Tool execution with manager failed: {e}")
# Return error response
return _create_tool_response(
tool_call.id, f"Tool execution failed: {str(e)}"
)
async def _execute_tools_agent_mode(
self,
input_data,
credentials,
tool_functions: list[dict[str, Any]],
prompt: list[dict],
graph_exec_id: str,
node_id: str,
node_exec_id: str,
user_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
):
"""Execute tools in agent mode with a loop until finished."""
max_iterations = input_data.agent_mode_max_iterations
iteration = 0
# Execution parameters for tool execution
execution_params = ExecutionParams(
user_id=user_id,
graph_id=graph_id,
node_id=node_id,
graph_version=graph_version,
graph_exec_id=graph_exec_id,
node_exec_id=node_exec_id,
execution_context=execution_context,
)
current_prompt = list(prompt)
while max_iterations < 0 or iteration < max_iterations:
iteration += 1
logger.debug(f"Agent mode iteration {iteration}")
# Prepare prompt for this iteration
iteration_prompt = list(current_prompt)
# On the last iteration, add a special system message to encourage completion
if max_iterations > 0 and iteration == max_iterations:
last_iteration_message = {
"role": "system",
"content": f"{MAIN_OBJECTIVE_PREFIX}This is your last iteration ({iteration}/{max_iterations}). "
"Try to complete the task with the information you have. If you cannot fully complete it, "
"provide a summary of what you've accomplished and what remains to be done. "
"Prefer finishing with a clear response rather than making additional tool calls.",
}
iteration_prompt.append(last_iteration_message)
# Get LLM response
try:
response = await self._attempt_llm_call_with_validation(
credentials, input_data, iteration_prompt, tool_functions
)
except Exception as e:
yield "error", f"LLM call failed in agent mode iteration {iteration}: {str(e)}"
return
# Process tool calls
processed_tools = self._process_tool_calls(response, tool_functions)
# If no tool calls, we're done
if not processed_tools:
yield "finished", response.response
self._update_conversation(current_prompt, response)
yield "conversations", current_prompt
return
# Execute tools and collect responses
tool_outputs = []
for tool_info in processed_tools:
try:
tool_response = await self._execute_single_tool_with_manager(
tool_info, execution_params, execution_processor
)
tool_outputs.append(tool_response)
except Exception as e:
logger.error(f"Tool execution failed: {e}")
# Create error response for the tool
error_response = _create_tool_response(
tool_info.tool_call.id, f"Error: {str(e)}"
)
tool_outputs.append(error_response)
tool_outputs = _combine_tool_responses(tool_outputs)
self._update_conversation(current_prompt, response, tool_outputs)
# Yield intermediate conversation state
yield "conversations", current_prompt
# If we reach max iterations, yield the current state
if max_iterations < 0:
yield "finished", f"Agent mode completed after {iteration} iterations"
else:
yield "finished", f"Agent mode completed after {max_iterations} iterations (limit reached)"
yield "conversations", current_prompt
async def run(
self,
input_data: Input,
@@ -972,31 +603,9 @@ class SmartDecisionMakerBlock(Block):
graph_exec_id: str,
node_exec_id: str,
user_id: str,
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
nodes_to_skip: set[str] | None = None,
**kwargs,
) -> BlockOutput:
tool_functions = await self._create_tool_node_signatures(node_id)
original_tool_count = len(tool_functions)
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
if nodes_to_skip:
tool_functions = [
tf
for tf in tool_functions
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
]
# Only raise error if we had tools but they were all filtered out
if original_tool_count > 0 and not tool_functions:
raise ValueError(
"No available tools to execute - all downstream nodes are unavailable "
"(possibly due to missing optional credentials)"
)
yield "tool_functions", json.dumps(tool_functions)
conversation_history = input_data.conversation_history or []
@@ -1039,52 +648,24 @@ class SmartDecisionMakerBlock(Block):
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
prefix = "[Main Objective Prompt]: "
if input_data.sys_prompt and not any(
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt
):
prompt.append(
{
"role": "system",
"content": MAIN_OBJECTIVE_PREFIX + input_data.sys_prompt,
}
)
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt})
if input_data.prompt and not any(
p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt
):
prompt.append(
{"role": "user", "content": MAIN_OBJECTIVE_PREFIX + input_data.prompt}
)
prompt.append({"role": "user", "content": prefix + input_data.prompt})
# Execute tools based on the selected mode
if input_data.agent_mode_max_iterations != 0:
# In agent mode, execute tools directly in a loop until finished
async for result in self._execute_tools_agent_mode(
input_data=input_data,
credentials=credentials,
tool_functions=tool_functions,
prompt=prompt,
graph_exec_id=graph_exec_id,
node_id=node_id,
node_exec_id=node_exec_id,
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
execution_processor=execution_processor,
):
yield result
return
# One-off mode: single LLM call and yield tool calls for external execution
current_prompt = list(prompt)
max_attempts = max(1, int(input_data.retry))
response = None
last_error = None
for _ in range(max_attempts):
for attempt in range(max_attempts):
try:
response = await self._attempt_llm_call_with_validation(
credentials, input_data, current_prompt, tool_functions

View File

@@ -196,15 +196,6 @@ class TestXMLParserBlockSecurity:
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
pass
async def test_rejects_text_outside_root(self):
"""Ensure parser surfaces readable errors for invalid root text."""
block = XMLParserBlock()
invalid_xml = "<root><child>value</child></root> trailing"
with pytest.raises(ValueError, match="text outside the root element"):
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
pass
class TestStoreMediaFileSecurity:
"""Test file storage security limits."""

View File

@@ -28,7 +28,7 @@ class TestLLMStatsTracking:
response = await llm.llm_call(
credentials=llm.TEST_CREDENTIALS,
llm_model=llm.DEFAULT_LLM_MODEL,
llm_model=llm.LlmModel.GPT4O,
prompt=[{"role": "user", "content": "Hello"}],
max_tokens=100,
)
@@ -65,7 +65,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore # type: ignore
)
@@ -109,7 +109,7 @@ class TestLLMStatsTracking:
# Run the block
input_data = llm.AITextGeneratorBlock.Input(
prompt="Generate text",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -170,7 +170,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
)
@@ -228,7 +228,7 @@ class TestLLMStatsTracking:
input_data = llm.AITextSummarizerBlock.Input(
text=long_text,
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=100, # Small chunks
chunk_overlap=10,
@@ -299,7 +299,7 @@ class TestLLMStatsTracking:
# Test with very short text (should only need 1 chunk + 1 final summary)
input_data = llm.AITextSummarizerBlock.Input(
text="This is a short text.",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=1000, # Large enough to avoid chunking
)
@@ -346,7 +346,7 @@ class TestLLMStatsTracking:
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
],
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -387,7 +387,7 @@ class TestLLMStatsTracking:
# Run the block
input_data = llm.AIListGeneratorBlock.Input(
focus="test items",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_retries=3,
)
@@ -469,7 +469,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test",
expected_format={"result": "desc"},
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -513,7 +513,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
style=llm.SummaryStyle.BULLET_POINTS,
)
@@ -558,7 +558,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
style=llm.SummaryStyle.BULLET_POINTS,
max_tokens=1000,
@@ -593,7 +593,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -623,7 +623,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=1000,
)
@@ -654,7 +654,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)

View File

@@ -1,14 +1,10 @@
import logging
import threading
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.model import CreateGraph
from backend.api.rest_api import AgentServer
from backend.data.execution import ExecutionContext
from backend.data.model import ProviderName, User
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.test import SpinTestServer, wait_execution
@@ -21,10 +17,10 @@ async def create_graph(s: SpinTestServer, g, u: User):
async def create_credentials(s: SpinTestServer, u: User):
import backend.blocks.llm as llm_module
import backend.blocks.llm as llm
provider = ProviderName.OPENAI
credentials = llm_module.TEST_CREDENTIALS
credentials = llm.TEST_CREDENTIALS
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
@@ -200,6 +196,8 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
@pytest.mark.asyncio
async def test_smart_decision_maker_tracks_llm_stats():
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -218,6 +216,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
}
# Mock the _create_tool_node_signatures method to avoid database calls
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
@@ -233,21 +232,12 @@ async def test_smart_decision_maker_tracks_llm_stats():
# Create test input
input_data = SmartDecisionMakerBlock.Input(
prompt="Should I continue with this task?",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
# Execute the block
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -256,9 +246,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -276,6 +263,8 @@ async def test_smart_decision_maker_tracks_llm_stats():
@pytest.mark.asyncio
async def test_smart_decision_maker_parameter_validation():
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -322,6 +311,8 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_with_typo.reasoning = None
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -335,20 +326,11 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2, # Set retry to 2 for testing
agent_mode_max_iterations=0,
)
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
# Should raise ValueError after retries due to typo'd parameter name
with pytest.raises(ValueError) as exc_info:
outputs = {}
@@ -360,9 +342,6 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -389,6 +368,8 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_missing_required.reasoning = None
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -402,19 +383,10 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
# Should raise ValueError due to missing required parameter
with pytest.raises(ValueError) as exc_info:
outputs = {}
@@ -426,9 +398,6 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -449,6 +418,8 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_valid.reasoning = None
mock_response_valid.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -462,21 +433,12 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
# Should succeed - optional parameter missing is OK
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -485,9 +447,6 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -513,6 +472,8 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_all_params.reasoning = None
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -526,21 +487,12 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
# Should succeed with all parameters
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -549,9 +501,6 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -564,6 +513,8 @@ async def test_smart_decision_maker_parameter_validation():
@pytest.mark.asyncio
async def test_smart_decision_maker_raw_response_conversion():
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -633,6 +584,7 @@ async def test_smart_decision_maker_raw_response_conversion():
)
# Mock llm_call to return different responses on different calls
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call", new_callable=AsyncMock
@@ -648,22 +600,13 @@ async def test_smart_decision_maker_raw_response_conversion():
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
agent_mode_max_iterations=0,
)
# Should succeed after retry, demonstrating our helper function works
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -672,9 +615,6 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -710,6 +650,8 @@ async def test_smart_decision_maker_raw_response_conversion():
"I'll help you with that." # Ollama returns string
)
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -722,20 +664,11 @@ async def test_smart_decision_maker_raw_response_conversion():
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Simple prompt",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -744,9 +677,6 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
@@ -766,6 +696,8 @@ async def test_smart_decision_maker_raw_response_conversion():
"content": "Test response",
} # Dict format
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
@@ -778,20 +710,11 @@ async def test_smart_decision_maker_raw_response_conversion():
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Another test",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
@@ -800,260 +723,8 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
assert "finished" in outputs
assert outputs["finished"] == "Test response"
@pytest.mark.asyncio
async def test_smart_decision_maker_agent_mode():
"""Test that agent mode executes tools directly and loops until finished."""
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = SmartDecisionMakerBlock()
# Mock tool call that requires multiple iterations
mock_tool_call_1 = MagicMock()
mock_tool_call_1.id = "call_1"
mock_tool_call_1.function.name = "search_keywords"
mock_tool_call_1.function.arguments = (
'{"query": "test", "max_keyword_difficulty": 50}'
)
mock_response_1 = MagicMock()
mock_response_1.response = None
mock_response_1.tool_calls = [mock_tool_call_1]
mock_response_1.prompt_tokens = 50
mock_response_1.completion_tokens = 25
mock_response_1.reasoning = "Using search tool"
mock_response_1.raw_response = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_1", "type": "function"}],
}
# Final response with no tool calls (finished)
mock_response_2 = MagicMock()
mock_response_2.response = "Task completed successfully"
mock_response_2.tool_calls = []
mock_response_2.prompt_tokens = 30
mock_response_2.completion_tokens = 15
mock_response_2.reasoning = None
mock_response_2.raw_response = {
"role": "assistant",
"content": "Task completed successfully",
}
# Mock the LLM call to return different responses on each iteration
llm_call_mock = AsyncMock()
llm_call_mock.side_effect = [mock_response_1, mock_response_2]
# Mock tool node signatures
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "search_keywords",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {},
"parameters": {
"properties": {
"query": {"type": "string"},
"max_keyword_difficulty": {"type": "integer"},
},
"required": ["query", "max_keyword_difficulty"],
},
},
}
]
# Mock database and execution components
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block-id"
mock_db_client.get_node.return_value = mock_node
# Mock upsert_execution_input to return proper NodeExecutionResult and input data
mock_node_exec_result = MagicMock()
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
mock_input_data = {"query": "test", "max_keyword_difficulty": 50}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_input_data,
)
# No longer need mock_execute_node since we use execution_processor.on_node_execution
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
), patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
), patch(
"backend.executor.manager.async_update_node_execution_status",
new_callable=AsyncMock,
), patch(
"backend.integrations.creds_manager.IntegrationCredentialsManager"
):
# Create a mock execution context
mock_execution_context = ExecutionContext(
safe_mode=False,
)
# Create a mock execution processor for agent mode tests
mock_execution_processor = AsyncMock()
# Configure the execution processor mock with required attributes
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
# Mock the on_node_execution method to return successful stats
mock_node_stats = MagicMock()
mock_node_stats.error = None # No error
mock_execution_processor.on_node_execution = AsyncMock(
return_value=mock_node_stats
)
# Mock the get_execution_outputs_by_node_exec_id method
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
"result": {"status": "success", "data": "search completed"}
}
# Test agent mode with max_iterations = 3
input_data = SmartDecisionMakerBlock.Input(
prompt="Complete this task using tools",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
)
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
# Verify agent mode behavior
assert "tool_functions" in outputs # tool_functions is yielded in both modes
assert "finished" in outputs
assert outputs["finished"] == "Task completed successfully"
assert "conversations" in outputs
# Verify the conversation includes tool responses
conversations = outputs["conversations"]
assert len(conversations) > 2 # Should have multiple conversation entries
# Verify LLM was called twice (once for tool call, once for finish)
assert llm_call_mock.call_count == 2
# Verify tool was executed via execution processor
assert mock_execution_processor.on_node_execution.call_count == 1
@pytest.mark.asyncio
async def test_smart_decision_maker_traditional_mode_default():
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = SmartDecisionMakerBlock()
# Mock tool call
mock_tool_call = MagicMock()
mock_tool_call.function.name = "search_keywords"
mock_tool_call.function.arguments = (
'{"query": "test", "max_keyword_difficulty": 50}'
)
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": None}
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "search_keywords",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {},
"parameters": {
"properties": {
"query": {"type": "string"},
"max_keyword_difficulty": {"type": "integer"},
},
"required": ["query", "max_keyword_difficulty"],
},
},
}
]
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response,
), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
):
# Test default behavior (traditional mode)
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0, # Traditional mode
)
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
# Verify traditional mode behavior
assert (
"tool_functions" in outputs
) # Should yield tool_functions in traditional mode
assert (
"tools_^_test-sink-node-id_~_query" in outputs
) # Should yield individual tool parameters
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
assert "conversations" in outputs

View File

@@ -1,7 +1,7 @@
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
import json
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
@@ -308,47 +308,10 @@ async def test_output_yielding_with_dynamic_fields():
) as mock_llm:
mock_llm.return_value = mock_response
# Mock the database manager to avoid HTTP calls during tool execution
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager, patch.object(
# Mock the function signature creation
with patch.object(
block, "_create_tool_node_signatures", new_callable=AsyncMock
) as mock_sig:
# Set up the mock database manager
mock_db_client = AsyncMock()
mock_db_manager.return_value = mock_db_client
# Mock the node retrieval
mock_target_node = Mock()
mock_target_node.id = "test-sink-node-id"
mock_target_node.block_id = "CreateDictionaryBlock"
mock_target_node.block = Mock()
mock_target_node.block.name = "Create Dictionary"
mock_db_client.get_node.return_value = mock_target_node
# Mock the execution result creation
mock_node_exec_result = Mock()
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
mock_final_input_data = {
"values_#_name": "Alice",
"values_#_age": 30,
"values_#_email": "alice@example.com",
}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_final_input_data,
)
# Mock the output retrieval
mock_outputs = {
"values_#_name": "Alice",
"values_#_age": 30,
"values_#_email": "alice@example.com",
}
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
mock_outputs
)
mock_sig.return_value = [
{
"type": "function",
@@ -373,17 +336,11 @@ async def test_output_yielding_with_dynamic_fields():
input_data = block.input_schema(
prompt="Create a user dictionary",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.DEFAULT_LLM_MODEL,
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
model=llm.LlmModel.GPT4O,
)
# Run the block
outputs = {}
from backend.data.execution import ExecutionContext
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
async for output_name, output_value in block.run(
input_data,
credentials=llm.TEST_CREDENTIALS,
@@ -392,9 +349,6 @@ async def test_output_yielding_with_dynamic_fields():
graph_exec_id="test_exec",
node_exec_id="test_node_exec",
user_id="test_user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_value
@@ -557,108 +511,45 @@ async def test_validation_errors_dont_pollute_conversation():
}
]
# Mock the database manager to avoid HTTP calls during tool execution
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager:
# Set up the mock database manager for agent mode
mock_db_client = AsyncMock()
mock_db_manager.return_value = mock_db_client
# Create input data
from backend.blocks import llm
# Mock the node retrieval
mock_target_node = Mock()
mock_target_node.id = "test-sink-node-id"
mock_target_node.block_id = "TestBlock"
mock_target_node.block = Mock()
mock_target_node.block.name = "Test Block"
mock_db_client.get_node.return_value = mock_target_node
input_data = block.input_schema(
prompt="Test prompt",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O,
retry=3, # Allow retries
)
# Mock the execution result creation
mock_node_exec_result = Mock()
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
mock_final_input_data = {"correct_param": "value"}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_final_input_data,
)
# Run the block
outputs = {}
async for output_name, output_value in block.run(
input_data,
credentials=llm.TEST_CREDENTIALS,
graph_id="test_graph",
node_id="test_node",
graph_exec_id="test_exec",
node_exec_id="test_node_exec",
user_id="test_user",
):
outputs[output_name] = output_value
# Mock the output retrieval
mock_outputs = {"correct_param": "value"}
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
mock_outputs
)
# Verify we had 2 LLM calls (initial + retry)
assert call_count == 2
# Create input data
from backend.blocks import llm
# Check the final conversation output
final_conversation = outputs.get("conversations", [])
input_data = block.input_schema(
prompt="Test prompt",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.DEFAULT_LLM_MODEL,
retry=3, # Allow retries
agent_mode_max_iterations=1,
)
# The final conversation should NOT contain the validation error message
error_messages = [
msg
for msg in final_conversation
if msg.get("role") == "user"
and "parameter errors" in msg.get("content", "")
]
assert (
len(error_messages) == 0
), "Validation error leaked into final conversation"
# Run the block
outputs = {}
from backend.data.execution import ExecutionContext
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a proper mock execution processor for agent mode
from collections import defaultdict
mock_execution_processor = AsyncMock()
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = MagicMock()
# Create a mock NodeExecutionProgress for the sink node
mock_node_exec_progress = MagicMock()
mock_node_exec_progress.add_task = MagicMock()
mock_node_exec_progress.pop_output = MagicMock(
return_value=None
) # No outputs to process
# Set up running_node_execution as a defaultdict that returns our mock for any key
mock_execution_processor.running_node_execution = defaultdict(
lambda: mock_node_exec_progress
)
# Mock the on_node_execution method that gets called during tool execution
mock_node_stats = MagicMock()
mock_node_stats.error = None
mock_execution_processor.on_node_execution.return_value = (
mock_node_stats
)
async for output_name, output_value in block.run(
input_data,
credentials=llm.TEST_CREDENTIALS,
graph_id="test_graph",
node_id="test_node",
graph_exec_id="test_exec",
node_exec_id="test_node_exec",
user_id="test_user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_value
# Verify we had at least 1 LLM call
assert call_count >= 1
# Check the final conversation output
final_conversation = outputs.get("conversations", [])
# The final conversation should NOT contain validation error messages
# Even if retries don't happen in agent mode, we should not leak errors
error_messages = [
msg
for msg in final_conversation
if msg.get("role") == "user"
and "parameter errors" in msg.get("content", "")
]
assert (
len(error_messages) == 0
), "Validation error leaked into final conversation"
# The final conversation should only have the successful response
assert final_conversation[-1]["content"] == "valid"

View File

@@ -1,5 +1,5 @@
from gravitasml.parser import Parser
from gravitasml.token import Token, tokenize
from gravitasml.token import tokenize
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
from backend.data.model import SchemaField
@@ -25,38 +25,6 @@ class XMLParserBlock(Block):
],
)
@staticmethod
def _validate_tokens(tokens: list[Token]) -> None:
"""Ensure the XML has a single root element and no stray text."""
if not tokens:
raise ValueError("XML input is empty.")
depth = 0
root_seen = False
for token in tokens:
if token.type == "TAG_OPEN":
if depth == 0 and root_seen:
raise ValueError("XML must have a single root element.")
depth += 1
if depth == 1:
root_seen = True
elif token.type == "TAG_CLOSE":
depth -= 1
if depth < 0:
raise SyntaxError("Unexpected closing tag in XML input.")
elif token.type in {"TEXT", "ESCAPE"}:
if depth == 0 and token.value:
raise ValueError(
"XML contains text outside the root element; "
"wrap content in a single root tag."
)
if depth != 0:
raise SyntaxError("Unclosed tag detected in XML input.")
if not root_seen:
raise ValueError("XML must include a root element.")
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Security fix: Add size limits to prevent XML bomb attacks
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
@@ -67,9 +35,7 @@ class XMLParserBlock(Block):
)
try:
tokens = list(tokenize(input_data.input_xml))
self._validate_tokens(tokens)
tokens = tokenize(input_data.input_xml)
parser = Parser(tokens)
parsed_result = parser.parse()
yield "parsed_xml", parsed_result

View File

@@ -111,8 +111,6 @@ class TranscribeYoutubeVideoBlock(Block):
return parsed_url.path.split("/")[2]
if parsed_url.path[:3] == "/v/":
return parsed_url.path.split("/")[2]
if parsed_url.path.startswith("/shorts/"):
return parsed_url.path.split("/")[2]
raise ValueError(f"Invalid YouTube URL: {url}")
def get_transcript(

View File

@@ -244,7 +244,11 @@ def websocket(server_address: str, graph_exec_id: str):
import websockets.asyncio.client
from backend.api.ws_api import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
from backend.server.ws_api import (
WSMessage,
WSMethod,
WSSubscribeGraphExecutionRequest,
)
async def send_message(server_address: str):
uri = f"ws://{server_address}"

View File

@@ -1 +0,0 @@
"""CLI utilities for backend development & administration"""

View File

@@ -1,57 +0,0 @@
#!/usr/bin/env python3
"""
Script to generate OpenAPI JSON specification for the FastAPI app.
This script imports the FastAPI app from backend.api.rest_api and outputs
the OpenAPI specification as JSON to stdout or a specified file.
Usage:
`poetry run python generate_openapi_json.py`
`poetry run python generate_openapi_json.py --output openapi.json`
`poetry run python generate_openapi_json.py --indent 4 --output openapi.json`
"""
import json
import os
from pathlib import Path
import click
@click.command()
@click.option(
"--output",
type=click.Path(dir_okay=False, path_type=Path),
help="Output file path (default: stdout)",
)
@click.option(
"--pretty",
type=click.BOOL,
default=False,
help="Pretty-print JSON output (indented 2 spaces)",
)
def main(output: Path, pretty: bool):
"""Generate and output the OpenAPI JSON specification."""
openapi_schema = get_openapi_schema()
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
if output:
output.write_text(json_output)
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
click.echo(f"\n{json_output[:500]} ...")
else:
print(json_output)
def get_openapi_schema():
"""Get the OpenAPI schema from the FastAPI app"""
from backend.api.rest_api import app
return app.openapi()
if __name__ == "__main__":
os.environ["LOG_LEVEL"] = "ERROR" # disable stdout log output
main()

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
from backend.api.features.library.model import LibraryAgentPreset
from backend.server.v2.library.model import LibraryAgentPreset
from .graph import NodeModel
from .integrations import Webhook # noqa: F401

View File

@@ -1,45 +1,12 @@
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional
import prisma.types
from pydantic import BaseModel
from backend.data.db import query_raw_with_schema
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
class AccuracyAlertData(BaseModel):
"""Alert data when accuracy drops significantly."""
graph_id: str
user_id: Optional[str]
drop_percent: float
three_day_avg: float
seven_day_avg: float
detected_at: datetime
class AccuracyLatestData(BaseModel):
"""Latest execution accuracy data point."""
date: datetime
daily_score: Optional[float]
three_day_avg: Optional[float]
seven_day_avg: Optional[float]
fourteen_day_avg: Optional[float]
class AccuracyTrendsResponse(BaseModel):
"""Response model for accuracy trends and alerts."""
latest_data: AccuracyLatestData
alert: Optional[AccuracyAlertData]
historical_data: Optional[list[AccuracyLatestData]] = None
async def log_raw_analytics(
user_id: str,
type: str,
@@ -76,217 +43,3 @@ async def log_raw_metric(
)
return result
async def get_accuracy_trends_and_alerts(
graph_id: str,
days_back: int = 30,
user_id: Optional[str] = None,
drop_threshold: float = 10.0,
include_historical: bool = False,
) -> AccuracyTrendsResponse:
"""Get accuracy trends and detect alerts for a specific graph."""
query_template = """
WITH daily_scores AS (
SELECT
DATE(e."createdAt") as execution_date,
AVG(CASE
WHEN e.stats IS NOT NULL
AND e.stats::json->>'correctness_score' IS NOT NULL
AND e.stats::json->>'correctness_score' != 'null'
THEN (e.stats::json->>'correctness_score')::float * 100
ELSE NULL
END) as daily_score
FROM {schema_prefix}"AgentGraphExecution" e
WHERE e."agentGraphId" = $1::text
AND e."isDeleted" = false
AND e."createdAt" >= $2::timestamp
AND e."executionStatus" IN ('COMPLETED', 'FAILED', 'TERMINATED')
{user_filter}
GROUP BY DATE(e."createdAt")
HAVING COUNT(*) >= 3 -- Need at least 3 executions per day
),
trends AS (
SELECT
execution_date,
daily_score,
AVG(daily_score) OVER (
ORDER BY execution_date
ROWS BETWEEN 2 PRECEDING AND CURRENT ROW
) as three_day_avg,
AVG(daily_score) OVER (
ORDER BY execution_date
ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
) as seven_day_avg,
AVG(daily_score) OVER (
ORDER BY execution_date
ROWS BETWEEN 13 PRECEDING AND CURRENT ROW
) as fourteen_day_avg
FROM daily_scores
)
SELECT *,
CASE
WHEN three_day_avg IS NOT NULL AND seven_day_avg IS NOT NULL AND seven_day_avg > 0
THEN ((seven_day_avg - three_day_avg) / seven_day_avg * 100)
ELSE NULL
END as drop_percent
FROM trends
ORDER BY execution_date DESC
{limit_clause}
"""
start_date = datetime.now(timezone.utc) - timedelta(days=days_back)
params = [graph_id, start_date]
user_filter = ""
if user_id:
user_filter = 'AND e."userId" = $3::text'
params.append(user_id)
# Determine limit clause
limit_clause = "" if include_historical else "LIMIT 1"
final_query = query_template.format(
schema_prefix="{schema_prefix}",
user_filter=user_filter,
limit_clause=limit_clause,
)
result = await query_raw_with_schema(final_query, *params)
if not result:
return AccuracyTrendsResponse(
latest_data=AccuracyLatestData(
date=datetime.now(timezone.utc),
daily_score=None,
three_day_avg=None,
seven_day_avg=None,
fourteen_day_avg=None,
),
alert=None,
)
latest = result[0]
alert = None
if (
latest["drop_percent"] is not None
and latest["drop_percent"] >= drop_threshold
and latest["three_day_avg"] is not None
and latest["seven_day_avg"] is not None
):
alert = AccuracyAlertData(
graph_id=graph_id,
user_id=user_id,
drop_percent=float(latest["drop_percent"]),
three_day_avg=float(latest["three_day_avg"]),
seven_day_avg=float(latest["seven_day_avg"]),
detected_at=datetime.now(timezone.utc),
)
# Prepare historical data if requested
historical_data = None
if include_historical:
historical_data = []
for row in result:
historical_data.append(
AccuracyLatestData(
date=row["execution_date"],
daily_score=(
float(row["daily_score"])
if row["daily_score"] is not None
else None
),
three_day_avg=(
float(row["three_day_avg"])
if row["three_day_avg"] is not None
else None
),
seven_day_avg=(
float(row["seven_day_avg"])
if row["seven_day_avg"] is not None
else None
),
fourteen_day_avg=(
float(row["fourteen_day_avg"])
if row["fourteen_day_avg"] is not None
else None
),
)
)
return AccuracyTrendsResponse(
latest_data=AccuracyLatestData(
date=latest["execution_date"],
daily_score=(
float(latest["daily_score"])
if latest["daily_score"] is not None
else None
),
three_day_avg=(
float(latest["three_day_avg"])
if latest["three_day_avg"] is not None
else None
),
seven_day_avg=(
float(latest["seven_day_avg"])
if latest["seven_day_avg"] is not None
else None
),
fourteen_day_avg=(
float(latest["fourteen_day_avg"])
if latest["fourteen_day_avg"] is not None
else None
),
),
alert=alert,
historical_data=historical_data,
)
class MarketplaceGraphData(BaseModel):
"""Data structure for marketplace graph monitoring."""
graph_id: str
user_id: Optional[str]
execution_count: int
async def get_marketplace_graphs_for_monitoring(
days_back: int = 30,
min_executions: int = 10,
) -> list[MarketplaceGraphData]:
"""Get published marketplace graphs with recent executions for monitoring."""
query_template = """
WITH marketplace_graphs AS (
SELECT DISTINCT
slv."agentGraphId" as graph_id,
slv."agentGraphVersion" as graph_version
FROM {schema_prefix}"StoreListing" sl
JOIN {schema_prefix}"StoreListingVersion" slv ON sl."activeVersionId" = slv."id"
WHERE sl."hasApprovedVersion" = true
AND sl."isDeleted" = false
)
SELECT DISTINCT
mg.graph_id,
NULL as user_id, -- Marketplace graphs don't have a specific user_id for monitoring
COUNT(*) as execution_count
FROM marketplace_graphs mg
JOIN {schema_prefix}"AgentGraphExecution" e ON e."agentGraphId" = mg.graph_id
WHERE e."createdAt" >= $1::timestamp
AND e."isDeleted" = false
AND e."executionStatus" IN ('COMPLETED', 'FAILED', 'TERMINATED')
GROUP BY mg.graph_id
HAVING COUNT(*) >= $2
ORDER BY execution_count DESC
"""
start_date = datetime.now(timezone.utc) - timedelta(days=days_back)
result = await query_raw_with_schema(query_template, start_date, min_executions)
return [
MarketplaceGraphData(
graph_id=row["graph_id"],
user_id=row["user_id"],
execution_count=int(row["execution_count"]),
)
for row in result
]

View File

@@ -1,24 +1,22 @@
import logging
import uuid
from datetime import datetime, timezone
from typing import Literal, Optional
from typing import Optional
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission, APIKeyStatus
from prisma.models import APIKey as PrismaAPIKey
from prisma.types import APIKeyWhereUniqueInput
from pydantic import Field
from pydantic import BaseModel, Field
from backend.data.includes import MAX_USER_API_KEYS_FETCH
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from .base import APIAuthorizationInfo
logger = logging.getLogger(__name__)
keysmith = APIKeySmith()
class APIKeyInfo(APIAuthorizationInfo):
class APIKeyInfo(BaseModel):
id: str
name: str
head: str = Field(
@@ -28,9 +26,12 @@ class APIKeyInfo(APIAuthorizationInfo):
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
)
status: APIKeyStatus
permissions: list[APIKeyPermission]
created_at: datetime
last_used_at: Optional[datetime] = None
revoked_at: Optional[datetime] = None
description: Optional[str] = None
type: Literal["api_key"] = "api_key" # type: ignore
user_id: str
@staticmethod
def from_db(api_key: PrismaAPIKey):
@@ -40,7 +41,7 @@ class APIKeyInfo(APIAuthorizationInfo):
head=api_key.head,
tail=api_key.tail,
status=APIKeyStatus(api_key.status),
scopes=[APIKeyPermission(p) for p in api_key.permissions],
permissions=[APIKeyPermission(p) for p in api_key.permissions],
created_at=api_key.createdAt,
last_used_at=api_key.lastUsedAt,
revoked_at=api_key.revokedAt,
@@ -210,7 +211,7 @@ async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
return required_permission in api_key.scopes
return required_permission in api_key.permissions
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:

View File

@@ -1,15 +0,0 @@
from datetime import datetime
from typing import Literal, Optional
from prisma.enums import APIKeyPermission
from pydantic import BaseModel
class APIAuthorizationInfo(BaseModel):
user_id: str
scopes: list[APIKeyPermission]
type: Literal["oauth", "api_key"]
created_at: datetime
expires_at: Optional[datetime] = None
last_used_at: Optional[datetime] = None
revoked_at: Optional[datetime] = None

View File

@@ -1,872 +0,0 @@
"""
OAuth 2.0 Provider Data Layer
Handles management of OAuth applications, authorization codes,
access tokens, and refresh tokens.
Hashing strategy:
- Access tokens & Refresh tokens: SHA256 (deterministic, allows direct lookup by hash)
- Client secrets: Scrypt with salt (lookup by client_id, then verify with salt)
"""
import hashlib
import logging
import secrets
import uuid
from datetime import datetime, timedelta, timezone
from typing import Literal, Optional
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission as APIPermission
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
from prisma.models import OAuthApplication as PrismaOAuthApplication
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
from prisma.types import OAuthApplicationUpdateInput
from pydantic import BaseModel, Field, SecretStr
from .base import APIAuthorizationInfo
logger = logging.getLogger(__name__)
keysmith = APIKeySmith() # Only used for client secret hashing (Scrypt)
def _generate_token() -> str:
"""Generate a cryptographically secure random token."""
return secrets.token_urlsafe(32)
def _hash_token(token: str) -> str:
"""Hash a token using SHA256 (deterministic, for direct lookup)."""
return hashlib.sha256(token.encode()).hexdigest()
# Token TTLs
AUTHORIZATION_CODE_TTL = timedelta(minutes=10)
ACCESS_TOKEN_TTL = timedelta(hours=1)
REFRESH_TOKEN_TTL = timedelta(days=30)
ACCESS_TOKEN_PREFIX = "agpt_xt_"
REFRESH_TOKEN_PREFIX = "agpt_rt_"
# ============================================================================
# Exception Classes
# ============================================================================
class OAuthError(Exception):
"""Base OAuth error"""
pass
class InvalidClientError(OAuthError):
"""Invalid client_id or client_secret"""
pass
class InvalidGrantError(OAuthError):
"""Invalid or expired authorization code/refresh token"""
def __init__(self, reason: str):
self.reason = reason
super().__init__(f"Invalid grant: {reason}")
class InvalidTokenError(OAuthError):
"""Invalid, expired, or revoked token"""
def __init__(self, reason: str):
self.reason = reason
super().__init__(f"Invalid token: {reason}")
# ============================================================================
# Data Models
# ============================================================================
class OAuthApplicationInfo(BaseModel):
"""OAuth application information (without client secret hash)"""
id: str
name: str
description: Optional[str] = None
logo_url: Optional[str] = None
client_id: str
redirect_uris: list[str]
grant_types: list[str]
scopes: list[APIPermission]
owner_id: str
is_active: bool
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(app: PrismaOAuthApplication):
return OAuthApplicationInfo(
id=app.id,
name=app.name,
description=app.description,
logo_url=app.logoUrl,
client_id=app.clientId,
redirect_uris=app.redirectUris,
grant_types=app.grantTypes,
scopes=[APIPermission(s) for s in app.scopes],
owner_id=app.ownerId,
is_active=app.isActive,
created_at=app.createdAt,
updated_at=app.updatedAt,
)
class OAuthApplicationInfoWithSecret(OAuthApplicationInfo):
"""OAuth application with client secret hash (for validation)"""
client_secret_hash: str
client_secret_salt: str
@staticmethod
def from_db(app: PrismaOAuthApplication):
return OAuthApplicationInfoWithSecret(
**OAuthApplicationInfo.from_db(app).model_dump(),
client_secret_hash=app.clientSecret,
client_secret_salt=app.clientSecretSalt,
)
def verify_secret(self, plaintext_secret: str) -> bool:
"""Verify a plaintext client secret against the stored hash"""
# Use keysmith.verify_key() with stored salt
return keysmith.verify_key(
plaintext_secret, self.client_secret_hash, self.client_secret_salt
)
class OAuthAuthorizationCodeInfo(BaseModel):
"""Authorization code information"""
id: str
code: str
created_at: datetime
expires_at: datetime
application_id: str
user_id: str
scopes: list[APIPermission]
redirect_uri: str
code_challenge: Optional[str] = None
code_challenge_method: Optional[str] = None
used_at: Optional[datetime] = None
@property
def is_used(self) -> bool:
return self.used_at is not None
@staticmethod
def from_db(code: PrismaOAuthAuthorizationCode):
return OAuthAuthorizationCodeInfo(
id=code.id,
code=code.code,
created_at=code.createdAt,
expires_at=code.expiresAt,
application_id=code.applicationId,
user_id=code.userId,
scopes=[APIPermission(s) for s in code.scopes],
redirect_uri=code.redirectUri,
code_challenge=code.codeChallenge,
code_challenge_method=code.codeChallengeMethod,
used_at=code.usedAt,
)
class OAuthAccessTokenInfo(APIAuthorizationInfo):
"""Access token information"""
id: str
expires_at: datetime # type: ignore
application_id: str
type: Literal["oauth"] = "oauth" # type: ignore
@staticmethod
def from_db(token: PrismaOAuthAccessToken):
return OAuthAccessTokenInfo(
id=token.id,
user_id=token.userId,
scopes=[APIPermission(s) for s in token.scopes],
created_at=token.createdAt,
expires_at=token.expiresAt,
last_used_at=None,
revoked_at=token.revokedAt,
application_id=token.applicationId,
)
class OAuthAccessToken(OAuthAccessTokenInfo):
"""Access token with plaintext token included (sensitive)"""
token: SecretStr = Field(description="Plaintext token (sensitive)")
@staticmethod
def from_db(token: PrismaOAuthAccessToken, plaintext_token: str): # type: ignore
return OAuthAccessToken(
**OAuthAccessTokenInfo.from_db(token).model_dump(),
token=SecretStr(plaintext_token),
)
class OAuthRefreshTokenInfo(BaseModel):
"""Refresh token information"""
id: str
user_id: str
scopes: list[APIPermission]
created_at: datetime
expires_at: datetime
application_id: str
revoked_at: Optional[datetime] = None
@property
def is_revoked(self) -> bool:
return self.revoked_at is not None
@staticmethod
def from_db(token: PrismaOAuthRefreshToken):
return OAuthRefreshTokenInfo(
id=token.id,
user_id=token.userId,
scopes=[APIPermission(s) for s in token.scopes],
created_at=token.createdAt,
expires_at=token.expiresAt,
application_id=token.applicationId,
revoked_at=token.revokedAt,
)
class OAuthRefreshToken(OAuthRefreshTokenInfo):
"""Refresh token with plaintext token included (sensitive)"""
token: SecretStr = Field(description="Plaintext token (sensitive)")
@staticmethod
def from_db(token: PrismaOAuthRefreshToken, plaintext_token: str): # type: ignore
return OAuthRefreshToken(
**OAuthRefreshTokenInfo.from_db(token).model_dump(),
token=SecretStr(plaintext_token),
)
class TokenIntrospectionResult(BaseModel):
"""Result of token introspection (RFC 7662)"""
active: bool
scopes: Optional[list[str]] = None
client_id: Optional[str] = None
user_id: Optional[str] = None
exp: Optional[int] = None # Unix timestamp
token_type: Optional[Literal["access_token", "refresh_token"]] = None
# ============================================================================
# OAuth Application Management
# ============================================================================
async def get_oauth_application(client_id: str) -> Optional[OAuthApplicationInfo]:
"""Get OAuth application by client ID (without secret)"""
app = await PrismaOAuthApplication.prisma().find_unique(
where={"clientId": client_id}
)
if not app:
return None
return OAuthApplicationInfo.from_db(app)
async def get_oauth_application_with_secret(
client_id: str,
) -> Optional[OAuthApplicationInfoWithSecret]:
"""Get OAuth application by client ID (with secret hash for validation)"""
app = await PrismaOAuthApplication.prisma().find_unique(
where={"clientId": client_id}
)
if not app:
return None
return OAuthApplicationInfoWithSecret.from_db(app)
async def validate_client_credentials(
client_id: str, client_secret: str
) -> OAuthApplicationInfo:
"""
Validate client credentials and return application info.
Raises:
InvalidClientError: If client_id or client_secret is invalid, or app is inactive
"""
app = await get_oauth_application_with_secret(client_id)
if not app:
raise InvalidClientError("Invalid client_id")
if not app.is_active:
raise InvalidClientError("Application is not active")
# Verify client secret
if not app.verify_secret(client_secret):
raise InvalidClientError("Invalid client_secret")
# Return without secret hash
return OAuthApplicationInfo(**app.model_dump(exclude={"client_secret_hash"}))
def validate_redirect_uri(app: OAuthApplicationInfo, redirect_uri: str) -> bool:
"""Validate that redirect URI is registered for the application"""
return redirect_uri in app.redirect_uris
def validate_scopes(
app: OAuthApplicationInfo, requested_scopes: list[APIPermission]
) -> bool:
"""Validate that all requested scopes are allowed for the application"""
return all(scope in app.scopes for scope in requested_scopes)
# ============================================================================
# Authorization Code Flow
# ============================================================================
def _generate_authorization_code() -> str:
"""Generate a cryptographically secure authorization code"""
# 32 bytes = 256 bits of entropy
return secrets.token_urlsafe(32)
async def create_authorization_code(
application_id: str,
user_id: str,
scopes: list[APIPermission],
redirect_uri: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[Literal["S256", "plain"]] = None,
) -> OAuthAuthorizationCodeInfo:
"""
Create a new authorization code.
Expires in 10 minutes and can only be used once.
"""
code = _generate_authorization_code()
now = datetime.now(timezone.utc)
expires_at = now + AUTHORIZATION_CODE_TTL
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
data={
"id": str(uuid.uuid4()),
"code": code,
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
"redirectUri": redirect_uri,
"codeChallenge": code_challenge,
"codeChallengeMethod": code_challenge_method,
}
)
return OAuthAuthorizationCodeInfo.from_db(saved_code)
async def consume_authorization_code(
code: str,
application_id: str,
redirect_uri: str,
code_verifier: Optional[str] = None,
) -> tuple[str, list[APIPermission]]:
"""
Consume an authorization code and return (user_id, scopes).
This marks the code as used and validates:
- Code exists and matches application
- Code is not expired
- Code has not been used
- Redirect URI matches
- PKCE code verifier matches (if code challenge was provided)
Raises:
InvalidGrantError: If code is invalid, expired, used, or PKCE fails
"""
auth_code = await PrismaOAuthAuthorizationCode.prisma().find_unique(
where={"code": code}
)
if not auth_code:
raise InvalidGrantError("authorization code not found")
# Validate application
if auth_code.applicationId != application_id:
raise InvalidGrantError(
"authorization code does not belong to this application"
)
# Check if already used
if auth_code.usedAt is not None:
raise InvalidGrantError(
f"authorization code already used at {auth_code.usedAt}"
)
# Check expiration
now = datetime.now(timezone.utc)
if auth_code.expiresAt < now:
raise InvalidGrantError("authorization code expired")
# Validate redirect URI
if auth_code.redirectUri != redirect_uri:
raise InvalidGrantError("redirect_uri mismatch")
# Validate PKCE if code challenge was provided
if auth_code.codeChallenge:
if not code_verifier:
raise InvalidGrantError("code_verifier required but not provided")
if not _verify_pkce(
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
):
raise InvalidGrantError("PKCE verification failed")
# Mark code as used
await PrismaOAuthAuthorizationCode.prisma().update(
where={"code": code},
data={"usedAt": now},
)
return auth_code.userId, [APIPermission(s) for s in auth_code.scopes]
def _verify_pkce(
code_verifier: str, code_challenge: str, code_challenge_method: Optional[str]
) -> bool:
"""
Verify PKCE code verifier against code challenge.
Supports:
- S256: SHA256(code_verifier) == code_challenge
- plain: code_verifier == code_challenge
"""
if code_challenge_method == "S256":
# Hash the verifier with SHA256 and base64url encode
hashed = hashlib.sha256(code_verifier.encode("ascii")).digest()
computed_challenge = (
secrets.token_urlsafe(len(hashed)).encode("ascii").decode("ascii")
)
# For proper base64url encoding
import base64
computed_challenge = (
base64.urlsafe_b64encode(hashed).decode("ascii").rstrip("=")
)
return secrets.compare_digest(computed_challenge, code_challenge)
elif code_challenge_method == "plain" or code_challenge_method is None:
# Plain comparison
return secrets.compare_digest(code_verifier, code_challenge)
else:
logger.warning(f"Unsupported code challenge method: {code_challenge_method}")
return False
# ============================================================================
# Access Token Management
# ============================================================================
async def create_access_token(
application_id: str, user_id: str, scopes: list[APIPermission]
) -> OAuthAccessToken:
"""
Create a new access token.
Returns OAuthAccessToken (with plaintext token).
"""
plaintext_token = ACCESS_TOKEN_PREFIX + _generate_token()
token_hash = _hash_token(plaintext_token)
now = datetime.now(timezone.utc)
expires_at = now + ACCESS_TOKEN_TTL
saved_token = await PrismaOAuthAccessToken.prisma().create(
data={
"id": str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
}
)
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
async def validate_access_token(
token: str,
) -> tuple[OAuthAccessTokenInfo, OAuthApplicationInfo]:
"""
Validate an access token and return token info.
Raises:
InvalidTokenError: If token is invalid, expired, or revoked
InvalidClientError: If the client application is not marked as active
"""
token_hash = _hash_token(token)
# Direct lookup by hash
access_token = await PrismaOAuthAccessToken.prisma().find_unique(
where={"token": token_hash}, include={"Application": True}
)
if not access_token:
raise InvalidTokenError("access token not found")
if not access_token.Application: # should be impossible
raise InvalidClientError("Client application not found")
if not access_token.Application.isActive:
raise InvalidClientError("Client application is disabled")
if access_token.revokedAt is not None:
raise InvalidTokenError("access token has been revoked")
# Check expiration
now = datetime.now(timezone.utc)
if access_token.expiresAt < now:
raise InvalidTokenError("access token expired")
return (
OAuthAccessTokenInfo.from_db(access_token),
OAuthApplicationInfo.from_db(access_token.Application),
)
async def revoke_access_token(
token: str, application_id: str
) -> OAuthAccessTokenInfo | None:
"""
Revoke an access token.
Args:
token: The plaintext access token to revoke
application_id: The application ID making the revocation request.
Only tokens belonging to this application will be revoked.
Returns:
OAuthAccessTokenInfo if token was found and revoked, None otherwise.
Note:
Always performs exactly 2 DB queries regardless of outcome to prevent
timing side-channel attacks that could reveal token existence.
"""
try:
token_hash = _hash_token(token)
# Use update_many to filter by both token and applicationId
updated_count = await PrismaOAuthAccessToken.prisma().update_many(
where={
"token": token_hash,
"applicationId": application_id,
"revokedAt": None,
},
data={"revokedAt": datetime.now(timezone.utc)},
)
# Always perform second query to ensure constant time
result = await PrismaOAuthAccessToken.prisma().find_unique(
where={"token": token_hash}
)
# Only return result if we actually revoked something
if updated_count == 0:
return None
return OAuthAccessTokenInfo.from_db(result) if result else None
except Exception as e:
logger.exception(f"Error revoking access token: {e}")
return None
# ============================================================================
# Refresh Token Management
# ============================================================================
async def create_refresh_token(
application_id: str, user_id: str, scopes: list[APIPermission]
) -> OAuthRefreshToken:
"""
Create a new refresh token.
Returns OAuthRefreshToken (with plaintext token).
"""
plaintext_token = REFRESH_TOKEN_PREFIX + _generate_token()
token_hash = _hash_token(plaintext_token)
now = datetime.now(timezone.utc)
expires_at = now + REFRESH_TOKEN_TTL
saved_token = await PrismaOAuthRefreshToken.prisma().create(
data={
"id": str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
}
)
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
async def refresh_tokens(
refresh_token: str, application_id: str
) -> tuple[OAuthAccessToken, OAuthRefreshToken]:
"""
Use a refresh token to create new access and refresh tokens.
Returns (new_access_token, new_refresh_token) both with plaintext tokens included.
Raises:
InvalidGrantError: If refresh token is invalid, expired, or revoked
"""
token_hash = _hash_token(refresh_token)
# Direct lookup by hash
rt = await PrismaOAuthRefreshToken.prisma().find_unique(where={"token": token_hash})
if not rt:
raise InvalidGrantError("refresh token not found")
# NOTE: no need to check Application.isActive, this is checked by the token endpoint
if rt.revokedAt is not None:
raise InvalidGrantError("refresh token has been revoked")
# Validate application
if rt.applicationId != application_id:
raise InvalidGrantError("refresh token does not belong to this application")
# Check expiration
now = datetime.now(timezone.utc)
if rt.expiresAt < now:
raise InvalidGrantError("refresh token expired")
# Revoke old refresh token
await PrismaOAuthRefreshToken.prisma().update(
where={"token": token_hash},
data={"revokedAt": now},
)
# Create new access and refresh tokens with same scopes
scopes = [APIPermission(s) for s in rt.scopes]
new_access_token = await create_access_token(
rt.applicationId,
rt.userId,
scopes,
)
new_refresh_token = await create_refresh_token(
rt.applicationId,
rt.userId,
scopes,
)
return new_access_token, new_refresh_token
async def revoke_refresh_token(
token: str, application_id: str
) -> OAuthRefreshTokenInfo | None:
"""
Revoke a refresh token.
Args:
token: The plaintext refresh token to revoke
application_id: The application ID making the revocation request.
Only tokens belonging to this application will be revoked.
Returns:
OAuthRefreshTokenInfo if token was found and revoked, None otherwise.
Note:
Always performs exactly 2 DB queries regardless of outcome to prevent
timing side-channel attacks that could reveal token existence.
"""
try:
token_hash = _hash_token(token)
# Use update_many to filter by both token and applicationId
updated_count = await PrismaOAuthRefreshToken.prisma().update_many(
where={
"token": token_hash,
"applicationId": application_id,
"revokedAt": None,
},
data={"revokedAt": datetime.now(timezone.utc)},
)
# Always perform second query to ensure constant time
result = await PrismaOAuthRefreshToken.prisma().find_unique(
where={"token": token_hash}
)
# Only return result if we actually revoked something
if updated_count == 0:
return None
return OAuthRefreshTokenInfo.from_db(result) if result else None
except Exception as e:
logger.exception(f"Error revoking refresh token: {e}")
return None
# ============================================================================
# Token Introspection
# ============================================================================
async def introspect_token(
token: str,
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None,
) -> TokenIntrospectionResult:
"""
Introspect a token and return its metadata (RFC 7662).
Returns TokenIntrospectionResult with active=True and metadata if valid,
or active=False if the token is invalid/expired/revoked.
"""
# Try as access token first (or if hint says "access_token")
if token_type_hint != "refresh_token":
try:
token_info, app = await validate_access_token(token)
return TokenIntrospectionResult(
active=True,
scopes=list(s.value for s in token_info.scopes),
client_id=app.client_id if app else None,
user_id=token_info.user_id,
exp=int(token_info.expires_at.timestamp()),
token_type="access_token",
)
except InvalidTokenError:
pass # Try as refresh token
# Try as refresh token
token_hash = _hash_token(token)
refresh_token = await PrismaOAuthRefreshToken.prisma().find_unique(
where={"token": token_hash}
)
if refresh_token and refresh_token.revokedAt is None:
# Check if valid (not expired)
now = datetime.now(timezone.utc)
if refresh_token.expiresAt > now:
app = await get_oauth_application_by_id(refresh_token.applicationId)
return TokenIntrospectionResult(
active=True,
scopes=list(s for s in refresh_token.scopes),
client_id=app.client_id if app else None,
user_id=refresh_token.userId,
exp=int(refresh_token.expiresAt.timestamp()),
token_type="refresh_token",
)
# Token not found or inactive
return TokenIntrospectionResult(active=False)
async def get_oauth_application_by_id(app_id: str) -> Optional[OAuthApplicationInfo]:
"""Get OAuth application by ID"""
app = await PrismaOAuthApplication.prisma().find_unique(where={"id": app_id})
if not app:
return None
return OAuthApplicationInfo.from_db(app)
async def list_user_oauth_applications(user_id: str) -> list[OAuthApplicationInfo]:
"""Get all OAuth applications owned by a user"""
apps = await PrismaOAuthApplication.prisma().find_many(
where={"ownerId": user_id},
order={"createdAt": "desc"},
)
return [OAuthApplicationInfo.from_db(app) for app in apps]
async def update_oauth_application(
app_id: str,
*,
owner_id: str,
is_active: Optional[bool] = None,
logo_url: Optional[str] = None,
) -> Optional[OAuthApplicationInfo]:
"""
Update OAuth application active status.
Only the owner can update their app's status.
Returns the updated app info, or None if app not found or not owned by user.
"""
# First verify ownership
app = await PrismaOAuthApplication.prisma().find_first(
where={"id": app_id, "ownerId": owner_id}
)
if not app:
return None
patch: OAuthApplicationUpdateInput = {}
if is_active is not None:
patch["isActive"] = is_active
if logo_url:
patch["logoUrl"] = logo_url
if not patch:
return OAuthApplicationInfo.from_db(app) # return unchanged
updated_app = await PrismaOAuthApplication.prisma().update(
where={"id": app_id},
data=patch,
)
return OAuthApplicationInfo.from_db(updated_app) if updated_app else None
# ============================================================================
# Token Cleanup
# ============================================================================
async def cleanup_expired_oauth_tokens() -> dict[str, int]:
"""
Delete expired OAuth tokens from the database.
This removes:
- Expired authorization codes (10 min TTL)
- Expired access tokens (1 hour TTL)
- Expired refresh tokens (30 day TTL)
Returns a dict with counts of deleted tokens by type.
"""
now = datetime.now(timezone.utc)
# Delete expired authorization codes
codes_result = await PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"expiresAt": {"lt": now}}
)
# Delete expired access tokens
access_result = await PrismaOAuthAccessToken.prisma().delete_many(
where={"expiresAt": {"lt": now}}
)
# Delete expired refresh tokens
refresh_result = await PrismaOAuthRefreshToken.prisma().delete_many(
where={"expiresAt": {"lt": now}}
)
deleted = {
"authorization_codes": codes_result,
"access_tokens": access_result,
"refresh_tokens": refresh_result,
}
total = sum(deleted.values())
if total > 0:
logger.info(f"Cleaned up {total} expired OAuth tokens: {deleted}")
return deleted

View File

@@ -50,8 +50,6 @@ from .model import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from .graph import Link
app_config = Config()
@@ -474,7 +472,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.requires_human_review: bool = False
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -604,90 +601,16 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
async for output_name, output_data in self._execute(input_data, **kwargs):
yield output_name, output_data
except Exception as ex:
if isinstance(ex, BlockError):
raise ex
else:
raise (
BlockExecutionError
if isinstance(ex, ValueError)
else BlockUnknownError
)(
if not isinstance(ex, BlockError):
raise BlockUnknownError(
message=str(ex),
block_name=self.name,
block_id=self.id,
) from ex
async def is_block_exec_need_review(
self,
input_data: BlockInput,
*,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
**kwargs,
) -> tuple[bool, BlockInput]:
"""
Check if this block execution needs human review and handle the review process.
Returns:
Tuple of (should_pause, input_data_to_use)
- should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer)
"""
# Skip review if not required or safe mode is disabled
if not self.requires_human_review or not execution_context.safe_mode:
return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper
# Handle the review request and get decision
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=True,
)
if decision is None:
# We're awaiting review - pause execution
return True, input_data
if not decision.should_proceed:
# Review was rejected, raise an error to stop execution
raise BlockExecutionError(
message=f"Block execution rejected by reviewer: {decision.message}",
block_name=self.name,
block_id=self.id,
)
# Review was approved - use the potentially modified data
# ReviewResult.data must be a dict for block inputs
reviewed_data = decision.review_result.data
if not isinstance(reviewed_data, dict):
raise BlockExecutionError(
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
block_name=self.name,
block_id=self.id,
)
return False, reviewed_data
else:
raise ex
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Check for review requirement and get potentially modified input data
should_pause, input_data = await self.is_block_exec_need_review(
input_data, **kwargs
)
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
@@ -695,7 +618,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,

View File

@@ -59,13 +59,12 @@ from backend.integrations.credentials_store import (
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 4,
LlmModel.O3_MINI: 2,
LlmModel.O1: 16,
LlmModel.O3_MINI: 2, # $1.10 / $4.40
LlmModel.O1: 16, # $15 / $60
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5_2: 6,
LlmModel.GPT5_1: 5,
LlmModel.GPT5: 2,
LlmModel.GPT5_1: 5,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 5,
@@ -88,7 +87,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.AIML_API_LLAMA3_3_70B: 1,
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
LlmModel.LLAMA3_3_70B: 1,
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_3: 1,
LlmModel.OLLAMA_LLAMA3_2: 1,

View File

@@ -16,7 +16,6 @@ from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBala
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
from pydantic import BaseModel
from backend.api.features.admin.model import UserHistoryResponse
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.db import query_raw_with_schema
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
@@ -30,6 +29,7 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.json import SafeJson, dumps
@@ -341,19 +341,6 @@ class UserCreditBase(ABC):
if result:
# UserBalance is already updated by the CTE
# Clear insufficient funds notification flags when credits are added
# so user can receive alerts again if they run out in the future.
if transaction.amount > 0 and transaction.type in [
CreditTransactionType.GRANT,
CreditTransactionType.TOP_UP,
]:
from backend.executor.manager import (
clear_insufficient_funds_notifications,
)
await clear_insufficient_funds_notifications(user_id)
return result[0]["balance"]
async def _add_transaction(
@@ -543,22 +530,6 @@ class UserCreditBase(ABC):
if result:
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
# UserBalance is already updated by the CTE
# Clear insufficient funds notification flags when credits are added
# so user can receive alerts again if they run out in the future.
if (
amount > 0
and is_active
and transaction_type
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
):
# Lazy import to avoid circular dependency with executor.manager
from backend.executor.manager import (
clear_insufficient_funds_notifications,
)
await clear_insufficient_funds_notifications(user_id)
return new_balance, tx_key
# If no result, either user doesn't exist or insufficient balance

View File

@@ -111,7 +111,7 @@ def get_database_schema() -> str:
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
"""Execute raw SQL query with proper schema handling."""
schema = get_database_schema()
schema_prefix = f'"{schema}".' if schema != "public" else ""
schema_prefix = f"{schema}." if schema != "public" else ""
formatted_query = query_template.format(schema_prefix=schema_prefix)
import prisma as prisma_module

View File

@@ -5,7 +5,6 @@ from enum import Enum
from multiprocessing import Manager
from queue import Empty
from typing import (
TYPE_CHECKING,
Annotated,
Any,
AsyncGenerator,
@@ -66,9 +65,6 @@ from .includes import (
)
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
if TYPE_CHECKING:
pass
T = TypeVar("T")
logger = logging.getLogger(__name__)
@@ -383,7 +379,6 @@ class GraphExecutionWithNodes(GraphExecution):
self,
execution_context: ExecutionContext,
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
):
return GraphExecutionEntry(
user_id=self.user_id,
@@ -391,7 +386,6 @@ class GraphExecutionWithNodes(GraphExecution):
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip or set(),
execution_context=execution_context,
)
@@ -842,30 +836,6 @@ async def upsert_execution_output(
await AgentNodeExecutionInputOutput.prisma().create(data=data)
async def get_execution_outputs_by_node_exec_id(
node_exec_id: str,
) -> dict[str, Any]:
"""
Get all execution outputs for a specific node execution ID.
Args:
node_exec_id: The node execution ID to get outputs for
Returns:
Dictionary mapping output names to their data values
"""
outputs = await AgentNodeExecutionInputOutput.prisma().find_many(
where={"referencedByOutputExecId": node_exec_id}
)
result = {}
for output in outputs:
if output.data is not None:
result[output.name] = type_utils.convert(output.data, JsonValue)
return result
async def update_graph_execution_start_time(
graph_exec_id: str,
) -> GraphExecution | None:
@@ -1147,8 +1117,6 @@ class GraphExecutionEntry(BaseModel):
graph_id: str
graph_version: int
nodes_input_masks: Optional[NodesInputMasks] = None
nodes_to_skip: set[str] = Field(default_factory=set)
"""Node IDs that should be skipped due to optional credentials not being configured."""
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
@@ -1497,35 +1465,3 @@ async def get_graph_execution_by_share_token(
created_at=execution.createdAt,
outputs=outputs,
)
async def get_frequently_executed_graphs(
days_back: int = 30,
min_executions: int = 10,
) -> list[dict]:
"""Get graphs that have been frequently executed for monitoring."""
query_template = """
SELECT DISTINCT
e."agentGraphId" as graph_id,
e."userId" as user_id,
COUNT(*) as execution_count
FROM {schema_prefix}"AgentGraphExecution" e
WHERE e."createdAt" >= $1::timestamp
AND e."isDeleted" = false
AND e."executionStatus" IN ('COMPLETED', 'FAILED', 'TERMINATED')
GROUP BY e."agentGraphId", e."userId"
HAVING COUNT(*) >= $2
ORDER BY execution_count DESC
"""
start_date = datetime.now(timezone.utc) - timedelta(days=days_back)
result = await query_raw_with_schema(query_template, start_date, min_executions)
return [
{
"graph_id": row["graph_id"],
"user_id": row["user_id"],
"execution_count": int(row["execution_count"]),
}
for row in result
]

View File

@@ -94,15 +94,6 @@ class Node(BaseDbModel):
input_links: list[Link] = []
output_links: list[Link] = []
@property
def credentials_optional(self) -> bool:
"""
Whether credentials are optional for this node.
When True and credentials are not configured, the node will be skipped
during execution rather than causing a validation error.
"""
return self.metadata.get("credentials_optional", False)
@property
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
@@ -244,10 +235,7 @@ class BaseGraph(BaseDbModel):
return any(
node.block_id
for node in self.nodes
if (
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
or node.block.requires_human_review
)
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
)
@property
@@ -338,35 +326,7 @@ class Graph(BaseGraph):
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
schema = self._credentials_input_schema.jsonschema()
# Determine which credential fields are required based on credentials_optional metadata
graph_credentials_inputs = self.aggregate_credentials_inputs()
required_fields = []
# Build a map of node_id -> node for quick lookup
all_nodes = {node.id: node for node in self.nodes}
for sub_graph in self.sub_graphs:
for node in sub_graph.nodes:
all_nodes[node.id] = node
for field_key, (
_field_info,
node_field_pairs,
) in graph_credentials_inputs.items():
# A field is required if ANY node using it has credentials_optional=False
is_required = False
for node_id, _field_name in node_field_pairs:
node = all_nodes.get(node_id)
if node and not node.credentials_optional:
is_required = True
break
if is_required:
required_fields.append(field_key)
schema["required"] = required_fields
return schema
return self._credentials_input_schema.jsonschema()
@property
def _credentials_input_schema(self) -> type[BlockSchema]:

View File

@@ -6,14 +6,14 @@ import fastapi.exceptions
import pytest
from pytest_snapshot.plugin import Snapshot
import backend.api.features.store.model as store
from backend.api.model import CreateGraph
import backend.server.v2.store.model as store
from backend.blocks.basic import StoreValueBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.data.block import BlockSchema, BlockSchemaInput
from backend.data.graph import Graph, Link, Node
from backend.data.model import SchemaField
from backend.data.user import DEFAULT_USER_ID
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_user
from backend.util.test import SpinTestServer
@@ -396,58 +396,3 @@ async def test_access_store_listing_graph(server: SpinTestServer):
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
)
assert got_graph is not None
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================
def test_node_credentials_optional_default():
"""Test that credentials_optional defaults to False when not set in metadata."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={},
)
assert node.credentials_optional is False
def test_node_credentials_optional_true():
"""Test that credentials_optional returns True when explicitly set."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={"credentials_optional": True},
)
assert node.credentials_optional is True
def test_node_credentials_optional_false():
"""Test that credentials_optional returns False when explicitly set to False."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={"credentials_optional": False},
)
assert node.credentials_optional is False
def test_node_credentials_optional_with_other_metadata():
"""Test that credentials_optional works correctly with other metadata present."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={
"position": {"x": 100, "y": 200},
"customized_name": "My Custom Node",
"credentials_optional": True,
},
)
assert node.credentials_optional is True
assert node.metadata["position"] == {"x": 100, "y": 200}
assert node.metadata["customized_name"] == "My Custom Node"

View File

@@ -13,7 +13,7 @@ from prisma.models import PendingHumanReview
from prisma.types import PendingHumanReviewUpdateInput
from pydantic import BaseModel
from backend.api.features.executions.review.model import (
from backend.server.v2.executions.review.model import (
PendingHumanReviewModel,
SafeJsonData,
)
@@ -100,7 +100,7 @@ async def get_or_create_human_review(
return None
else:
return ReviewResult(
data=review.payload,
data=review.payload if review.status == ReviewStatus.APPROVED else None,
status=review.status,
message=review.reviewMessage or "",
processed=review.processed,

View File

@@ -23,7 +23,7 @@ from backend.util.exceptions import NotFoundError
from backend.util.json import SafeJson
if TYPE_CHECKING:
from backend.api.features.library.model import LibraryAgentPreset
from backend.server.v2.library.model import LibraryAgentPreset
from .db import BaseDbModel
from .graph import NodeModel
@@ -79,7 +79,7 @@ class WebhookWithRelations(Webhook):
# integrations.py → library/model.py → integrations.py (for Webhook)
# Runtime import is used in WebhookWithRelations.from_db() method instead
# Import at runtime to avoid circular dependency
from backend.api.features.library.model import LibraryAgentPreset
from backend.server.v2.library.model import LibraryAgentPreset
return WebhookWithRelations(
**Webhook.from_db(webhook).model_dump(),
@@ -285,8 +285,8 @@ async def unlink_webhook_from_graph(
user_id: The ID of the user (for authorization)
"""
# Avoid circular imports
from backend.api.features.library.db import set_preset_webhook
from backend.data.graph import set_node_webhook
from backend.server.v2.library.db import set_preset_webhook
# Find all nodes in this graph that use this webhook
nodes = await AgentNode.prisma().find_many(

View File

@@ -22,7 +22,7 @@ from typing import (
from urllib.parse import urlparse
from uuid import uuid4
from prisma.enums import CreditTransactionType, OnboardingStep
from prisma.enums import CreditTransactionType
from pydantic import (
BaseModel,
ConfigDict,
@@ -868,20 +868,3 @@ class UserExecutionSummaryStats(BaseModel):
total_execution_time: float = Field(default=0)
average_execution_time: float = Field(default=0)
cost_breakdown: dict[str, float] = Field(default_factory=dict)
class UserOnboarding(BaseModel):
userId: str
completedSteps: list[OnboardingStep]
walletShown: bool
notified: list[OnboardingStep]
rewardedFor: list[OnboardingStep]
usageReason: Optional[str]
integrations: list[str]
otherIntegrations: Optional[str]
selectedStoreListingVersionId: Optional[str]
agentInput: Optional[dict[str, Any]]
onboardingAgentExecutionId: Optional[str]
agentRuns: int
lastRunAt: Optional[datetime]
consecutiveRunDays: int

View File

@@ -2,10 +2,10 @@ from __future__ import annotations
from typing import AsyncGenerator
from pydantic import BaseModel, field_serializer
from pydantic import BaseModel
from backend.api.model import NotificationPayload
from backend.data.event_bus import AsyncRedisEventBus
from backend.server.model import NotificationPayload
from backend.util.settings import Settings
@@ -15,11 +15,6 @@ class NotificationEvent(BaseModel):
user_id: str
payload: NotificationPayload
@field_serializer("payload")
def serialize_payload(self, payload: NotificationPayload):
"""Ensure extra fields survive Redis serialization."""
return payload.model_dump()
class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
Model = NotificationEvent # type: ignore

View File

@@ -1,7 +1,6 @@
import re
from datetime import datetime, timedelta, timezone
from typing import Any, Literal, Optional
from zoneinfo import ZoneInfo
from datetime import datetime
from typing import Any, Optional
import prisma
import pydantic
@@ -9,18 +8,17 @@ from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from backend.api.features.store.model import StoreAgentDetails
from backend.api.model import OnboardingNotificationPayload
from backend.data import execution as execution_db
from backend.data.block import get_blocks
from backend.data.credit import get_user_credit_model
from backend.data.model import CredentialsMetaInput
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
)
from backend.data.user import get_user_by_id
from backend.server.model import OnboardingNotificationPayload
from backend.server.v2.store.model import StoreAgentDetails
from backend.util.cache import cached
from backend.util.json import SafeJson
from backend.util.timezone_utils import get_user_timezone_or_utc
# Mapping from user reason id to categories to search for when choosing agent to show
REASON_MAPPING: dict[str, list[str]] = {
@@ -33,20 +31,9 @@ REASON_MAPPING: dict[str, list[str]] = {
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
FrontendOnboardingStep = Literal[
OnboardingStep.WELCOME,
OnboardingStep.USAGE_REASON,
OnboardingStep.INTEGRATIONS,
OnboardingStep.AGENT_CHOICE,
OnboardingStep.AGENT_NEW_RUN,
OnboardingStep.AGENT_INPUT,
OnboardingStep.CONGRATS,
OnboardingStep.MARKETPLACE_VISIT,
OnboardingStep.BUILDER_OPEN,
]
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
walletShown: Optional[bool] = None
notified: Optional[list[OnboardingStep]] = None
usageReason: Optional[str] = None
@@ -55,6 +42,9 @@ class UserOnboardingUpdate(pydantic.BaseModel):
selectedStoreListingVersionId: Optional[str] = None
agentInput: Optional[dict[str, Any]] = None
onboardingAgentExecutionId: Optional[str] = None
agentRuns: Optional[int] = None
lastRunAt: Optional[datetime] = None
consecutiveRunDays: Optional[int] = None
async def get_user_onboarding(user_id: str):
@@ -93,6 +83,14 @@ async def reset_user_onboarding(user_id: str):
async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update: UserOnboardingUpdateInput = {}
onboarding = await get_user_onboarding(user_id)
if data.completedSteps is not None:
update["completedSteps"] = list(
set(data.completedSteps + onboarding.completedSteps)
)
for step in data.completedSteps:
if step not in onboarding.completedSteps:
await _reward_user(user_id, onboarding, step)
await _send_onboarding_notification(user_id, step)
if data.walletShown:
update["walletShown"] = data.walletShown
if data.notified is not None:
@@ -109,6 +107,12 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update["agentInput"] = SafeJson(data.agentInput)
if data.onboardingAgentExecutionId is not None:
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
if data.agentRuns is not None and data.agentRuns > onboarding.agentRuns:
update["agentRuns"] = data.agentRuns
if data.lastRunAt is not None:
update["lastRunAt"] = data.lastRunAt
if data.consecutiveRunDays is not None:
update["consecutiveRunDays"] = data.consecutiveRunDays
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
@@ -157,12 +161,14 @@ async def _reward_user(user_id: str, onboarding: UserOnboarding, step: Onboardin
if step in onboarding.rewardedFor:
return
onboarding.rewardedFor.append(step)
user_credit_model = await get_user_credit_model(user_id)
await user_credit_model.onboarding_reward(user_id, reward, step)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"rewardedFor": list(set(onboarding.rewardedFor + [step])),
"completedSteps": list(set(onboarding.completedSteps + [step])),
"rewardedFor": onboarding.rewardedFor,
},
)
@@ -171,52 +177,31 @@ async def complete_onboarding_step(user_id: str, step: OnboardingStep):
"""
Completes the specified onboarding step for the user if not already completed.
"""
onboarding = await get_user_onboarding(user_id)
if step not in onboarding.completedSteps:
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"completedSteps": list(set(onboarding.completedSteps + [step])),
},
await update_user_onboarding(
user_id,
UserOnboardingUpdate(completedSteps=onboarding.completedSteps + [step]),
)
await _reward_user(user_id, onboarding, step)
await _send_onboarding_notification(user_id, step)
async def _send_onboarding_notification(
user_id: str, step: OnboardingStep | None, event: str = "step_completed"
):
async def _send_onboarding_notification(user_id: str, step: OnboardingStep):
"""
Sends an onboarding notification to the user.
Sends an onboarding notification to the user for the specified step.
"""
payload = OnboardingNotificationPayload(
type="onboarding",
event=event,
step=step,
event="step_completed",
step=step.value,
)
await AsyncRedisNotificationEventBus().publish(
NotificationEvent(user_id=user_id, payload=payload)
)
async def complete_re_run_agent(user_id: str, graph_id: str) -> None:
"""
Complete RE_RUN_AGENT step when a user runs a graph they've run before.
Keeps overhead low by only counting executions if the step is still pending.
"""
onboarding = await get_user_onboarding(user_id)
if OnboardingStep.RE_RUN_AGENT in onboarding.completedSteps:
return
# Includes current execution, so count > 1 means there was at least one prior run.
previous_exec_count = await execution_db.get_graph_executions_count(
user_id=user_id, graph_id=graph_id
)
if previous_exec_count > 1:
await complete_onboarding_step(user_id, OnboardingStep.RE_RUN_AGENT)
def _clean_and_split(text: str) -> list[str]:
def clean_and_split(text: str) -> list[str]:
"""
Removes all special characters from a string, truncates it to 100 characters,
and splits it by whitespace and commas.
@@ -239,7 +224,7 @@ def _clean_and_split(text: str) -> list[str]:
return words
def _calculate_points(
def calculate_points(
agent, categories: list[str], custom: list[str], integrations: list[str]
) -> int:
"""
@@ -283,85 +268,18 @@ def _calculate_points(
return int(points)
def _normalize_datetime(value: datetime | None) -> datetime | None:
if value is None:
return None
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def get_credentials_blocks() -> dict[str, str]:
# Returns a dictionary of block id to credentials field name
creds: dict[str, str] = {}
blocks = get_blocks()
for id, block in blocks.items():
for field_name, field_info in block().input_schema.model_fields.items():
if field_info.annotation == CredentialsMetaInput:
creds[id] = field_name
return creds
def _calculate_consecutive_run_days(
last_run_at: datetime | None, current_consecutive_days: int, user_timezone: str
) -> tuple[datetime, int]:
tz = ZoneInfo(user_timezone)
local_now = datetime.now(tz)
normalized_last_run = _normalize_datetime(last_run_at)
if normalized_last_run is None:
return local_now.astimezone(timezone.utc), 1
last_run_local = normalized_last_run.astimezone(tz)
last_run_date = last_run_local.date()
today = local_now.date()
if last_run_date == today:
return local_now.astimezone(timezone.utc), current_consecutive_days
if last_run_date == today - timedelta(days=1):
return local_now.astimezone(timezone.utc), current_consecutive_days + 1
return local_now.astimezone(timezone.utc), 1
def _get_run_milestone_steps(
new_run_count: int, consecutive_days: int
) -> list[OnboardingStep]:
milestones: list[OnboardingStep] = []
if new_run_count >= 10:
milestones.append(OnboardingStep.RUN_AGENTS)
if new_run_count >= 100:
milestones.append(OnboardingStep.RUN_AGENTS_100)
if consecutive_days >= 3:
milestones.append(OnboardingStep.RUN_3_DAYS)
if consecutive_days >= 14:
milestones.append(OnboardingStep.RUN_14_DAYS)
return milestones
async def _get_user_timezone(user_id: str) -> str:
user = await get_user_by_id(user_id)
return get_user_timezone_or_utc(user.timezone if user else None)
async def increment_runs(user_id: str):
"""
Increment a user's run counters and trigger any onboarding milestones.
"""
user_timezone = await _get_user_timezone(user_id)
onboarding = await get_user_onboarding(user_id)
new_run_count = onboarding.agentRuns + 1
last_run_at, consecutive_run_days = _calculate_consecutive_run_days(
onboarding.lastRunAt, onboarding.consecutiveRunDays, user_timezone
)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"agentRuns": {"increment": 1},
"lastRunAt": last_run_at,
"consecutiveRunDays": consecutive_run_days,
},
)
milestones = _get_run_milestone_steps(new_run_count, consecutive_run_days)
new_steps = [step for step in milestones if step not in onboarding.completedSteps]
for step in new_steps:
await complete_onboarding_step(user_id, step)
# Send progress notification if no steps were completed, so client refetches onboarding state
if not new_steps:
await _send_onboarding_notification(user_id, None, event="increment_runs")
CREDENTIALS_FIELDS: dict[str, str] = get_credentials_blocks()
async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
@@ -370,7 +288,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
where_clause: dict[str, Any] = {}
custom = _clean_and_split((user_onboarding.usageReason or "").lower())
custom = clean_and_split((user_onboarding.usageReason or "").lower())
if categories:
where_clause["OR"] = [
@@ -418,7 +336,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
# Calculate points for the first X agents and choose the top 2
agent_points = []
for agent in storeAgents[:POINTS_AGENT_COUNT]:
points = _calculate_points(
points = calculate_points(
agent, categories, custom, user_onboarding.integrations
)
agent_points.append((agent, points))
@@ -432,7 +350,6 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_output_demo=agent.agent_output_demo or "",
agent_image=agent.agent_image,
creator=agent.creator_username,
creator_avatar=agent.creator_avatar,
@@ -442,8 +359,6 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
agentGraphVersions=agent.agentGraphVersions,
agentGraphId=agent.agentGraphId,
last_updated=agent.updated_at,
)
for agent in recommended_agents

View File

@@ -2,24 +2,13 @@ import logging
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.api.features.library.db import (
add_store_agent_to_library,
list_library_agents,
)
from backend.api.features.store.db import get_store_agent_details, get_store_agents
from backend.data import db
from backend.data.analytics import (
get_accuracy_trends_and_alerts,
get_marketplace_graphs_for_monitoring,
)
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
create_graph_execution,
get_block_error_stats,
get_child_graph_executions,
get_execution_kv_data,
get_execution_outputs_by_node_exec_id,
get_frequently_executed_graphs,
get_graph_execution_meta,
get_graph_executions,
get_graph_executions_count,
@@ -66,6 +55,8 @@ from backend.data.user import (
get_user_notification_preference,
update_user_integrations,
)
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
from backend.util.service import (
AppService,
AppServiceClient,
@@ -151,13 +142,9 @@ class DatabaseManager(AppService):
update_graph_execution_stats = _(update_graph_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
get_execution_outputs_by_node_exec_id = _(get_execution_outputs_by_node_exec_id)
get_execution_kv_data = _(get_execution_kv_data)
set_execution_kv_data = _(set_execution_kv_data)
get_block_error_stats = _(get_block_error_stats)
get_accuracy_trends_and_alerts = _(get_accuracy_trends_and_alerts)
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
# Graphs
get_node = _(get_node)
@@ -239,10 +226,6 @@ class DatabaseManagerClient(AppServiceClient):
# Block error monitoring
get_block_error_stats = _(d.get_block_error_stats)
# Execution accuracy monitoring
get_accuracy_trends_and_alerts = _(d.get_accuracy_trends_and_alerts)
get_frequently_executed_graphs = _(d.get_frequently_executed_graphs)
get_marketplace_graphs_for_monitoring = _(d.get_marketplace_graphs_for_monitoring)
# Human In The Loop
has_pending_reviews_for_graph_exec = _(d.has_pending_reviews_for_graph_exec)
@@ -282,7 +265,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_user_integrations = d.get_user_integrations
upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
update_graph_execution_stats = d.update_graph_execution_stats
update_node_execution_status = d.update_node_execution_status
update_node_execution_status_batch = d.update_node_execution_status_batch

View File

@@ -48,8 +48,27 @@ from backend.data.notifications import (
ZeroBalanceData,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.activity_status_generator import (
generate_activity_status_for_execution,
)
from backend.executor.utils import (
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_EXCHANGE,
GRAPH_EXECUTION_QUEUE_NAME,
GRAPH_EXECUTION_ROUTING_KEY,
CancelExecutionEvent,
ExecutionOutputEntry,
LogMetadata,
NodeExecutionProgress,
block_usage_cost,
create_execution_queue_config,
execution_usage_cost,
validate_exec,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.notifications.notifications import queue_notification
from backend.server.v2.AutoMod.manager import automod_manager
from backend.util import json
from backend.util.clients import (
get_async_execution_event_bus,
@@ -76,24 +95,7 @@ from backend.util.retry import (
)
from backend.util.settings import Settings
from .activity_status_generator import generate_activity_status_for_execution
from .automod.manager import automod_manager
from .cluster_lock import ClusterLock
from .utils import (
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_EXCHANGE,
GRAPH_EXECUTION_QUEUE_NAME,
GRAPH_EXECUTION_ROUTING_KEY,
CancelExecutionEvent,
ExecutionOutputEntry,
LogMetadata,
NodeExecutionProgress,
block_usage_cost,
create_execution_queue_config,
execution_usage_cost,
validate_exec,
)
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
@@ -114,40 +116,6 @@ utilization_gauge = Gauge(
"Ratio of active graph runs to max graph workers",
)
# Redis key prefix for tracking insufficient funds Discord notifications.
# We only send one notification per user per agent until they top up credits.
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
# TTL for the notification flag (30 days) - acts as a fallback cleanup
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
async def clear_insufficient_funds_notifications(user_id: str) -> int:
"""
Clear all insufficient funds notification flags for a user.
This should be called when a user tops up their credits, allowing
Discord notifications to be sent again if they run out of funds.
Args:
user_id: The user ID to clear notifications for.
Returns:
The number of keys that were deleted.
"""
try:
redis_client = await redis.get_redis_async()
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
keys = [key async for key in redis_client.scan_iter(match=pattern)]
if keys:
return await redis_client.delete(*keys)
return 0
except Exception as e:
logger.warning(
f"Failed to clear insufficient funds notification flags for user "
f"{user_id}: {e}"
)
return 0
# Thread-local storage for ExecutionProcessor instances
_tls = threading.local()
@@ -165,8 +133,9 @@ def execute_graph(
cluster_lock: ClusterLock,
):
"""Execute graph using thread-local ExecutionProcessor instance"""
processor: ExecutionProcessor = _tls.processor
return processor.on_graph_execution(graph_exec_entry, cancel_event, cluster_lock)
return _tls.processor.on_graph_execution(
graph_exec_entry, cancel_event, cluster_lock
)
T = TypeVar("T")
@@ -174,11 +143,10 @@ T = TypeVar("T")
async def execute_node(
node: Node,
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_processor: "ExecutionProcessor",
execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> BlockOutput:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -201,7 +169,6 @@ async def execute_node(
node_id = data.node_id
node_block = node.block
execution_context = data.execution_context
creds_manager = execution_processor.creds_manager
log_metadata = LogMetadata(
logger=_logger,
@@ -245,8 +212,6 @@ async def execute_node(
"node_exec_id": node_exec_id,
"user_id": user_id,
"execution_context": execution_context,
"execution_processor": execution_processor,
"nodes_to_skip": nodes_to_skip or set(),
}
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
@@ -544,7 +509,6 @@ class ExecutionProcessor:
node_exec_progress: NodeExecutionProgress,
nodes_input_masks: Optional[NodesInputMasks],
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
nodes_to_skip: Optional[set[str]] = None,
) -> NodeExecutionStats:
log_metadata = LogMetadata(
logger=_logger,
@@ -567,7 +531,6 @@ class ExecutionProcessor:
db_client=db_client,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
)
if isinstance(status, BaseException):
raise status
@@ -613,7 +576,6 @@ class ExecutionProcessor:
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> ExecutionStatus:
status = ExecutionStatus.RUNNING
@@ -646,11 +608,10 @@ class ExecutionProcessor:
async for output_name, output_data in execute_node(
node=node,
creds_manager=self.creds_manager,
data=node_exec,
execution_processor=self,
execution_stats=stats,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
):
await persist_output(output_name, output_data)
@@ -899,17 +860,12 @@ class ExecutionProcessor:
execution_stats_lock = threading.Lock()
# State holders ----------------------------------------------------
self.running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
NodeExecutionProgress
)
self.running_node_evaluation: dict[str, Future] = {}
self.execution_stats = execution_stats
self.execution_stats_lock = execution_stats_lock
running_node_evaluation: dict[str, Future] = {}
execution_queue = ExecutionQueue[NodeExecutionEntry]()
running_node_execution = self.running_node_execution
running_node_evaluation = self.running_node_evaluation
try:
if db_client.get_credits(graph_exec.user_id) <= 0:
raise InsufficientBalanceError(
@@ -962,21 +918,6 @@ class ExecutionProcessor:
queued_node_exec = execution_queue.get()
# Check if this node should be skipped due to optional credentials
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
log_metadata.info(
f"Skipping node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id} - optional credentials not configured"
)
# Mark the node as completed without executing
# No outputs will be produced, so downstream nodes won't trigger
update_node_execution_status(
db_client=db_client,
exec_id=queued_node_exec.node_exec_id,
status=ExecutionStatus.COMPLETED,
)
continue
log_metadata.debug(
f"Dispatching node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id}",
@@ -1037,7 +978,6 @@ class ExecutionProcessor:
execution_stats,
execution_stats_lock,
),
nodes_to_skip=graph_exec.nodes_to_skip,
),
self.node_execution_loop,
)
@@ -1317,40 +1257,12 @@ class ExecutionProcessor:
graph_id: str,
e: InsufficientBalanceError,
):
# Check if we've already sent a notification for this user+agent combo.
# We only send one notification per user per agent until they top up credits.
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
try:
redis_client = redis.get_redis()
# SET NX returns True only if the key was newly set (didn't exist)
is_new_notification = redis_client.set(
redis_key,
"1",
nx=True,
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
)
if not is_new_notification:
# Already notified for this user+agent, skip all notifications
logger.debug(
f"Skipping duplicate insufficient funds notification for "
f"user={user_id}, graph={graph_id}"
)
return
except Exception as redis_error:
# If Redis fails, log and continue to send the notification
# (better to occasionally duplicate than to never notify)
logger.warning(
f"Failed to check/set insufficient funds notification flag in Redis: "
f"{redis_error}"
)
shortfall = abs(e.amount) - e.balance
metadata = db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
# Queue user email notification
queue_notification(
NotificationEventModel(
user_id=user_id,
@@ -1364,7 +1276,6 @@ class ExecutionProcessor:
)
)
# Send Discord system alert
try:
user_email = db_client.get_user_email_by_id(user_id)

View File

@@ -1,560 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma.enums import NotificationType
from backend.data.notifications import ZeroBalanceData
from backend.executor.manager import (
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
ExecutionProcessor,
clear_insufficient_funds_notifications,
)
from backend.util.exceptions import InsufficientBalanceError
from backend.util.test import SpinTestServer
async def async_iter(items):
"""Helper to create an async iterator from a list."""
for item in items:
yield item
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_sends_discord_alert_first_time(
server: SpinTestServer,
):
"""Test that the first insufficient funds notification sends a Discord alert."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72, # $0.72
amount=-714, # Attempting to spend $7.14
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to simulate first-time notification (set returns True)
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.return_value = True # Key was newly set
# Create mock database client
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify notification was queued
mock_queue_notif.assert_called_once()
notification_call = mock_queue_notif.call_args[0][0]
assert notification_call.type == NotificationType.ZERO_BALANCE
assert notification_call.user_id == user_id
assert isinstance(notification_call.data, ZeroBalanceData)
assert notification_call.data.current_balance == 72
# Verify Redis was checked with correct key pattern
expected_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
mock_redis_client.set.assert_called_once()
call_args = mock_redis_client.set.call_args
assert call_args[0][0] == expected_key
assert call_args[1]["nx"] is True
# Verify Discord alert was sent
mock_client.discord_system_alert.assert_called_once()
discord_message = mock_client.discord_system_alert.call_args[0][0]
assert "Insufficient Funds Alert" in discord_message
assert "test@example.com" in discord_message
assert "Test Agent" in discord_message
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_skips_duplicate_notifications(
server: SpinTestServer,
):
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to simulate duplicate notification (set returns False/None)
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.return_value = None # Key already existed
# Create mock database client
mock_db_client = MagicMock()
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify email notification was NOT queued (deduplication worked)
mock_queue_notif.assert_not_called()
# Verify Discord alert was NOT sent (deduplication worked)
mock_client.discord_system_alert.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
server: SpinTestServer,
):
"""Test that different agents for the same user get separate Discord alerts."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id_1 = "test-graph-111"
graph_id_2 = "test-graph-222"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch("backend.executor.manager.queue_notification"), patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
# Both calls return True (first time for each agent)
mock_redis_client.set.return_value = True
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# First agent notification
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_1,
e=error,
)
# Second agent notification
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_2,
e=error,
)
# Verify Discord alerts were sent for both agents
assert mock_client.discord_system_alert.call_count == 2
# Verify Redis was called with different keys
assert mock_redis_client.set.call_count == 2
calls = mock_redis_client.set.call_args_list
assert (
calls[0][0][0]
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_1}"
)
assert (
calls[1][0][0]
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_2}"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
"""Test that clearing notifications removes all keys for a user."""
user_id = "test-user-123"
with patch("backend.executor.manager.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
# Mock scan_iter to return some keys as an async iterator
mock_keys = [
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1",
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-2",
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-3",
]
mock_redis_client.scan_iter.return_value = async_iter(mock_keys)
# delete is awaited, so use AsyncMock
mock_redis_client.delete = AsyncMock(return_value=3)
# Clear notifications
result = await clear_insufficient_funds_notifications(user_id)
# Verify correct pattern was used
expected_pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
mock_redis_client.scan_iter.assert_called_once_with(match=expected_pattern)
# Verify delete was called with all keys
mock_redis_client.delete.assert_called_once_with(*mock_keys)
# Verify return value
assert result == 3
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestServer):
"""Test clearing notifications when there are no keys to clear."""
user_id = "test-user-no-notifications"
with patch("backend.executor.manager.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
# Mock scan_iter to return no keys as an async iterator
mock_redis_client.scan_iter.return_value = async_iter([])
# Clear notifications
result = await clear_insufficient_funds_notifications(user_id)
# Verify delete was not called
mock_redis_client.delete.assert_not_called()
# Verify return value
assert result == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications_handles_redis_error(
server: SpinTestServer,
):
"""Test that clearing notifications handles Redis errors gracefully."""
user_id = "test-user-redis-error"
with patch("backend.executor.manager.redis") as mock_redis_module:
# Mock get_redis_async to raise an error
mock_redis_module.get_redis_async = AsyncMock(
side_effect=Exception("Redis connection failed")
)
# Clear notifications should not raise, just return 0
result = await clear_insufficient_funds_notifications(user_id)
# Verify it returned 0 (graceful failure)
assert result == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_continues_on_redis_error(
server: SpinTestServer,
):
"""Test that both email and Discord notifications are still sent when Redis fails."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to raise an error
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.side_effect = Exception("Redis connection error")
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify email notification was still queued despite Redis error
mock_queue_notif.assert_called_once()
# Verify Discord alert was still sent despite Redis error
mock_client.discord_system_alert.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_clears_notifications_on_grant(server: SpinTestServer):
"""Test that _add_transaction clears notification flags when adding GRANT credits."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-grant-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter(
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
)
mock_redis_client.delete = AsyncMock(return_value=1)
# Create a concrete instance
credit_model = UserCredit()
# Call _add_transaction with GRANT type (should clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=500, # Positive amount
transaction_type=CreditTransactionType.GRANT,
is_active=True, # Active transaction
)
# Verify notification clearing was called
mock_redis_module.get_redis_async.assert_called_once()
mock_redis_client.scan_iter.assert_called_once_with(
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestServer):
"""Test that _add_transaction clears notification flags when adding TOP_UP credits."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-topup-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter([])
mock_redis_client.delete = AsyncMock(return_value=0)
credit_model = UserCredit()
# Call _add_transaction with TOP_UP type (should clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=1000, # Positive amount
transaction_type=CreditTransactionType.TOP_UP,
is_active=True,
)
# Verify notification clearing was attempted
mock_redis_module.get_redis_async.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_skips_clearing_for_inactive_transaction(
server: SpinTestServer,
):
"""Test that _add_transaction does NOT clear notifications for inactive transactions."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-inactive"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
# Mock async Redis
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
credit_model = UserCredit()
# Call _add_transaction with is_active=False (should NOT clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=500,
transaction_type=CreditTransactionType.TOP_UP,
is_active=False, # Inactive - pending Stripe payment
)
# Verify notification clearing was NOT called
mock_redis_module.get_redis_async.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_skips_clearing_for_usage_transaction(
server: SpinTestServer,
):
"""Test that _add_transaction does NOT clear notifications for USAGE transactions."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-usage"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
# Mock async Redis
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
credit_model = UserCredit()
# Call _add_transaction with USAGE type (spending, should NOT clear)
await credit_model._add_transaction(
user_id=user_id,
amount=-100, # Negative - spending credits
transaction_type=CreditTransactionType.USAGE,
is_active=True,
)
# Verify notification clearing was NOT called
mock_redis_module.get_redis_async.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_enable_transaction_clears_notifications(server: SpinTestServer):
"""Test that _enable_transaction clears notification flags when enabling a TOP_UP."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-enable"
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
"backend.data.credit.query_raw_with_schema"
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
# Mock finding the pending transaction
mock_transaction = MagicMock()
mock_transaction.amount = 1000
mock_transaction.type = CreditTransactionType.TOP_UP
mock_credit_tx.prisma.return_value.find_first = AsyncMock(
return_value=mock_transaction
)
# Mock the query to return updated balance
mock_query.return_value = [{"balance": 1500}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter(
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
)
mock_redis_client.delete = AsyncMock(return_value=1)
credit_model = UserCredit()
# Call _enable_transaction (simulates Stripe checkout completion)
from backend.util.json import SafeJson
await credit_model._enable_transaction(
transaction_key="cs_test_123",
user_id=user_id,
metadata=SafeJson({"payment": "completed"}),
)
# Verify notification clearing was called
mock_redis_module.get_redis_async.assert_called_once()
mock_redis_client.scan_iter.assert_called_once_with(
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
)

View File

@@ -3,16 +3,16 @@ import logging
import fastapi.responses
import pytest
import backend.api.features.library.model
import backend.api.features.store.model
from backend.api.model import CreateGraph
from backend.api.rest_api import AgentServer
import backend.server.v2.library.model
import backend.server.v2.store.model
from backend.blocks.basic import StoreValueBlock
from backend.blocks.data_manipulation import FindInDictionaryBlock
from backend.blocks.io import AgentInputBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.data.model import User
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.test import SpinTestServer, wait_execution
@@ -356,7 +356,7 @@ async def test_execute_preset(server: SpinTestServer):
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
name="Test Preset With Clash",
description="Test preset with clashing input values",
graph_id=test_graph.id,
@@ -444,7 +444,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.api.features.library.model.LibraryAgentPresetCreatable(
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
name="Test Preset With Clash",
description="Test preset with clashing input values",
graph_id=test_graph.id,
@@ -485,7 +485,7 @@ async def test_store_listing_graph(server: SpinTestServer):
test_user = await create_test_user()
test_graph = await create_graph(server, create_test_graph(), test_user)
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
agent_id=test_graph.id,
agent_version=test_graph.version,
slug=test_graph.id,
@@ -514,7 +514,7 @@ async def test_store_listing_graph(server: SpinTestServer):
admin_user = await create_test_user(alt_user=True)
await server.agent_server.test_review_store_listing(
backend.api.features.store.model.ReviewSubmissionRequest(
backend.server.v2.store.model.ReviewSubmissionRequest(
store_listing_version_id=slv_id,
is_approved=True,
comments="Test comments",
@@ -523,7 +523,7 @@ async def test_store_listing_graph(server: SpinTestServer):
)
# Add the approved store listing to the admin user's library so they can execute it
from backend.api.features.library.db import add_store_agent_to_library
from backend.server.v2.library.db import add_store_agent_to_library
await add_store_agent_to_library(
store_listing_version_id=slv_id, user_id=admin_user.id

View File

@@ -23,18 +23,15 @@ from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_runs
from backend.executor import utils as execution_utils
from backend.monitoring import (
NotificationJobArgs,
process_existing_batches,
process_weekly_summary,
report_block_error_rates,
report_execution_accuracy_alerts,
report_late_executions,
)
from backend.util.clients import get_scheduler_client
@@ -156,7 +153,6 @@ async def _execute_graph(**kwargs):
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
)
await increment_runs(args.user_id)
elapsed = asyncio.get_event_loop().time() - start_time
logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
@@ -243,17 +239,6 @@ def cleanup_expired_files():
run_async(cleanup_expired_files_async())
def cleanup_oauth_tokens():
"""Clean up expired OAuth tokens from the database."""
# Wait for completion
run_async(cleanup_expired_oauth_tokens())
def execution_accuracy_alerts():
"""Check execution accuracy and send alerts if drops are detected."""
return report_execution_accuracy_alerts()
# Monitoring functions are now imported from monitoring module
@@ -453,28 +438,6 @@ class Scheduler(AppService):
jobstore=Jobstores.EXECUTION.value,
)
# OAuth Token Cleanup - configurable interval
self.scheduler.add_job(
cleanup_oauth_tokens,
id="cleanup_oauth_tokens",
trigger="interval",
replace_existing=True,
seconds=config.oauth_token_cleanup_interval_hours
* 3600, # Convert hours to seconds
jobstore=Jobstores.EXECUTION.value,
)
# Execution Accuracy Monitoring - configurable interval
self.scheduler.add_job(
execution_accuracy_alerts,
id="report_execution_accuracy_alerts",
trigger="interval",
replace_existing=True,
seconds=config.execution_accuracy_check_interval_hours
* 3600, # Convert hours to seconds
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
@@ -622,16 +585,6 @@ class Scheduler(AppService):
"""Manually trigger cleanup of expired cloud storage files."""
return cleanup_expired_files()
@expose
def execute_cleanup_oauth_tokens(self):
"""Manually trigger cleanup of expired OAuth tokens."""
return cleanup_oauth_tokens()
@expose
def execute_report_execution_accuracy_alerts(self):
"""Manually trigger execution accuracy alert checking."""
return execution_accuracy_alerts()
class SchedulerClient(AppServiceClient):
@classmethod

View File

@@ -1,7 +1,7 @@
import pytest
from backend.api.model import CreateGraph
from backend.data import db
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.clients import get_scheduler_client
from backend.util.test import SpinTestServer

View File

@@ -239,19 +239,14 @@ async def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[dict[str, dict[str, str]], set[str]]:
) -> dict[str, dict[str, str]]:
"""
Checks all credentials for all nodes of the graph and returns structured errors
and a set of nodes that should be skipped due to optional missing credentials.
Checks all credentials for all nodes of the graph and returns structured errors.
Returns:
tuple[
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
"""
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
nodes_to_skip: set[str] = set()
for node in graph.nodes:
block = node.block
@@ -261,46 +256,27 @@ async def _validate_node_input_credentials(
if not credentials_fields:
continue
# Track if any credential field is missing for this node
has_missing_credentials = False
for field_name, credentials_meta_type in credentials_fields.items():
try:
# Check nodes_input_masks first, then input_default
field_value = None
if (
nodes_input_masks
and (node_input_mask := nodes_input_masks.get(node.id))
and field_name in node_input_mask
):
field_value = node_input_mask[field_name]
credentials_meta = credentials_meta_type.model_validate(
node_input_mask[field_name]
)
elif field_name in node.input_default:
# For optional credentials, don't use input_default - treat as missing
# This prevents stale credential IDs from failing validation
if node.credentials_optional:
field_value = None
else:
field_value = node.input_default[field_name]
# Check if credentials are missing (None, empty, or not present)
if field_value is None or (
isinstance(field_value, dict) and not field_value.get("id")
):
has_missing_credentials = True
# If node has credentials_optional flag, mark for skipping instead of error
if node.credentials_optional:
continue # Don't add error, will be marked for skip after loop
else:
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
credentials_meta = credentials_meta_type.model_validate(field_value)
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
# Missing credentials
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
except ValidationError as e:
# Validation error means credentials were provided but invalid
# This should always be an error, even if optional
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
continue
@@ -311,7 +287,6 @@ async def _validate_node_input_credentials(
)
except Exception as e:
# Handle any errors fetching credentials
# If credentials were explicitly configured but unavailable, it's an error
credential_errors[node.id][
field_name
] = f"Credentials not available: {e}"
@@ -338,19 +313,7 @@ async def _validate_node_input_credentials(
] = "Invalid credentials: type/provider mismatch"
continue
# If node has optional credentials and any are missing, mark for skipping
# But only if there are no other errors for this node
if (
has_missing_credentials
and node.credentials_optional
and node.id not in credential_errors
):
nodes_to_skip.add(node.id)
logger.info(
f"Node #{node.id} will be skipped: optional credentials not configured"
)
return credential_errors, nodes_to_skip
return credential_errors
def make_node_credentials_input_map(
@@ -392,25 +355,21 @@ async def validate_graph_with_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
) -> Mapping[str, Mapping[str, str]]:
"""
Validate graph including credentials and return structured errors per node,
along with a set of nodes that should be skipped due to optional missing credentials.
Validate graph including credentials and return structured errors per node.
Returns:
tuple[
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
dict[node_id, dict[field_name, error_message]]: Validation errors per node
"""
# Get input validation errors
node_input_errors = GraphModel.validate_graph_get_errors(
graph, for_run=True, nodes_input_masks=nodes_input_masks
)
# Get credential input/availability/validation errors and nodes to skip
node_credential_input_errors, nodes_to_skip = (
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
# Get credential input/availability/validation errors
node_credential_input_errors = await _validate_node_input_credentials(
graph, user_id, nodes_input_masks
)
# Merge credential errors with structural errors
@@ -419,7 +378,7 @@ async def validate_graph_with_credentials(
node_input_errors[node_id] = {}
node_input_errors[node_id].update(field_errors)
return node_input_errors, nodes_to_skip
return node_input_errors
async def _construct_starting_node_execution_input(
@@ -427,7 +386,7 @@ async def _construct_starting_node_execution_input(
user_id: str,
graph_inputs: BlockInput,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
This function checks the graph for starting nodes, validates the input data
@@ -441,14 +400,11 @@ async def _construct_starting_node_execution_input(
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
Returns:
tuple[
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
and the corresponding input data for that node.
set[str]: Node IDs that should be skipped (optional credentials not configured)
]
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
# Use new validation function that includes credentials
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
validation_errors = await validate_graph_with_credentials(
graph, user_id, nodes_input_masks
)
n_error_nodes = len(validation_errors)
@@ -489,7 +445,7 @@ async def _construct_starting_node_execution_input(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
return nodes_input, nodes_to_skip
return nodes_input
async def validate_and_construct_node_execution_input(
@@ -500,7 +456,7 @@ async def validate_and_construct_node_execution_input(
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
is_sub_graph: bool = False,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
"""
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
This centralizes the logic used by both scheduler validation and actual execution.
@@ -517,7 +473,6 @@ async def validate_and_construct_node_execution_input(
GraphModel: Full graph object for the given `graph_id`.
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
dict[str, BlockInput]: Node input masks including all passed-in credentials.
set[str]: Node IDs that should be skipped (optional credentials not configured).
Raises:
NotFoundError: If the graph is not found.
@@ -559,16 +514,14 @@ async def validate_and_construct_node_execution_input(
nodes_input_masks or {},
)
starting_nodes_input, nodes_to_skip = (
await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
starting_nodes_input = await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
return graph, starting_nodes_input, nodes_input_masks
def _merge_nodes_input_masks(
@@ -826,9 +779,6 @@ async def add_graph_execution(
# Use existing execution's compiled input masks
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
# For resumed executions, nodes_to_skip was already determined at creation time
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
nodes_to_skip: set[str] = set()
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
else:
@@ -837,7 +787,7 @@ async def add_graph_execution(
)
# Create new execution
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
graph, starting_nodes_input, compiled_nodes_input_masks = (
await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
@@ -886,7 +836,6 @@ async def add_graph_execution(
try:
graph_exec_entry = graph_exec.to_graph_execution_entry(
compiled_nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip,
execution_context=execution_context,
)
logger.info(f"Publishing execution {graph_exec.id} to execution queue")

View File

@@ -367,13 +367,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
)
# Setup mock returns
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
nodes_to_skip: set[str] = set()
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip,
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
@@ -459,212 +456,3 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
# Both executions should succeed (though they create different objects)
assert result1 == mock_graph_exec
assert result2 == mock_graph_exec_2
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================
@pytest.mark.asyncio
async def test_validate_node_input_credentials_returns_nodes_to_skip(
mocker: MockerFixture,
):
"""
Test that _validate_node_input_credentials returns nodes_to_skip set
for nodes with credentials_optional=True and missing credentials.
"""
from backend.executor.utils import _validate_node_input_credentials
# Create a mock node with credentials_optional=True
mock_node = mocker.MagicMock()
mock_node.id = "node-with-optional-creds"
mock_node.credentials_optional = True
mock_node.input_default = {} # No credentials configured
# Create a mock block with credentials field
mock_block = mocker.MagicMock()
mock_credentials_field_type = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_node.block = mock_block
# Create mock graph
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Call the function
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Node should be in nodes_to_skip, not in errors
assert mock_node.id in nodes_to_skip
assert mock_node.id not in errors
@pytest.mark.asyncio
async def test_validate_node_input_credentials_required_missing_creds_error(
mocker: MockerFixture,
):
"""
Test that _validate_node_input_credentials returns errors
for nodes with credentials_optional=False and missing credentials.
"""
from backend.executor.utils import _validate_node_input_credentials
# Create a mock node with credentials_optional=False (required)
mock_node = mocker.MagicMock()
mock_node.id = "node-with-required-creds"
mock_node.credentials_optional = False
mock_node.input_default = {} # No credentials configured
# Create a mock block with credentials field
mock_block = mocker.MagicMock()
mock_credentials_field_type = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_node.block = mock_block
# Create mock graph
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Call the function
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Node should be in errors, not in nodes_to_skip
assert mock_node.id in errors
assert "credentials" in errors[mock_node.id]
assert "required" in errors[mock_node.id]["credentials"].lower()
assert mock_node.id not in nodes_to_skip
@pytest.mark.asyncio
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
mocker: MockerFixture,
):
"""
Test that validate_graph_with_credentials returns nodes_to_skip set
from _validate_node_input_credentials.
"""
from backend.executor.utils import validate_graph_with_credentials
# Mock _validate_node_input_credentials to return specific values
mock_validate = mocker.patch(
"backend.executor.utils._validate_node_input_credentials"
)
expected_errors = {"node1": {"field": "error"}}
expected_nodes_to_skip = {"node2", "node3"}
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
# Mock GraphModel with validate_graph_get_errors method
mock_graph = mocker.MagicMock()
mock_graph.validate_graph_get_errors.return_value = {}
# Call the function
errors, nodes_to_skip = await validate_graph_with_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Verify nodes_to_skip is passed through
assert nodes_to_skip == expected_nodes_to_skip
assert "node1" in errors
@pytest.mark.asyncio
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
"""
Test that add_graph_execution properly passes nodes_to_skip
to the graph execution entry.
"""
from backend.data.execution import GraphExecutionWithNodes
from backend.executor.utils import add_graph_execution
# Mock data
graph_id = "test-graph-id"
user_id = "test-user-id"
inputs = {"test_input": "test_value"}
graph_version = 1
# Mock the graph object
mock_graph = mocker.MagicMock()
mock_graph.version = graph_version
# Starting nodes and masks
starting_nodes_input = [("node1", {"input1": "value1"})]
compiled_nodes_input_masks = {}
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
# Mock the graph execution object
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = []
# Track what's passed to to_graph_execution_entry
captured_kwargs = {}
def capture_to_entry(**kwargs):
captured_kwargs.update(kwargs)
return mocker.MagicMock()
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
# Setup mocks
mock_validate = mocker.patch(
"backend.executor.utils.validate_and_construct_node_execution_input"
)
mock_edb = mocker.patch("backend.executor.utils.execution_db")
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_udb = mocker.patch("backend.executor.utils.user_db")
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
# Setup returns - include nodes_to_skip in the tuple
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip, # This should be passed through
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
mock_user = mocker.MagicMock()
mock_user.timezone = "UTC"
mock_settings = mocker.MagicMock()
mock_settings.human_in_the_loop_safe_mode = True
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
mock_get_queue.return_value = mocker.AsyncMock()
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
# Call the function
await add_graph_execution(
graph_id=graph_id,
user_id=user_id,
inputs=inputs,
graph_version=graph_version,
)
# Verify nodes_to_skip was passed to to_graph_execution_entry
assert "nodes_to_skip" in captured_kwargs
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip

View File

@@ -0,0 +1,156 @@
"""
Embedding service for generating text embeddings using OpenAI.
Used for vector-based semantic search in the store.
"""
import logging
from typing import Optional
import openai
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
# Model configuration
# Using text-embedding-3-small (1536 dimensions) for compatibility with pgvector indexes
# pgvector IVFFlat/HNSW indexes have dimension limits (2000 for IVFFlat, varies for HNSW)
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_DIMENSIONS = 1536
# Input validation limits
# OpenAI text-embedding-3-large supports up to 8191 tokens (~32k chars)
# We set a conservative limit to prevent abuse
MAX_TEXT_LENGTH = 10000 # characters
MAX_BATCH_SIZE = 100 # maximum texts per batch request
class EmbeddingService:
"""Service for generating text embeddings using OpenAI."""
def __init__(self, api_key: Optional[str] = None):
settings = Settings()
self.api_key = (
api_key
or settings.secrets.openai_internal_api_key
or settings.secrets.openai_api_key
)
if not self.api_key:
raise ValueError(
"OpenAI API key not configured. "
"Set OPENAI_API_KEY or OPENAI_INTERNAL_API_KEY environment variable."
)
self.client = openai.AsyncOpenAI(api_key=self.api_key)
async def generate_embedding(self, text: str) -> list[float]:
"""
Generate embedding for a single text string.
Args:
text: The text to generate an embedding for.
Returns:
A list of floats representing the embedding vector.
Raises:
ValueError: If the text is empty or exceeds maximum length.
openai.APIError: If the OpenAI API call fails.
"""
# Input validation
if not text or not text.strip():
raise ValueError("Text cannot be empty")
if len(text) > MAX_TEXT_LENGTH:
raise ValueError(
f"Text exceeds maximum length of {MAX_TEXT_LENGTH} characters"
)
try:
response = await self.client.embeddings.create(
model=EMBEDDING_MODEL,
input=text,
dimensions=EMBEDDING_DIMENSIONS,
)
return response.data[0].embedding
except openai.APIError as e:
logger.error(f"OpenAI API error generating embedding: {e}")
raise
async def generate_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
Generate embeddings for multiple texts (batch).
Args:
texts: List of texts to generate embeddings for.
Returns:
List of embedding vectors, one per input text.
Raises:
ValueError: If any text is invalid or batch size exceeds limit.
openai.APIError: If the OpenAI API call fails.
"""
# Input validation
if not texts:
raise ValueError("Texts list cannot be empty")
if len(texts) > MAX_BATCH_SIZE:
raise ValueError(f"Batch size exceeds maximum of {MAX_BATCH_SIZE} texts")
for i, text in enumerate(texts):
if not text or not text.strip():
raise ValueError(f"Text at index {i} cannot be empty")
if len(text) > MAX_TEXT_LENGTH:
raise ValueError(
f"Text at index {i} exceeds maximum length of {MAX_TEXT_LENGTH} characters"
)
try:
response = await self.client.embeddings.create(
model=EMBEDDING_MODEL,
input=texts,
dimensions=EMBEDDING_DIMENSIONS,
)
# Sort by index to ensure correct ordering
sorted_data = sorted(response.data, key=lambda x: x.index)
return [item.embedding for item in sorted_data]
except openai.APIError as e:
logger.error(f"OpenAI API error generating embeddings: {e}")
raise
def create_search_text(name: str, sub_heading: str, description: str) -> str:
"""
Combine fields into searchable text for embedding.
This creates a single text string from the agent's name, sub-heading,
and description, which is then converted to an embedding vector.
Args:
name: The agent name.
sub_heading: The agent sub-heading/tagline.
description: The agent description.
Returns:
A single string combining all non-empty fields.
"""
parts = [name or "", sub_heading or "", description or ""]
return " ".join(filter(None, parts)).strip()
# Singleton instance
_embedding_service: Optional[EmbeddingService] = None
async def get_embedding_service() -> EmbeddingService:
"""
Get or create the embedding service singleton.
Returns:
The shared EmbeddingService instance.
Raises:
ValueError: If OpenAI API key is not configured.
"""
global _embedding_service
if _embedding_service is None:
_embedding_service = EmbeddingService()
return _embedding_service

View File

@@ -0,0 +1,231 @@
"""Tests for the embedding service.
This module tests:
- create_search_text utility function
- EmbeddingService input validation
- EmbeddingService API interaction (mocked)
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.integrations.embeddings import (
EMBEDDING_DIMENSIONS,
MAX_BATCH_SIZE,
MAX_TEXT_LENGTH,
EmbeddingService,
create_search_text,
)
class TestCreateSearchText:
"""Tests for the create_search_text utility function."""
def test_combines_all_fields(self):
result = create_search_text("Agent Name", "A cool agent", "Does amazing things")
assert result == "Agent Name A cool agent Does amazing things"
def test_handles_empty_name(self):
result = create_search_text("", "Sub heading", "Description")
assert result == "Sub heading Description"
def test_handles_empty_sub_heading(self):
result = create_search_text("Name", "", "Description")
assert result == "Name Description"
def test_handles_empty_description(self):
result = create_search_text("Name", "Sub heading", "")
assert result == "Name Sub heading"
def test_handles_all_empty(self):
result = create_search_text("", "", "")
assert result == ""
def test_handles_none_values(self):
# The function expects strings but should handle None gracefully
result = create_search_text(None, None, None) # type: ignore
assert result == ""
def test_preserves_content_strips_outer_whitespace(self):
# The function joins parts and strips the outer result
# Internal whitespace in each part is preserved
result = create_search_text(" Name ", " Sub ", " Desc ")
# Each part is joined with space, then outer strip applied
assert result == "Name Sub Desc"
def test_handles_only_whitespace(self):
# Parts that are only whitespace become empty after filter
result = create_search_text(" ", " ", " ")
assert result == ""
class TestEmbeddingServiceValidation:
"""Tests for EmbeddingService input validation."""
@pytest.fixture
def mock_settings(self):
"""Mock settings with a test API key."""
with patch("backend.integrations.embeddings.Settings") as mock:
mock_instance = MagicMock()
mock_instance.secrets.openai_internal_api_key = "test-api-key"
mock_instance.secrets.openai_api_key = ""
mock.return_value = mock_instance
yield mock
@pytest.fixture
def service(self, mock_settings):
"""Create an EmbeddingService instance with mocked settings."""
with patch("backend.integrations.embeddings.openai.AsyncOpenAI"):
return EmbeddingService()
def test_init_requires_api_key(self):
"""Test that initialization fails without an API key."""
with patch("backend.integrations.embeddings.Settings") as mock:
mock_instance = MagicMock()
mock_instance.secrets.openai_internal_api_key = ""
mock_instance.secrets.openai_api_key = ""
mock.return_value = mock_instance
with pytest.raises(ValueError, match="OpenAI API key not configured"):
EmbeddingService()
def test_init_accepts_explicit_api_key(self):
"""Test that explicit API key overrides settings."""
with patch("backend.integrations.embeddings.Settings") as mock:
mock_instance = MagicMock()
mock_instance.secrets.openai_internal_api_key = ""
mock_instance.secrets.openai_api_key = ""
mock.return_value = mock_instance
with patch("backend.integrations.embeddings.openai.AsyncOpenAI"):
service = EmbeddingService(api_key="explicit-key")
assert service.api_key == "explicit-key"
@pytest.mark.asyncio
async def test_generate_embedding_empty_text(self, service):
"""Test that empty text raises ValueError."""
with pytest.raises(ValueError, match="Text cannot be empty"):
await service.generate_embedding("")
@pytest.mark.asyncio
async def test_generate_embedding_whitespace_only(self, service):
"""Test that whitespace-only text raises ValueError."""
with pytest.raises(ValueError, match="Text cannot be empty"):
await service.generate_embedding(" ")
@pytest.mark.asyncio
async def test_generate_embedding_exceeds_max_length(self, service):
"""Test that text exceeding max length raises ValueError."""
long_text = "a" * (MAX_TEXT_LENGTH + 1)
with pytest.raises(ValueError, match="exceeds maximum length"):
await service.generate_embedding(long_text)
@pytest.mark.asyncio
async def test_generate_embeddings_empty_list(self, service):
"""Test that empty list raises ValueError."""
with pytest.raises(ValueError, match="Texts list cannot be empty"):
await service.generate_embeddings([])
@pytest.mark.asyncio
async def test_generate_embeddings_exceeds_batch_size(self, service):
"""Test that batch exceeding max size raises ValueError."""
texts = ["text"] * (MAX_BATCH_SIZE + 1)
with pytest.raises(ValueError, match="Batch size exceeds maximum"):
await service.generate_embeddings(texts)
@pytest.mark.asyncio
async def test_generate_embeddings_empty_text_in_batch(self, service):
"""Test that empty text in batch raises ValueError with index."""
with pytest.raises(ValueError, match="Text at index 1 cannot be empty"):
await service.generate_embeddings(["valid", "", "also valid"])
@pytest.mark.asyncio
async def test_generate_embeddings_long_text_in_batch(self, service):
"""Test that long text in batch raises ValueError with index."""
long_text = "a" * (MAX_TEXT_LENGTH + 1)
with pytest.raises(ValueError, match="Text at index 2 exceeds maximum length"):
await service.generate_embeddings(["short", "also short", long_text])
class TestEmbeddingServiceAPI:
"""Tests for EmbeddingService API interaction."""
@pytest.fixture
def mock_openai_client(self):
"""Create a mock OpenAI client."""
mock_client = MagicMock()
mock_client.embeddings = MagicMock()
return mock_client
@pytest.fixture
def service_with_mock_client(self, mock_openai_client):
"""Create an EmbeddingService with a mocked OpenAI client."""
with patch("backend.integrations.embeddings.Settings") as mock_settings:
mock_instance = MagicMock()
mock_instance.secrets.openai_internal_api_key = "test-key"
mock_instance.secrets.openai_api_key = ""
mock_settings.return_value = mock_instance
with patch(
"backend.integrations.embeddings.openai.AsyncOpenAI"
) as mock_openai:
mock_openai.return_value = mock_openai_client
service = EmbeddingService()
return service, mock_openai_client
@pytest.mark.asyncio
async def test_generate_embedding_success(self, service_with_mock_client):
"""Test successful embedding generation."""
service, mock_client = service_with_mock_client
# Create mock response
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=mock_embedding)]
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
result = await service.generate_embedding("test text")
assert result == mock_embedding
mock_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_generate_embeddings_success(self, service_with_mock_client):
"""Test successful batch embedding generation."""
service, mock_client = service_with_mock_client
# Create mock response with multiple embeddings
mock_embeddings = [[0.1] * EMBEDDING_DIMENSIONS, [0.2] * EMBEDDING_DIMENSIONS]
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=mock_embeddings[0], index=0),
MagicMock(embedding=mock_embeddings[1], index=1),
]
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
result = await service.generate_embeddings(["text1", "text2"])
assert result == mock_embeddings
mock_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_generate_embeddings_preserves_order(self, service_with_mock_client):
"""Test that batch embeddings are returned in correct order even if API returns out of order."""
service, mock_client = service_with_mock_client
# Create mock response with embeddings out of order
mock_embeddings = [[0.1] * EMBEDDING_DIMENSIONS, [0.2] * EMBEDDING_DIMENSIONS]
mock_response = MagicMock()
# Return in reverse order
mock_response.data = [
MagicMock(embedding=mock_embeddings[1], index=1),
MagicMock(embedding=mock_embeddings[0], index=0),
]
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
result = await service.generate_embeddings(["text1", "text2"])
# Should be sorted by index
assert result[0] == mock_embeddings[0]
assert result[1] == mock_embeddings[1]

View File

@@ -8,7 +8,6 @@ from .discord import DiscordOAuthHandler
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
from .reddit import RedditOAuthHandler
from .twitter import TwitterOAuthHandler
if TYPE_CHECKING:
@@ -21,7 +20,6 @@ _ORIGINAL_HANDLERS = [
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
RedditOAuthHandler,
TwitterOAuthHandler,
TodoistOAuthHandler,
]

View File

@@ -1,208 +0,0 @@
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
from backend.util.settings import Settings
settings = Settings()
class RedditOAuthHandler(BaseOAuthHandler):
"""
Reddit OAuth 2.0 handler.
Based on the documentation at:
- https://github.com/reddit-archive/reddit/wiki/OAuth2
Notes:
- Reddit requires `duration=permanent` to get refresh tokens
- Access tokens expire after 1 hour (3600 seconds)
- Reddit requires HTTP Basic Auth for token requests
- Reddit requires a unique User-Agent header
"""
PROVIDER_NAME = ProviderName.REDDIT
DEFAULT_SCOPES: ClassVar[list[str]] = [
"identity", # Get username, verify auth
"read", # Access posts and comments
"submit", # Submit new posts and comments
"edit", # Edit own posts and comments
"history", # Access user's post history
"privatemessages", # Access inbox and send private messages
"flair", # Access and set flair on posts/subreddits
]
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
USERNAME_URL = "https://oauth.reddit.com/api/v1/me"
REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
"""Generate Reddit OAuth 2.0 authorization URL"""
scopes = self.handle_default_scopes(scopes)
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(scopes),
"state": state,
"duration": "permanent", # Required for refresh tokens
}
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
"""Exchange authorization code for access tokens"""
scopes = self.handle_default_scopes(scopes)
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
}
# Reddit requires HTTP Basic Auth for token requests
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.TOKEN_URL, headers=headers, data=data, auth=auth
)
if not response.ok:
error_text = response.text()
raise ValueError(
f"Reddit token exchange failed: {response.status} - {error_text}"
)
tokens = response.json()
if "error" in tokens:
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
username = await self._get_username(tokens["access_token"])
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
username=username,
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
scopes=scopes,
)
async def _get_username(self, access_token: str) -> str:
"""Get the username from the access token"""
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": settings.config.reddit_user_agent,
}
response = await Requests().get(self.USERNAME_URL, headers=headers)
if not response.ok:
raise ValueError(f"Failed to get Reddit username: {response.status}")
data = response.json()
return data.get("name", "unknown")
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
"""Refresh access tokens using refresh token"""
if not credentials.refresh_token:
raise ValueError("No refresh token available")
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
}
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.TOKEN_URL, headers=headers, data=data, auth=auth
)
if not response.ok:
error_text = response.text()
raise ValueError(
f"Reddit token refresh failed: {response.status} - {error_text}"
)
tokens = response.json()
if "error" in tokens:
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
username = await self._get_username(tokens["access_token"])
# Reddit may or may not return a new refresh token
new_refresh_token = tokens.get("refresh_token")
if new_refresh_token:
refresh_token: SecretStr | None = SecretStr(new_refresh_token)
elif credentials.refresh_token:
# Keep the existing refresh token
refresh_token = credentials.refresh_token
else:
refresh_token = None
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=credentials.title,
username=username,
access_token=tokens["access_token"],
refresh_token=refresh_token,
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
refresh_token_expires_at=None,
scopes=credentials.scopes,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
"""Revoke the access token"""
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
}
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.REVOKE_URL, headers=headers, data=data, auth=auth
)
# Reddit returns 204 No Content on successful revocation
return response.ok

Some files were not shown because too many files have changed in this diff Show More