Compare commits

..

5 Commits

Author SHA1 Message Date
Bently
82bddd885b Merge branch 'dev' into update-install-scripts 2025-08-19 15:57:45 +01:00
Bentlybro
c71406af8b Simplify setup scripts and remove Sentry prompts
Refactored Windows and Linux setup scripts to streamline prerequisite checks, repository detection, and service startup. Removed Sentry configuration and related prompts for a simpler setup experience. Updated user messaging and improved error handling for common Docker issues.
2025-08-13 18:07:22 +01:00
Bently
468d1af802 Merge branch 'dev' into update-install-scripts 2025-08-08 12:47:12 +01:00
Bentlybro
a2c88c7786 Refactor setup scripts for improved reliability and clarity
Reworked both Windows (.bat) and Unix (.sh) setup scripts to improve error handling, logging, and user prompts. The scripts now check for prerequisites, handle Sentry enablement more clearly, ensure environment files are copied with error checks, and consolidate service startup into a single docker compose command with log output. Unused or redundant code was removed for maintainability.
2025-08-07 10:35:48 +01:00
Bentlybro
e79b7a95dc Remove auto-start of frontend dev server in setup scripts
The setup-autogpt.bat and setup-autogpt.sh scripts no longer automatically start the frontend development server after setup. Users are now instructed to manually stop services with 'docker compose down', and the scripts prompt for exit while keeping services running.
2025-08-07 10:03:45 +01:00
247 changed files with 2275 additions and 18964 deletions

View File

@@ -1,244 +0,0 @@
# GitHub Copilot Instructions for AutoGPT
This file provides comprehensive onboarding information for GitHub Copilot coding agent to work efficiently with the AutoGPT repository.
## Repository Overview
**AutoGPT** is a powerful platform for creating, deploying, and managing continuous AI agents that automate complex workflows. This is a large monorepo (~150MB) containing multiple components:
- **AutoGPT Platform** (`autogpt_platform/`) - Main focus: Modern AI agent platform (Polyform Shield License)
- **Classic AutoGPT** (`classic/`) - Legacy agent system (MIT License)
- **Documentation** (`docs/`) - MkDocs-based documentation site
- **Infrastructure** - Docker configurations, CI/CD, and development tools
**Primary Languages & Frameworks:**
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
## Build and Validation Instructions
### Essential Setup Commands
**Always run these commands in the correct directory and in this order:**
1. **Initial Setup** (required once):
```bash
# Clone and enter repository
git clone <repo> && cd AutoGPT
# Start all services (database, redis, rabbitmq, clamav)
cd autogpt_platform && docker compose --profile local up deps --build --detach
```
2. **Backend Setup** (always run before backend development):
```bash
cd autogpt_platform/backend
poetry install # Install dependencies
poetry run prisma migrate dev # Run database migrations
poetry run prisma generate # Generate Prisma client
```
3. **Frontend Setup** (always run before frontend development):
```bash
cd autogpt_platform/frontend
pnpm install # Install dependencies
```
### Runtime Requirements
**Critical:** Always ensure Docker services are running before starting development:
```bash
cd autogpt_platform && docker compose --profile local up deps --build --detach
```
**Python Version:** Use Python 3.11 (required; managed by Poetry via pyproject.toml)
**Node.js Version:** Use Node.js 21+ with pnpm package manager
### Development Commands
**Backend Development:**
```bash
cd autogpt_platform/backend
poetry run serve # Start development server (port 8000)
poetry run test # Run all tests (requires ~5 minutes)
poetry run pytest path/to/test.py # Run specific test
poetry run format # Format code (Black + isort) - always run first
poetry run lint # Lint code (ruff) - run after format
```
**Frontend Development:**
```bash
cd autogpt_platform/frontend
pnpm dev # Start development server (port 3000) - use for active development
pnpm build # Build for production (only needed for E2E tests or deployment)
pnpm test # Run Playwright E2E tests (requires build first)
pnpm test-ui # Run tests with UI
pnpm format # Format and lint code
pnpm storybook # Start component development server
```
### Testing Strategy
**Backend Tests:**
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
**Frontend Tests:**
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
- **Component Tests**: Use Storybook for isolated component development
### Critical Validation Steps
**Before committing changes:**
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
2. Ensure all tests pass in modified areas
3. Verify Docker services are still running
4. Check that database migrations apply cleanly
**Common Issues & Workarounds:**
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
- **Permission errors**: Ensure Docker has proper permissions
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
- **Test timeouts**: Backend tests can take 5+ minutes, use `-x` flag to stop on first failure
## Project Layout & Architecture
### Core Architecture
**AutoGPT Platform** (`autogpt_platform/`):
- `backend/` - FastAPI server with async support
- `backend/backend/` - Core API logic
- `backend/blocks/` - Agent execution blocks
- `backend/data/` - Database models and schemas
- `schema.prisma` - Database schema definition
- `frontend/` - Next.js application
- `src/app/` - App Router pages and layouts
- `src/components/` - Reusable React components
- `src/lib/` - Utilities and configurations
- `autogpt_libs/` - Shared Python utilities
- `docker-compose.yml` - Development stack orchestration
**Key Configuration Files:**
- `pyproject.toml` - Python dependencies and tooling
- `package.json` - Node.js dependencies and scripts
- `schema.prisma` - Database schema and migrations
- `next.config.mjs` - Next.js configuration
- `tailwind.config.ts` - Styling configuration
### Security & Middleware
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
**Authentication**: JWT-based with Supabase integration
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
### Development Workflow
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
- `platform-backend-ci.yml` - Backend testing and validation
- `platform-frontend-ci.yml` - Frontend testing and validation
- `platform-fullstack-ci.yml` - End-to-end integration tests
**Pre-commit Hooks**: Run linting and formatting checks
**Conventional Commits**: Use format `type(scope): description` (e.g., `feat(backend): add API`)
### Key Source Files
**Backend Entry Points:**
- `backend/backend/server/server.py` - FastAPI application setup
- `backend/backend/data/` - Database models and user management
- `backend/blocks/` - Agent execution blocks and logic
**Frontend Entry Points:**
- `frontend/src/app/layout.tsx` - Root application layout
- `frontend/src/app/page.tsx` - Home page
- `frontend/src/lib/supabase/` - Authentication and database client
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
### Agent Block System
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
- Block definition with input/output schemas
- Execution logic with proper error handling
- Tests validating functionality
### Database & ORM
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
- Schema in `schema.prisma`
- Migrations in `backend/migrations/`
- Always run `prisma migrate dev` and `prisma generate` after schema changes
## Environment Configuration
### Configuration Files Priority Order
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
4. Docker Compose `environment:` sections override file-based config
5. Shell environment variables have highest precedence
### Docker Environment Setup
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Copy `.env.default` files to `.env` for local development customization
## Advanced Development Patterns
### Adding New Blocks
1. Create file in `/backend/backend/blocks/`
2. Inherit from `Block` base class with input/output schemas
3. Implement `run` method with proper error handling
4. Generate block UUID using `uuid.uuid4()`
5. Register in block registry
6. Write tests alongside block implementation
7. Consider how inputs/outputs connect with other blocks in graph editor
### API Development
1. Update routes in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside route files
4. For `data/*.py` changes, validate user ID checks
5. Run `poetry run test` to verify changes
### Frontend Development
1. Components in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for component development
4. Test user-facing features with Playwright E2E tests
5. Update protected routes in middleware when needed
### Security Guidelines
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
- Prevents sensitive data caching in browsers/proxies
- Add new cacheable endpoints to `CACHEABLE_PATHS`
### CI/CD Alignment
The repository has comprehensive CI workflows that test:
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
- **Integration**: Full-stack type checking and E2E testing
Match these patterns when developing locally - the copilot setup environment mirrors these CI configurations.
## Collaboration with Other AI Assistants
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
- Check for existing CLAUDE.md files that provide additional context
- Follow established patterns and conventions already in the codebase
- Maintain consistency with existing code style and architecture
- Consider that changes may be reviewed and extended by both human developers and AI assistants
## Trust These Instructions
These instructions are comprehensive and tested. Only perform additional searches if:
1. Information here is incomplete for your specific task
2. You encounter errors not covered by the workarounds
3. You need to understand implementation details not covered above
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.

View File

@@ -1,302 +0,0 @@
name: "Copilot Setup Steps"
# Automatically run the setup steps when they are changed to allow for easy validation, and
# allow manual testing through the repository's "Actions" tab
on:
workflow_dispatch:
push:
paths:
- .github/workflows/copilot-setup-steps.yml
pull_request:
paths:
- .github/workflows/copilot-setup-steps.yml
jobs:
# The job MUST be called `copilot-setup-steps` or it will not be picked up by Copilot.
copilot-setup-steps:
runs-on: ubuntu-latest
timeout-minutes: 45
# Set the permissions to the lowest permissions possible needed for your steps.
# Copilot will be given its own token for its operations.
permissions:
# If you want to clone the repository as part of your setup steps, for example to install dependencies, you'll need the `contents: read` permission. If you don't clone the repository in your setup steps, Copilot will do this for you automatically after the steps complete.
contents: read
# You can define any steps you want, and they will run before the agent starts.
# If you do not check out your code, Copilot will do this for you.
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
run: |
# Extract Poetry version from backend/poetry.lock (matches CI)
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
# Install Poetry
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
# Add Poetry to PATH
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Check poetry.lock
working-directory: autogpt_platform/backend
run: |
poetry lock
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
echo "Warning: poetry.lock not up to date, but continuing for setup"
git checkout poetry.lock # Reset for clean setup
fi
- name: Install Python dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Set pnpm store directory
run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
run: pnpm install --frozen-lockfile
# Install Playwright browsers for frontend testing
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
# - name: Install Playwright browsers
# working-directory: autogpt_platform/frontend
# run: pnpm playwright install --with-deps chromium
# Docker setup for development environment
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Copy default environment files
working-directory: autogpt_platform
run: |
# Copy default environment files for development
cp .env.default .env
cp backend/.env.default backend/.env
cp frontend/.env.default frontend/.env
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
restore-keys: |
docker-images-v2-${{ runner.os }}-
docker-images-v1-${{ runner.os }}-
- name: Load or pull Docker images
working-directory: autogpt_platform
run: |
mkdir -p ~/docker-cache
# Define image list for easy maintenance
IMAGES=(
"redis:latest"
"rabbitmq:management"
"clamav/clamav-debian:latest"
"busybox:latest"
"kong:2.8.1"
"supabase/gotrue:v2.170.0"
"supabase/postgres:15.8.1.049"
"supabase/postgres-meta:v0.86.1"
"supabase/studio:20250224-d10db0f"
)
# Check if any cached tar files exist (more reliable than cache-hit)
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
echo "Docker cache found, loading images in parallel..."
for image in "${IMAGES[@]}"; do
# Convert image name to filename (replace : and / with -)
filename=$(echo "$image" | tr ':/' '--')
if [ -f ~/docker-cache/${filename}.tar ]; then
echo "Loading $image..."
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
fi
done
wait
echo "All cached images loaded"
else
echo "No Docker cache found, pulling images in parallel..."
# Pull all images in parallel
for image in "${IMAGES[@]}"; do
docker pull "$image" &
done
wait
# Only save cache on main branches (not PRs) to avoid cache pollution
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
echo "Saving Docker images to cache in parallel..."
for image in "${IMAGES[@]}"; do
# Convert image name to filename (replace : and / with -)
filename=$(echo "$image" | tr ':/' '--')
echo "Saving $image..."
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
done
wait
echo "Docker image cache saved"
else
echo "Skipping cache save for PR/feature branch"
fi
fi
echo "Docker images ready for use"
# Phase 2: Build migrate service with GitHub Actions cache
- name: Build migrate Docker image with cache
working-directory: autogpt_platform
run: |
# Build the migrate image with buildx for GHA caching
docker buildx build \
--cache-from type=gha \
--cache-to type=gha,mode=max \
--target migrate \
--tag autogpt_platform-migrate:latest \
--load \
-f backend/Dockerfile \
..
# Start services using pre-built images
- name: Start Docker services for development
working-directory: autogpt_platform
run: |
# Start essential services (migrate image already built with correct tag)
docker compose --profile local up deps --no-build --detach
echo "Waiting for services to be ready..."
# Wait for database to be ready
echo "Checking database readiness..."
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
echo " Waiting for database..."
sleep 2
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
# Check migrate service status
echo "Checking migration status..."
docker compose ps migrate || echo " Migrate service not visible in ps output"
# Wait for migrate service to complete
echo "Waiting for migrations to complete..."
timeout 30 bash -c '
ATTEMPTS=0
while [ $ATTEMPTS -lt 15 ]; do
ATTEMPTS=$((ATTEMPTS + 1))
# Check using docker directly (more reliable than docker compose ps)
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
if [ -z "$CONTAINER_STATUS" ]; then
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
echo "✅ Migrations completed successfully"
docker compose logs migrate --tail=5 2>/dev/null || true
exit 0
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
echo "❌ Migrations failed with exit code: $EXIT_CODE"
echo "Migration logs:"
docker compose logs migrate --tail=20 2>/dev/null || true
exit 1
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
else
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
fi
sleep 2
done
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
echo "Final container check:"
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
echo "Migration logs (if available):"
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
' || echo "⚠️ Migration check completed with warnings, continuing..."
# Brief wait for other services to stabilize
echo "Waiting 5 seconds for other services to stabilize..."
sleep 5
# Verify installations and provide environment info
- name: Verify setup and show environment info
run: |
echo "=== Python Setup ==="
python --version
poetry --version
echo "=== Node.js Setup ==="
node --version
pnpm --version
echo "=== Additional Tools ==="
docker --version
docker compose version
gh --version || true
echo "=== Services Status ==="
cd autogpt_platform
docker compose ps || true
echo "=== Backend Dependencies ==="
cd backend
poetry show | head -10 || true
echo "=== Frontend Dependencies ==="
cd ../frontend
pnpm list --depth=0 | head -10 || true
echo "=== Environment Files ==="
ls -la ../.env* || true
ls -la .env* || true
ls -la ../backend/.env* || true
echo "✅ AutoGPT Platform development environment setup complete!"
echo "🚀 Ready for development with Docker services running"
echo "📝 Backend server: poetry run serve (port 8000)"
echo "🌐 Frontend server: pnpm dev (port 3000)"

View File

@@ -32,7 +32,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.12", "3.13"]
python-version: ["3.11"]
runs-on: ubuntu-latest
services:
@@ -201,7 +201,7 @@ jobs:
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"

View File

@@ -149,23 +149,14 @@ Key models (defined in `/backend/schema.prisma`):
**Adding a new block:**
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `/backend/backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
2. Inherit from `Block` base class
3. Define input/output schemas
4. Implement `run` method
5. Register in block registry
6. Generate the block uuid using `uuid.uuid4()`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
**Modifying the API:**

View File

@@ -1,13 +1,13 @@
from .config import verify_settings
from .dependencies import get_user_id, requires_admin_user, requires_user
from .helpers import add_auth_responses_to_openapi
from .depends import requires_admin_user, requires_user
from .jwt_utils import parse_jwt_token
from .middleware import APIKeyValidator, auth_middleware
from .models import User
__all__ = [
"verify_settings",
"get_user_id",
"requires_admin_user",
"parse_jwt_token",
"requires_user",
"add_auth_responses_to_openapi",
"requires_admin_user",
"APIKeyValidator",
"auth_middleware",
"User",
]

View File

@@ -1,90 +1,11 @@
import logging
import os
from jwt.algorithms import get_default_algorithms, has_crypto
logger = logging.getLogger(__name__)
class AuthConfigError(ValueError):
"""Raised when authentication configuration is invalid."""
pass
ALGO_RECOMMENDATION = (
"We highly recommend using an asymmetric algorithm such as ES256, "
"because when leaked, a shared secret would allow anyone to "
"forge valid tokens and impersonate users. "
"More info: https://supabase.com/docs/guides/auth/signing-keys#choosing-the-right-signing-algorithm" # noqa
)
class Settings:
def __init__(self):
self.JWT_VERIFY_KEY: str = os.getenv(
"JWT_VERIFY_KEY", os.getenv("SUPABASE_JWT_SECRET", "")
).strip()
self.JWT_ALGORITHM: str = os.getenv("JWT_SIGN_ALGORITHM", "HS256").strip()
self.validate()
def validate(self):
if not self.JWT_VERIFY_KEY:
raise AuthConfigError(
"JWT_VERIFY_KEY must be set. "
"An empty JWT secret would allow anyone to forge valid tokens."
)
if len(self.JWT_VERIFY_KEY) < 32:
logger.warning(
"⚠️ JWT_VERIFY_KEY appears weak (less than 32 characters). "
"Consider using a longer, cryptographically secure secret."
)
supported_algorithms = get_default_algorithms().keys()
if not has_crypto:
logger.warning(
"⚠️ Asymmetric JWT verification is not available "
"because the 'cryptography' package is not installed. "
+ ALGO_RECOMMENDATION
)
if (
self.JWT_ALGORITHM not in supported_algorithms
or self.JWT_ALGORITHM == "none"
):
raise AuthConfigError(
f"Invalid JWT_SIGN_ALGORITHM: '{self.JWT_ALGORITHM}'. "
"Supported algorithms are listed on "
"https://pyjwt.readthedocs.io/en/stable/algorithms.html"
)
if self.JWT_ALGORITHM.startswith("HS"):
logger.warning(
f"⚠️ JWT_SIGN_ALGORITHM is set to '{self.JWT_ALGORITHM}', "
"a symmetric shared-key signature algorithm. " + ALGO_RECOMMENDATION
)
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
self.JWT_ALGORITHM: str = "HS256"
_settings: Settings = None # type: ignore
def get_settings() -> Settings:
global _settings
if not _settings:
_settings = Settings()
return _settings
def verify_settings() -> None:
global _settings
if not _settings:
_settings = Settings() # calls validation indirectly
return
_settings.validate()
settings = Settings()

View File

@@ -1,306 +0,0 @@
"""
Comprehensive tests for auth configuration to ensure 100% line and branch coverage.
These tests verify critical security checks preventing JWT token forgery.
"""
import logging
import os
import pytest
from pytest_mock import MockerFixture
from autogpt_libs.auth.config import AuthConfigError, Settings
def test_environment_variable_precedence(mocker: MockerFixture):
"""Test that environment variables take precedence over defaults."""
secret = "environment-secret-key-with-proper-length-123456"
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == secret
def test_environment_variable_backwards_compatible(mocker: MockerFixture):
"""Test that SUPABASE_JWT_SECRET is read if JWT_VERIFY_KEY is not set."""
secret = "environment-secret-key-with-proper-length-123456"
mocker.patch.dict(os.environ, {"SUPABASE_JWT_SECRET": secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == secret
def test_auth_config_error_inheritance():
"""Test that AuthConfigError is properly defined as an Exception."""
assert issubclass(AuthConfigError, Exception)
error = AuthConfigError("test message")
assert str(error) == "test message"
def test_settings_static_after_creation(mocker: MockerFixture):
"""Test that settings maintain their values after creation."""
secret = "immutable-secret-key-with-proper-length-12345"
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
settings = Settings()
original_secret = settings.JWT_VERIFY_KEY
# Changing environment after creation shouldn't affect settings
os.environ["JWT_VERIFY_KEY"] = "different-secret"
assert settings.JWT_VERIFY_KEY == original_secret
def test_settings_load_with_valid_secret(mocker: MockerFixture):
"""Test auth enabled with a valid JWT secret."""
valid_secret = "a" * 32 # 32 character secret
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": valid_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == valid_secret
def test_settings_load_with_strong_secret(mocker: MockerFixture):
"""Test auth enabled with a cryptographically strong secret."""
strong_secret = "super-secret-jwt-token-with-at-least-32-characters-long"
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": strong_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == strong_secret
assert len(settings.JWT_VERIFY_KEY) >= 32
def test_secret_empty_raises_error(mocker: MockerFixture):
"""Test that auth enabled with empty secret raises AuthConfigError."""
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": ""}, clear=True)
with pytest.raises(Exception) as exc_info:
Settings()
assert "JWT_VERIFY_KEY" in str(exc_info.value)
def test_secret_missing_raises_error(mocker: MockerFixture):
"""Test that auth enabled without secret env var raises AuthConfigError."""
mocker.patch.dict(os.environ, {}, clear=True)
with pytest.raises(Exception) as exc_info:
Settings()
assert "JWT_VERIFY_KEY" in str(exc_info.value)
@pytest.mark.parametrize("secret", [" ", " ", "\t", "\n", " \t\n "])
def test_secret_only_whitespace_raises_error(mocker: MockerFixture, secret: str):
"""Test that auth enabled with whitespace-only secret raises error."""
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
with pytest.raises(ValueError):
Settings()
def test_secret_weak_logs_warning(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test that weak JWT secret triggers warning log."""
weak_secret = "short" # Less than 32 characters
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": weak_secret}, clear=True)
with caplog.at_level(logging.WARNING):
settings = Settings()
assert settings.JWT_VERIFY_KEY == weak_secret
assert "key appears weak" in caplog.text.lower()
assert "less than 32 characters" in caplog.text
def test_secret_31_char_logs_warning(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test that 31-character secret triggers warning (boundary test)."""
secret_31 = "a" * 31 # Exactly 31 characters
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_31}, clear=True)
with caplog.at_level(logging.WARNING):
settings = Settings()
assert len(settings.JWT_VERIFY_KEY) == 31
assert "key appears weak" in caplog.text.lower()
def test_secret_32_char_no_warning(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test that 32-character secret does not trigger warning (boundary test)."""
secret_32 = "a" * 32 # Exactly 32 characters
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_32}, clear=True)
with caplog.at_level(logging.WARNING):
settings = Settings()
assert len(settings.JWT_VERIFY_KEY) == 32
assert "JWT secret appears weak" not in caplog.text
def test_secret_whitespace_stripped(mocker: MockerFixture):
"""Test that JWT secret whitespace is stripped."""
secret = "a" * 32
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": f" {secret} "}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == secret
def test_secret_with_special_characters(mocker: MockerFixture):
"""Test JWT secret with special characters."""
special_secret = "!@#$%^&*()_+-=[]{}|;:,.<>?`~" + "a" * 10 # 40 chars total
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": special_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == special_secret
def test_secret_with_unicode(mocker: MockerFixture):
"""Test JWT secret with unicode characters."""
unicode_secret = "秘密🔐キー" + "a" * 25 # Ensure >32 bytes
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": unicode_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == unicode_secret
def test_secret_very_long(mocker: MockerFixture):
"""Test JWT secret with excessive length."""
long_secret = "a" * 1000 # 1000 character secret
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": long_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == long_secret
assert len(settings.JWT_VERIFY_KEY) == 1000
def test_secret_with_newline(mocker: MockerFixture):
"""Test JWT secret containing newlines."""
multiline_secret = "secret\nwith\nnewlines" + "a" * 20
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": multiline_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == multiline_secret
def test_secret_base64_encoded(mocker: MockerFixture):
"""Test JWT secret that looks like base64."""
base64_secret = "dGhpc19pc19hX3NlY3JldF9rZXlfd2l0aF9wcm9wZXJfbGVuZ3Ro"
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": base64_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == base64_secret
def test_secret_numeric_only(mocker: MockerFixture):
"""Test JWT secret with only numbers."""
numeric_secret = "1234567890" * 4 # 40 character numeric secret
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": numeric_secret}, clear=True)
settings = Settings()
assert settings.JWT_VERIFY_KEY == numeric_secret
def test_algorithm_default_hs256(mocker: MockerFixture):
"""Test that JWT algorithm defaults to HS256."""
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": "a" * 32}, clear=True)
settings = Settings()
assert settings.JWT_ALGORITHM == "HS256"
def test_algorithm_whitespace_stripped(mocker: MockerFixture):
"""Test that JWT algorithm whitespace is stripped."""
secret = "a" * 32
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": " HS256 "},
clear=True,
)
settings = Settings()
assert settings.JWT_ALGORITHM == "HS256"
def test_no_crypto_warning(mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
"""Test warning when crypto package is not available."""
secret = "a" * 32
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
# Mock has_crypto to return False
mocker.patch("autogpt_libs.auth.config.has_crypto", False)
with caplog.at_level(logging.WARNING):
Settings()
assert "Asymmetric JWT verification is not available" in caplog.text
assert "cryptography" in caplog.text
def test_algorithm_invalid_raises_error(mocker: MockerFixture):
"""Test that invalid JWT algorithm raises AuthConfigError."""
secret = "a" * 32
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "INVALID_ALG"},
clear=True,
)
with pytest.raises(AuthConfigError) as exc_info:
Settings()
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
assert "INVALID_ALG" in str(exc_info.value)
def test_algorithm_none_raises_error(mocker: MockerFixture):
"""Test that 'none' algorithm raises AuthConfigError."""
secret = "a" * 32
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "none"},
clear=True,
)
with pytest.raises(AuthConfigError) as exc_info:
Settings()
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
@pytest.mark.parametrize("algorithm", ["HS256", "HS384", "HS512"])
def test_algorithm_symmetric_warning(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
):
"""Test warning for symmetric algorithms (HS256, HS384, HS512)."""
secret = "a" * 32
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
clear=True,
)
with caplog.at_level(logging.WARNING):
settings = Settings()
assert algorithm in caplog.text
assert "symmetric shared-key signature algorithm" in caplog.text
assert settings.JWT_ALGORITHM == algorithm
@pytest.mark.parametrize(
"algorithm",
["ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512"],
)
def test_algorithm_asymmetric_no_warning(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
):
"""Test that asymmetric algorithms do not trigger warning."""
secret = "a" * 32
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
clear=True,
)
with caplog.at_level(logging.WARNING):
settings = Settings()
# Should not contain the symmetric algorithm warning
assert "symmetric shared-key signature algorithm" not in caplog.text
assert settings.JWT_ALGORITHM == algorithm

View File

@@ -1,45 +0,0 @@
"""
FastAPI dependency functions for JWT-based authentication and authorization.
These are the high-level dependency functions used in route definitions.
"""
import fastapi
from .jwt_utils import get_jwt_payload, verify_user
from .models import User
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
"""
FastAPI dependency that requires a valid authenticated user.
Raises:
HTTPException: 401 for authentication failures
"""
return verify_user(jwt_payload, admin_only=False)
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
"""
FastAPI dependency that requires a valid admin user.
Raises:
HTTPException: 401 for authentication failures, 403 for insufficient permissions
"""
return verify_user(jwt_payload, admin_only=True)
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
"""
FastAPI dependency that returns the ID of the authenticated user.
Raises:
HTTPException: 401 for authentication failures or missing user ID
"""
user_id = jwt_payload.get("sub")
if not user_id:
raise fastapi.HTTPException(
status_code=401, detail="User ID not found in token"
)
return user_id

View File

@@ -1,335 +0,0 @@
"""
Comprehensive integration tests for authentication dependencies.
Tests the full authentication flow from HTTP requests to user validation.
"""
import os
import pytest
from fastapi import FastAPI, HTTPException, Security
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
from autogpt_libs.auth.dependencies import (
get_user_id,
requires_admin_user,
requires_user,
)
from autogpt_libs.auth.models import User
class TestAuthDependencies:
"""Test suite for authentication dependency functions."""
@pytest.fixture
def app(self):
"""Create a test FastAPI application."""
app = FastAPI()
@app.get("/user")
def get_user_endpoint(user: User = Security(requires_user)):
return {"user_id": user.user_id, "role": user.role}
@app.get("/admin")
def get_admin_endpoint(user: User = Security(requires_admin_user)):
return {"user_id": user.user_id, "role": user.role}
@app.get("/user-id")
def get_user_id_endpoint(user_id: str = Security(get_user_id)):
return {"user_id": user_id}
return app
@pytest.fixture
def client(self, app):
"""Create a test client."""
return TestClient(app)
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
"""Test requires_user with valid JWT payload."""
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
# Mock get_jwt_payload to return our test payload
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = requires_user(jwt_payload)
assert isinstance(user, User)
assert user.user_id == "user-123"
assert user.role == "user"
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
"""Test requires_user accepts admin users."""
jwt_payload = {
"sub": "admin-456",
"role": "admin",
"email": "admin@example.com",
}
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = requires_user(jwt_payload)
assert user.user_id == "admin-456"
assert user.role == "admin"
def test_requires_user_missing_sub(self):
"""Test requires_user with missing user ID."""
jwt_payload = {"role": "user", "email": "user@example.com"}
with pytest.raises(HTTPException) as exc_info:
requires_user(jwt_payload)
assert exc_info.value.status_code == 401
assert "User ID not found" in exc_info.value.detail
def test_requires_user_empty_sub(self):
"""Test requires_user with empty user ID."""
jwt_payload = {"sub": "", "role": "user"}
with pytest.raises(HTTPException) as exc_info:
requires_user(jwt_payload)
assert exc_info.value.status_code == 401
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
"""Test requires_admin_user with admin role."""
jwt_payload = {
"sub": "admin-789",
"role": "admin",
"email": "admin@example.com",
}
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = requires_admin_user(jwt_payload)
assert user.user_id == "admin-789"
assert user.role == "admin"
def test_requires_admin_user_with_regular_user(self):
"""Test requires_admin_user rejects regular users."""
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
with pytest.raises(HTTPException) as exc_info:
requires_admin_user(jwt_payload)
assert exc_info.value.status_code == 403
assert "Admin access required" in exc_info.value.detail
def test_requires_admin_user_missing_role(self):
"""Test requires_admin_user with missing role."""
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
with pytest.raises(KeyError):
requires_admin_user(jwt_payload)
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
"""Test get_user_id extracts user ID correctly."""
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user_id = get_user_id(jwt_payload)
assert user_id == "user-id-xyz"
def test_get_user_id_missing_sub(self):
"""Test get_user_id with missing user ID."""
jwt_payload = {"role": "user"}
with pytest.raises(HTTPException) as exc_info:
get_user_id(jwt_payload)
assert exc_info.value.status_code == 401
assert "User ID not found" in exc_info.value.detail
def test_get_user_id_none_sub(self):
"""Test get_user_id with None user ID."""
jwt_payload = {"sub": None, "role": "user"}
with pytest.raises(HTTPException) as exc_info:
get_user_id(jwt_payload)
assert exc_info.value.status_code == 401
class TestAuthDependenciesIntegration:
"""Integration tests for auth dependencies with FastAPI."""
acceptable_jwt_secret = "test-secret-with-proper-length-123456"
@pytest.fixture
def create_token(self, mocker: MockerFixture):
"""Helper to create JWT tokens."""
import jwt
mocker.patch.dict(
os.environ,
{"JWT_VERIFY_KEY": self.acceptable_jwt_secret},
clear=True,
)
def _create_token(payload, secret=self.acceptable_jwt_secret):
return jwt.encode(payload, secret, algorithm="HS256")
return _create_token
def test_endpoint_auth_enabled_no_token(self):
"""Test endpoints require token when auth is enabled."""
app = FastAPI()
@app.get("/test")
def test_endpoint(user: User = Security(requires_user)):
return {"user_id": user.user_id}
client = TestClient(app)
# Should fail without auth header
response = client.get("/test")
assert response.status_code == 401
def test_endpoint_with_valid_token(self, create_token):
"""Test endpoint with valid JWT token."""
app = FastAPI()
@app.get("/test")
def test_endpoint(user: User = Security(requires_user)):
return {"user_id": user.user_id, "role": user.role}
client = TestClient(app)
token = create_token(
{"sub": "test-user", "role": "user", "aud": "authenticated"},
secret=self.acceptable_jwt_secret,
)
response = client.get("/test", headers={"Authorization": f"Bearer {token}"})
assert response.status_code == 200
assert response.json()["user_id"] == "test-user"
def test_admin_endpoint_requires_admin_role(self, create_token):
"""Test admin endpoint rejects non-admin users."""
app = FastAPI()
@app.get("/admin")
def admin_endpoint(user: User = Security(requires_admin_user)):
return {"user_id": user.user_id}
client = TestClient(app)
# Regular user token
user_token = create_token(
{"sub": "regular-user", "role": "user", "aud": "authenticated"},
secret=self.acceptable_jwt_secret,
)
response = client.get(
"/admin", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == 403
# Admin token
admin_token = create_token(
{"sub": "admin-user", "role": "admin", "aud": "authenticated"},
secret=self.acceptable_jwt_secret,
)
response = client.get(
"/admin", headers={"Authorization": f"Bearer {admin_token}"}
)
assert response.status_code == 200
assert response.json()["user_id"] == "admin-user"
class TestAuthDependenciesEdgeCases:
"""Edge case tests for authentication dependencies."""
def test_dependency_with_complex_payload(self):
"""Test dependencies handle complex JWT payloads."""
complex_payload = {
"sub": "user-123",
"role": "admin",
"email": "test@example.com",
"app_metadata": {"provider": "email", "providers": ["email"]},
"user_metadata": {
"full_name": "Test User",
"avatar_url": "https://example.com/avatar.jpg",
},
"aud": "authenticated",
"iat": 1234567890,
"exp": 9999999999,
}
user = requires_user(complex_payload)
assert user.user_id == "user-123"
assert user.email == "test@example.com"
admin = requires_admin_user(complex_payload)
assert admin.role == "admin"
def test_dependency_with_unicode_in_payload(self):
"""Test dependencies handle unicode in JWT payloads."""
unicode_payload = {
"sub": "user-😀-123",
"role": "user",
"email": "测试@example.com",
"name": "日本語",
}
user = requires_user(unicode_payload)
assert "😀" in user.user_id
assert user.email == "测试@example.com"
def test_dependency_with_null_values(self):
"""Test dependencies handle null values in payload."""
null_payload = {
"sub": "user-123",
"role": "user",
"email": None,
"phone": None,
"metadata": None,
}
user = requires_user(null_payload)
assert user.user_id == "user-123"
assert user.email is None
def test_concurrent_requests_isolation(self):
"""Test that concurrent requests don't interfere with each other."""
payload1 = {"sub": "user-1", "role": "user"}
payload2 = {"sub": "user-2", "role": "admin"}
# Simulate concurrent processing
user1 = requires_user(payload1)
user2 = requires_admin_user(payload2)
assert user1.user_id == "user-1"
assert user2.user_id == "user-2"
assert user1.role == "user"
assert user2.role == "admin"
@pytest.mark.parametrize(
"payload,expected_error,admin_only",
[
(None, "Authorization header is missing", False),
({}, "User ID not found", False),
({"sub": ""}, "User ID not found", False),
({"role": "user"}, "User ID not found", False),
({"sub": "user", "role": "user"}, "Admin access required", True),
],
)
def test_dependency_error_cases(
self, payload, expected_error: str, admin_only: bool
):
"""Test that errors propagate correctly through dependencies."""
# Import verify_user to test it directly since dependencies use FastAPI Security
from autogpt_libs.auth.jwt_utils import verify_user
with pytest.raises(HTTPException) as exc_info:
verify_user(payload, admin_only=admin_only)
assert expected_error in exc_info.value.detail
def test_dependency_valid_user(self):
"""Test valid user case for dependency."""
# Import verify_user to test it directly since dependencies use FastAPI Security
from autogpt_libs.auth.jwt_utils import verify_user
# Valid case
user = verify_user({"sub": "user", "role": "user"}, admin_only=False)
assert user.user_id == "user"

View File

@@ -0,0 +1,46 @@
import fastapi
from .config import settings
from .middleware import auth_middleware
from .models import DEFAULT_USER_ID, User
def requires_user(payload: dict = fastapi.Depends(auth_middleware)) -> User:
return verify_user(payload, admin_only=False)
def requires_admin_user(
payload: dict = fastapi.Depends(auth_middleware),
) -> User:
return verify_user(payload, admin_only=True)
def verify_user(payload: dict | None, admin_only: bool) -> User:
if not payload:
if settings.ENABLE_AUTH:
raise fastapi.HTTPException(
status_code=401, detail="Authorization header is missing"
)
# This handles the case when authentication is disabled
payload = {"sub": DEFAULT_USER_ID, "role": "admin"}
user_id = payload.get("sub")
if not user_id:
raise fastapi.HTTPException(
status_code=401, detail="User ID not found in token"
)
if admin_only and payload["role"] != "admin":
raise fastapi.HTTPException(status_code=403, detail="Admin access required")
return User.from_payload(payload)
def get_user_id(payload: dict = fastapi.Depends(auth_middleware)) -> str:
user_id = payload.get("sub")
if not user_id:
raise fastapi.HTTPException(
status_code=401, detail="User ID not found in token"
)
return user_id

View File

@@ -0,0 +1,68 @@
import pytest
from .depends import requires_admin_user, requires_user, verify_user
def test_verify_user_no_payload():
user = verify_user(None, admin_only=False)
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
assert user.role == "admin"
def test_verify_user_no_user_id():
with pytest.raises(Exception):
verify_user({"role": "admin"}, admin_only=False)
def test_verify_user_not_admin():
with pytest.raises(Exception):
verify_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
admin_only=True,
)
def test_verify_user_with_admin_role():
user = verify_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"},
admin_only=True,
)
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
assert user.role == "admin"
def test_verify_user_with_user_role():
user = verify_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
admin_only=False,
)
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
assert user.role == "user"
def test_requires_user():
user = requires_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
)
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
assert user.role == "user"
def test_requires_user_no_user_id():
with pytest.raises(Exception):
requires_user({"role": "user"})
def test_requires_admin_user():
user = requires_admin_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"}
)
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
assert user.role == "admin"
def test_requires_admin_user_not_admin():
with pytest.raises(Exception):
requires_admin_user(
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
)

View File

@@ -1,68 +0,0 @@
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from .jwt_utils import bearer_jwt_auth
def add_auth_responses_to_openapi(app: FastAPI) -> None:
"""
Set up custom OpenAPI schema generation that adds 401 responses
to all authenticated endpoints.
This is needed when using HTTPBearer with auto_error=False to get proper
401 responses instead of 403, but FastAPI only automatically adds security
responses when auto_error=True.
"""
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# Add 401 response to all endpoints that have security requirements
for path, methods in openapi_schema["paths"].items():
for method, details in methods.items():
security_schemas = [
schema
for auth_option in details.get("security", [])
for schema in auth_option.keys()
]
if bearer_jwt_auth.scheme_name not in security_schemas:
continue
if "responses" not in details:
details["responses"] = {}
details["responses"]["401"] = {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
# Ensure #/components/responses exists
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "responses" not in openapi_schema["components"]:
openapi_schema["components"]["responses"] = {}
# Define 401 response
openapi_schema["components"]["responses"]["HTTP401NotAuthenticatedError"] = {
"description": "Authentication required",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {"detail": {"type": "string"}},
}
}
},
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi

View File

@@ -1,435 +0,0 @@
"""
Comprehensive tests for auth helpers module to achieve 100% coverage.
Tests OpenAPI schema generation and authentication response handling.
"""
from unittest import mock
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from autogpt_libs.auth.helpers import add_auth_responses_to_openapi
from autogpt_libs.auth.jwt_utils import bearer_jwt_auth
def test_add_auth_responses_to_openapi_basic():
"""Test adding 401 responses to OpenAPI schema."""
app = FastAPI(title="Test App", version="1.0.0")
# Add some test endpoints with authentication
from fastapi import Depends
from autogpt_libs.auth.dependencies import requires_user
@app.get("/protected", dependencies=[Depends(requires_user)])
def protected_endpoint():
return {"message": "Protected"}
@app.get("/public")
def public_endpoint():
return {"message": "Public"}
# Apply the OpenAPI customization
add_auth_responses_to_openapi(app)
# Get the OpenAPI schema
schema = app.openapi()
# Verify basic schema properties
assert schema["info"]["title"] == "Test App"
assert schema["info"]["version"] == "1.0.0"
# Verify 401 response component is added
assert "components" in schema
assert "responses" in schema["components"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
# Verify 401 response structure
error_response = schema["components"]["responses"]["HTTP401NotAuthenticatedError"]
assert error_response["description"] == "Authentication required"
assert "application/json" in error_response["content"]
assert "schema" in error_response["content"]["application/json"]
# Verify schema properties
response_schema = error_response["content"]["application/json"]["schema"]
assert response_schema["type"] == "object"
assert "detail" in response_schema["properties"]
assert response_schema["properties"]["detail"]["type"] == "string"
def test_add_auth_responses_to_openapi_with_security():
"""Test that 401 responses are added only to secured endpoints."""
app = FastAPI()
# Mock endpoint with security
from fastapi import Security
from autogpt_libs.auth.dependencies import get_user_id
@app.get("/secured")
def secured_endpoint(user_id: str = Security(get_user_id)):
return {"user_id": user_id}
@app.post("/also-secured")
def another_secured(user_id: str = Security(get_user_id)):
return {"status": "ok"}
@app.get("/unsecured")
def unsecured_endpoint():
return {"public": True}
# Apply OpenAPI customization
add_auth_responses_to_openapi(app)
# Get schema
schema = app.openapi()
# Check that secured endpoints have 401 responses
if "/secured" in schema["paths"]:
if "get" in schema["paths"]["/secured"]:
secured_get = schema["paths"]["/secured"]["get"]
if "responses" in secured_get:
assert "401" in secured_get["responses"]
assert (
secured_get["responses"]["401"]["$ref"]
== "#/components/responses/HTTP401NotAuthenticatedError"
)
if "/also-secured" in schema["paths"]:
if "post" in schema["paths"]["/also-secured"]:
secured_post = schema["paths"]["/also-secured"]["post"]
if "responses" in secured_post:
assert "401" in secured_post["responses"]
# Check that unsecured endpoint does not have 401 response
if "/unsecured" in schema["paths"]:
if "get" in schema["paths"]["/unsecured"]:
unsecured_get = schema["paths"]["/unsecured"]["get"]
if "responses" in unsecured_get:
assert "401" not in unsecured_get.get("responses", {})
def test_add_auth_responses_to_openapi_cached_schema():
"""Test that OpenAPI schema is cached after first generation."""
app = FastAPI()
# Apply customization
add_auth_responses_to_openapi(app)
# Get schema twice
schema1 = app.openapi()
schema2 = app.openapi()
# Should return the same cached object
assert schema1 is schema2
def test_add_auth_responses_to_openapi_existing_responses():
"""Test handling endpoints that already have responses defined."""
app = FastAPI()
from fastapi import Security
from autogpt_libs.auth.jwt_utils import get_jwt_payload
@app.get(
"/with-responses",
responses={
200: {"description": "Success"},
404: {"description": "Not found"},
},
)
def endpoint_with_responses(jwt: dict = Security(get_jwt_payload)):
return {"data": "test"}
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Check that existing responses are preserved and 401 is added
if "/with-responses" in schema["paths"]:
if "get" in schema["paths"]["/with-responses"]:
responses = schema["paths"]["/with-responses"]["get"].get("responses", {})
# Original responses should be preserved
if "200" in responses:
assert responses["200"]["description"] == "Success"
if "404" in responses:
assert responses["404"]["description"] == "Not found"
# 401 should be added
if "401" in responses:
assert (
responses["401"]["$ref"]
== "#/components/responses/HTTP401NotAuthenticatedError"
)
def test_add_auth_responses_to_openapi_no_security_endpoints():
"""Test with app that has no secured endpoints."""
app = FastAPI()
@app.get("/public1")
def public1():
return {"message": "public1"}
@app.post("/public2")
def public2():
return {"message": "public2"}
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Component should still be added for consistency
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
# But no endpoints should have 401 responses
for path in schema["paths"].values():
for method in path.values():
if isinstance(method, dict) and "responses" in method:
assert "401" not in method["responses"]
def test_add_auth_responses_to_openapi_multiple_security_schemes():
"""Test endpoints with multiple security requirements."""
app = FastAPI()
from fastapi import Security
from autogpt_libs.auth.dependencies import requires_admin_user, requires_user
from autogpt_libs.auth.models import User
@app.get("/multi-auth")
def multi_auth(
user: User = Security(requires_user),
admin: User = Security(requires_admin_user),
):
return {"status": "super secure"}
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Should have 401 response
if "/multi-auth" in schema["paths"]:
if "get" in schema["paths"]["/multi-auth"]:
responses = schema["paths"]["/multi-auth"]["get"].get("responses", {})
if "401" in responses:
assert (
responses["401"]["$ref"]
== "#/components/responses/HTTP401NotAuthenticatedError"
)
def test_add_auth_responses_to_openapi_empty_components():
"""Test when OpenAPI schema has no components section initially."""
app = FastAPI()
# Mock get_openapi to return schema without components
original_get_openapi = get_openapi
def mock_get_openapi(*args, **kwargs):
schema = original_get_openapi(*args, **kwargs)
# Remove components if it exists
if "components" in schema:
del schema["components"]
return schema
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Components should be created
assert "components" in schema
assert "responses" in schema["components"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
def test_add_auth_responses_to_openapi_all_http_methods():
"""Test that all HTTP methods are handled correctly."""
app = FastAPI()
from fastapi import Security
from autogpt_libs.auth.jwt_utils import get_jwt_payload
@app.get("/resource")
def get_resource(jwt: dict = Security(get_jwt_payload)):
return {"method": "GET"}
@app.post("/resource")
def post_resource(jwt: dict = Security(get_jwt_payload)):
return {"method": "POST"}
@app.put("/resource")
def put_resource(jwt: dict = Security(get_jwt_payload)):
return {"method": "PUT"}
@app.patch("/resource")
def patch_resource(jwt: dict = Security(get_jwt_payload)):
return {"method": "PATCH"}
@app.delete("/resource")
def delete_resource(jwt: dict = Security(get_jwt_payload)):
return {"method": "DELETE"}
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# All methods should have 401 response
if "/resource" in schema["paths"]:
for method in ["get", "post", "put", "patch", "delete"]:
if method in schema["paths"]["/resource"]:
method_spec = schema["paths"]["/resource"][method]
if "responses" in method_spec:
assert "401" in method_spec["responses"]
def test_bearer_jwt_auth_scheme_config():
"""Test that bearer_jwt_auth is configured correctly."""
assert bearer_jwt_auth.scheme_name == "HTTPBearerJWT"
assert bearer_jwt_auth.auto_error is False
def test_add_auth_responses_with_no_routes():
"""Test OpenAPI generation with app that has no routes."""
app = FastAPI(title="Empty App")
# Apply customization to empty app
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Should still have basic structure
assert schema["info"]["title"] == "Empty App"
assert "components" in schema
assert "responses" in schema["components"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
def test_custom_openapi_function_replacement():
"""Test that the custom openapi function properly replaces the default."""
app = FastAPI()
# Store original function
original_openapi = app.openapi
# Apply customization
add_auth_responses_to_openapi(app)
# Function should be replaced
assert app.openapi != original_openapi
assert callable(app.openapi)
def test_endpoint_without_responses_section():
"""Test endpoint that has security but no responses section initially."""
app = FastAPI()
from fastapi import Security
from fastapi.openapi.utils import get_openapi as original_get_openapi
from autogpt_libs.auth.jwt_utils import get_jwt_payload
# Create endpoint
@app.get("/no-responses")
def endpoint_without_responses(jwt: dict = Security(get_jwt_payload)):
return {"data": "test"}
# Mock get_openapi to remove responses from the endpoint
def mock_get_openapi(*args, **kwargs):
schema = original_get_openapi(*args, **kwargs)
# Remove responses from our endpoint to trigger line 40
if "/no-responses" in schema.get("paths", {}):
if "get" in schema["paths"]["/no-responses"]:
# Delete responses to force the code to create it
if "responses" in schema["paths"]["/no-responses"]["get"]:
del schema["paths"]["/no-responses"]["get"]["responses"]
return schema
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
# Apply customization
add_auth_responses_to_openapi(app)
# Get schema and verify 401 was added
schema = app.openapi()
# The endpoint should now have 401 response
if "/no-responses" in schema["paths"]:
if "get" in schema["paths"]["/no-responses"]:
responses = schema["paths"]["/no-responses"]["get"].get("responses", {})
assert "401" in responses
assert (
responses["401"]["$ref"]
== "#/components/responses/HTTP401NotAuthenticatedError"
)
def test_components_with_existing_responses():
"""Test when components already has a responses section."""
app = FastAPI()
# Mock get_openapi to return schema with existing components/responses
from fastapi.openapi.utils import get_openapi as original_get_openapi
def mock_get_openapi(*args, **kwargs):
schema = original_get_openapi(*args, **kwargs)
# Add existing components/responses
if "components" not in schema:
schema["components"] = {}
schema["components"]["responses"] = {
"ExistingResponse": {"description": "An existing response"}
}
return schema
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
# Apply customization
add_auth_responses_to_openapi(app)
schema = app.openapi()
# Both responses should exist
assert "ExistingResponse" in schema["components"]["responses"]
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
# Verify our 401 response structure
error_response = schema["components"]["responses"][
"HTTP401NotAuthenticatedError"
]
assert error_response["description"] == "Authentication required"
def test_openapi_schema_persistence():
"""Test that modifications to OpenAPI schema persist correctly."""
app = FastAPI()
from fastapi import Security
from autogpt_libs.auth.jwt_utils import get_jwt_payload
@app.get("/test")
def test_endpoint(jwt: dict = Security(get_jwt_payload)):
return {"test": True}
# Apply customization
add_auth_responses_to_openapi(app)
# Get schema multiple times
schema1 = app.openapi()
# Modify the cached schema (shouldn't affect future calls)
schema1["info"]["title"] = "Modified Title"
# Clear cache and get again
app.openapi_schema = None
schema2 = app.openapi()
# Should regenerate with original title
assert schema2["info"]["title"] == app.title
assert schema2["info"]["title"] != "Modified Title"

View File

@@ -1,48 +1,11 @@
import logging
from typing import Any
from typing import Any, Dict
import jwt
from fastapi import HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from .config import get_settings
from .models import User
logger = logging.getLogger(__name__)
# Bearer token authentication scheme
bearer_jwt_auth = HTTPBearer(
bearerFormat="jwt", scheme_name="HTTPBearerJWT", auto_error=False
)
from .config import settings
def get_jwt_payload(
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
) -> dict[str, Any]:
"""
Extract and validate JWT payload from HTTP Authorization header.
This is the core authentication function that handles:
- Reading the `Authorization` header to obtain the JWT token
- Verifying the JWT token's signature
- Decoding the JWT token's payload
:param credentials: HTTP Authorization credentials from bearer token
:return: JWT payload dictionary
:raises HTTPException: 401 if authentication fails
"""
if not credentials:
raise HTTPException(status_code=401, detail="Authorization header is missing")
try:
payload = parse_jwt_token(credentials.credentials)
logger.debug("Token decoded successfully")
return payload
except ValueError as e:
raise HTTPException(status_code=401, detail=str(e))
def parse_jwt_token(token: str) -> dict[str, Any]:
def parse_jwt_token(token: str) -> Dict[str, Any]:
"""
Parse and validate a JWT token.
@@ -50,11 +13,10 @@ def parse_jwt_token(token: str) -> dict[str, Any]:
:return: The decoded payload
:raises ValueError: If the token is invalid or expired
"""
settings = get_settings()
try:
payload = jwt.decode(
token,
settings.JWT_VERIFY_KEY,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
audience="authenticated",
)
@@ -63,18 +25,3 @@ def parse_jwt_token(token: str) -> dict[str, Any]:
raise ValueError("Token has expired")
except jwt.InvalidTokenError as e:
raise ValueError(f"Invalid token: {str(e)}")
def verify_user(jwt_payload: dict | None, admin_only: bool) -> User:
if jwt_payload is None:
raise HTTPException(status_code=401, detail="Authorization header is missing")
user_id = jwt_payload.get("sub")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in token")
if admin_only and jwt_payload["role"] != "admin":
raise HTTPException(status_code=403, detail="Admin access required")
return User.from_payload(jwt_payload)

View File

@@ -1,308 +0,0 @@
"""
Comprehensive tests for JWT token parsing and validation.
Ensures 100% line and branch coverage for JWT security functions.
"""
import os
from datetime import datetime, timedelta, timezone
import jwt
import pytest
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from pytest_mock import MockerFixture
from autogpt_libs.auth import config, jwt_utils
from autogpt_libs.auth.config import Settings
from autogpt_libs.auth.models import User
MOCK_JWT_SECRET = "test-secret-key-with-at-least-32-characters"
TEST_USER_PAYLOAD = {
"sub": "test-user-id",
"role": "user",
"aud": "authenticated",
"email": "test@example.com",
}
TEST_ADMIN_PAYLOAD = {
"sub": "admin-user-id",
"role": "admin",
"aud": "authenticated",
"email": "admin@example.com",
}
@pytest.fixture(autouse=True)
def mock_config(mocker: MockerFixture):
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": MOCK_JWT_SECRET}, clear=True)
mocker.patch.object(config, "_settings", Settings())
yield
def create_token(payload, secret=None, algorithm="HS256"):
"""Helper to create JWT tokens."""
if secret is None:
secret = MOCK_JWT_SECRET
return jwt.encode(payload, secret, algorithm=algorithm)
def test_parse_jwt_token_valid():
"""Test parsing a valid JWT token."""
token = create_token(TEST_USER_PAYLOAD)
result = jwt_utils.parse_jwt_token(token)
assert result["sub"] == "test-user-id"
assert result["role"] == "user"
assert result["aud"] == "authenticated"
def test_parse_jwt_token_expired():
"""Test parsing an expired JWT token."""
expired_payload = {
**TEST_USER_PAYLOAD,
"exp": datetime.now(timezone.utc) - timedelta(hours=1),
}
token = create_token(expired_payload)
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Token has expired" in str(exc_info.value)
def test_parse_jwt_token_invalid_signature():
"""Test parsing a token with invalid signature."""
# Create token with different secret
token = create_token(TEST_USER_PAYLOAD, secret="wrong-secret")
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Invalid token" in str(exc_info.value)
def test_parse_jwt_token_malformed():
"""Test parsing a malformed token."""
malformed_tokens = [
"not.a.token",
"invalid",
"",
# Header only
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9",
# No signature
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0",
]
for token in malformed_tokens:
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Invalid token" in str(exc_info.value)
def test_parse_jwt_token_wrong_audience():
"""Test parsing a token with wrong audience."""
wrong_aud_payload = {**TEST_USER_PAYLOAD, "aud": "wrong-audience"}
token = create_token(wrong_aud_payload)
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Invalid token" in str(exc_info.value)
def test_parse_jwt_token_missing_audience():
"""Test parsing a token without audience claim."""
no_aud_payload = {k: v for k, v in TEST_USER_PAYLOAD.items() if k != "aud"}
token = create_token(no_aud_payload)
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Invalid token" in str(exc_info.value)
def test_get_jwt_payload_with_valid_token():
"""Test extracting JWT payload with valid bearer token."""
token = create_token(TEST_USER_PAYLOAD)
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
result = jwt_utils.get_jwt_payload(credentials)
assert result["sub"] == "test-user-id"
assert result["role"] == "user"
def test_get_jwt_payload_no_credentials():
"""Test JWT payload when no credentials provided."""
with pytest.raises(HTTPException) as exc_info:
jwt_utils.get_jwt_payload(None)
assert exc_info.value.status_code == 401
assert "Authorization header is missing" in exc_info.value.detail
def test_get_jwt_payload_invalid_token():
"""Test JWT payload extraction with invalid token."""
credentials = HTTPAuthorizationCredentials(
scheme="Bearer", credentials="invalid.token.here"
)
with pytest.raises(HTTPException) as exc_info:
jwt_utils.get_jwt_payload(credentials)
assert exc_info.value.status_code == 401
assert "Invalid token" in exc_info.value.detail
def test_verify_user_with_valid_user():
"""Test verifying a valid user."""
user = jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=False)
assert isinstance(user, User)
assert user.user_id == "test-user-id"
assert user.role == "user"
assert user.email == "test@example.com"
def test_verify_user_with_admin():
"""Test verifying an admin user."""
user = jwt_utils.verify_user(TEST_ADMIN_PAYLOAD, admin_only=True)
assert isinstance(user, User)
assert user.user_id == "admin-user-id"
assert user.role == "admin"
def test_verify_user_admin_only_with_regular_user():
"""Test verifying regular user when admin is required."""
with pytest.raises(HTTPException) as exc_info:
jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=True)
assert exc_info.value.status_code == 403
assert "Admin access required" in exc_info.value.detail
def test_verify_user_no_payload():
"""Test verifying user with no payload."""
with pytest.raises(HTTPException) as exc_info:
jwt_utils.verify_user(None, admin_only=False)
assert exc_info.value.status_code == 401
assert "Authorization header is missing" in exc_info.value.detail
def test_verify_user_missing_sub():
"""Test verifying user with payload missing 'sub' field."""
invalid_payload = {"role": "user", "email": "test@example.com"}
with pytest.raises(HTTPException) as exc_info:
jwt_utils.verify_user(invalid_payload, admin_only=False)
assert exc_info.value.status_code == 401
assert "User ID not found in token" in exc_info.value.detail
def test_verify_user_empty_sub():
"""Test verifying user with empty 'sub' field."""
invalid_payload = {"sub": "", "role": "user"}
with pytest.raises(HTTPException) as exc_info:
jwt_utils.verify_user(invalid_payload, admin_only=False)
assert exc_info.value.status_code == 401
assert "User ID not found in token" in exc_info.value.detail
def test_verify_user_none_sub():
"""Test verifying user with None 'sub' field."""
invalid_payload = {"sub": None, "role": "user"}
with pytest.raises(HTTPException) as exc_info:
jwt_utils.verify_user(invalid_payload, admin_only=False)
assert exc_info.value.status_code == 401
assert "User ID not found in token" in exc_info.value.detail
def test_verify_user_missing_role_admin_check():
"""Test verifying admin when role field is missing."""
no_role_payload = {"sub": "user-id"}
with pytest.raises(KeyError):
# This will raise KeyError when checking payload["role"]
jwt_utils.verify_user(no_role_payload, admin_only=True)
# ======================== EDGE CASES ======================== #
def test_jwt_with_additional_claims():
"""Test JWT token with additional custom claims."""
extra_claims_payload = {
"sub": "user-id",
"role": "user",
"aud": "authenticated",
"custom_claim": "custom_value",
"permissions": ["read", "write"],
"metadata": {"key": "value"},
}
token = create_token(extra_claims_payload)
result = jwt_utils.parse_jwt_token(token)
assert result["sub"] == "user-id"
assert result["custom_claim"] == "custom_value"
assert result["permissions"] == ["read", "write"]
def test_jwt_with_numeric_sub():
"""Test JWT token with numeric user ID."""
payload = {
"sub": 12345, # Numeric ID
"role": "user",
"aud": "authenticated",
}
# Should convert to string internally
user = jwt_utils.verify_user(payload, admin_only=False)
assert user.user_id == 12345
def test_jwt_with_very_long_sub():
"""Test JWT token with very long user ID."""
long_id = "a" * 1000
payload = {
"sub": long_id,
"role": "user",
"aud": "authenticated",
}
user = jwt_utils.verify_user(payload, admin_only=False)
assert user.user_id == long_id
def test_jwt_with_special_characters_in_claims():
"""Test JWT token with special characters in claims."""
payload = {
"sub": "user@example.com/special-chars!@#$%",
"role": "admin",
"aud": "authenticated",
"email": "test+special@example.com",
}
user = jwt_utils.verify_user(payload, admin_only=True)
assert "special-chars!@#$%" in user.user_id
def test_jwt_with_future_iat():
"""Test JWT token with issued-at time in future."""
future_payload = {
"sub": "user-id",
"role": "user",
"aud": "authenticated",
"iat": datetime.now(timezone.utc) + timedelta(hours=1),
}
token = create_token(future_payload)
# PyJWT validates iat claim and should reject future tokens
with pytest.raises(ValueError, match="not yet valid"):
jwt_utils.parse_jwt_token(token)
def test_jwt_with_different_algorithms():
"""Test that only HS256 algorithm is accepted."""
payload = {
"sub": "user-id",
"role": "user",
"aud": "authenticated",
}
# Try different algorithms
algorithms = ["HS384", "HS512", "none"]
for algo in algorithms:
if algo == "none":
# Special case for 'none' algorithm (security vulnerability if accepted)
token = create_token(payload, "", algorithm="none")
else:
token = create_token(payload, algorithm=algo)
with pytest.raises(ValueError) as exc_info:
jwt_utils.parse_jwt_token(token)
assert "Invalid token" in str(exc_info.value)

View File

@@ -0,0 +1,140 @@
import inspect
import logging
import secrets
from typing import Any, Callable, Optional
from fastapi import HTTPException, Request, Security
from fastapi.security import APIKeyHeader, HTTPBearer
from starlette.status import HTTP_401_UNAUTHORIZED
from .config import settings
from .jwt_utils import parse_jwt_token
security = HTTPBearer()
logger = logging.getLogger(__name__)
async def auth_middleware(request: Request):
if not settings.ENABLE_AUTH:
# If authentication is disabled, allow the request to proceed
logger.warning("Auth disabled")
return {}
security = HTTPBearer()
credentials = await security(request)
if not credentials:
raise HTTPException(status_code=401, detail="Authorization header is missing")
try:
payload = parse_jwt_token(credentials.credentials)
request.state.user = payload
logger.debug("Token decoded successfully")
except ValueError as e:
raise HTTPException(status_code=401, detail=str(e))
return payload
class APIKeyValidator:
"""
Configurable API key validator that supports custom validation functions
for FastAPI applications.
This class provides a flexible way to implement API key authentication with optional
custom validation logic. It can be used for simple token matching
or more complex validation scenarios like database lookups.
Examples:
Simple token validation:
```python
validator = APIKeyValidator(
header_name="X-API-Key",
expected_token="your-secret-token"
)
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
def protected_endpoint():
return {"message": "Access granted"}
```
Custom validation with database lookup:
```python
async def validate_with_db(api_key: str):
api_key_obj = await db.get_api_key(api_key)
return api_key_obj if api_key_obj and api_key_obj.is_active else None
validator = APIKeyValidator(
header_name="X-API-Key",
validate_fn=validate_with_db
)
```
Args:
header_name (str): The name of the header containing the API key
expected_token (Optional[str]): The expected API key value for simple token matching
validate_fn (Optional[Callable]): Custom validation function that takes an API key
string and returns a boolean or object. Can be async.
error_status (int): HTTP status code to use for validation errors
error_message (str): Error message to return when validation fails
"""
def __init__(
self,
header_name: str,
expected_token: Optional[str] = None,
validate_fn: Optional[Callable[[str], bool]] = None,
error_status: int = HTTP_401_UNAUTHORIZED,
error_message: str = "Invalid API key",
):
# Create the APIKeyHeader as a class property
self.security_scheme = APIKeyHeader(name=header_name)
self.expected_token = expected_token
self.custom_validate_fn = validate_fn
self.error_status = error_status
self.error_message = error_message
async def default_validator(self, api_key: str) -> bool:
if not self.expected_token:
raise ValueError(
"Expected Token Required to be set when uisng API Key Validator default validation"
)
return secrets.compare_digest(api_key, self.expected_token)
async def __call__(
self, request: Request, api_key: str = Security(APIKeyHeader)
) -> Any:
if api_key is None:
raise HTTPException(status_code=self.error_status, detail="Missing API key")
# Use custom validation if provided, otherwise use default equality check
validator = self.custom_validate_fn or self.default_validator
result = (
await validator(api_key)
if inspect.iscoroutinefunction(validator)
else validator(api_key)
)
if not result:
raise HTTPException(
status_code=self.error_status, detail=self.error_message
)
# Store validation result in request state if it's not just a boolean
if result is not True:
request.state.api_key = result
return result
def get_dependency(self):
"""
Returns a callable dependency that FastAPI will recognize as a security scheme
"""
async def validate_api_key(
request: Request, api_key: str = Security(self.security_scheme)
) -> Any:
return await self(request, api_key)
# This helps FastAPI recognize it as a security dependency
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
return validate_api_key

View File

@@ -54,7 +54,7 @@ version = "1.2.0"
description = "Backport of asyncio.Runner, a context manager that controls event loop life cycle."
optional = false
python-versions = "<3.11,>=3.8"
groups = ["dev"]
groups = ["main"]
markers = "python_version < \"3.11\""
files = [
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
@@ -85,87 +85,6 @@ files = [
{file = "certifi-2025.7.14.tar.gz", hash = "sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995"},
]
[[package]]
name = "cffi"
version = "1.17.1"
description = "Foreign Function Interface for Python calling C code."
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "platform_python_implementation != \"PyPy\""
files = [
{file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"},
{file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"},
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382"},
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702"},
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3"},
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6"},
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17"},
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8"},
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e"},
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be"},
{file = "cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c"},
{file = "cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15"},
{file = "cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401"},
{file = "cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf"},
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4"},
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41"},
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1"},
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6"},
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d"},
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6"},
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f"},
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b"},
{file = "cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655"},
{file = "cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0"},
{file = "cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4"},
{file = "cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c"},
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36"},
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5"},
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff"},
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99"},
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93"},
{file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3"},
{file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8"},
{file = "cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65"},
{file = "cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903"},
{file = "cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e"},
{file = "cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2"},
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3"},
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683"},
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5"},
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4"},
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd"},
{file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed"},
{file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9"},
{file = "cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d"},
{file = "cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a"},
{file = "cffi-1.17.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b"},
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964"},
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9"},
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc"},
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c"},
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1"},
{file = "cffi-1.17.1-cp38-cp38-win32.whl", hash = "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8"},
{file = "cffi-1.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1"},
{file = "cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16"},
{file = "cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36"},
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8"},
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576"},
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87"},
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0"},
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3"},
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595"},
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a"},
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e"},
{file = "cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7"},
{file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"},
{file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"},
]
[package.dependencies]
pycparser = "*"
[[package]]
name = "charset-normalizer"
version = "3.4.2"
@@ -289,176 +208,12 @@ version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
groups = ["main", "dev"]
groups = ["main"]
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "coverage"
version = "7.10.5"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "coverage-7.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c6a5c3414bfc7451b879141ce772c546985163cf553f08e0f135f0699a911801"},
{file = "coverage-7.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bc8e4d99ce82f1710cc3c125adc30fd1487d3cf6c2cd4994d78d68a47b16989a"},
{file = "coverage-7.10.5-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:02252dc1216e512a9311f596b3169fad54abcb13827a8d76d5630c798a50a754"},
{file = "coverage-7.10.5-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:73269df37883e02d460bee0cc16be90509faea1e3bd105d77360b512d5bb9c33"},
{file = "coverage-7.10.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f8a81b0614642f91c9effd53eec284f965577591f51f547a1cbeb32035b4c2f"},
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6a29f8e0adb7f8c2b95fa2d4566a1d6e6722e0a637634c6563cb1ab844427dd9"},
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fcf6ab569436b4a647d4e91accba12509ad9f2554bc93d3aee23cc596e7f99c3"},
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:90dc3d6fb222b194a5de60af8d190bedeeddcbc7add317e4a3cd333ee6b7c879"},
{file = "coverage-7.10.5-cp310-cp310-win32.whl", hash = "sha256:414a568cd545f9dc75f0686a0049393de8098414b58ea071e03395505b73d7a8"},
{file = "coverage-7.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:e551f9d03347196271935fd3c0c165f0e8c049220280c1120de0084d65e9c7ff"},
{file = "coverage-7.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c177e6ffe2ebc7c410785307758ee21258aa8e8092b44d09a2da767834f075f2"},
{file = "coverage-7.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:14d6071c51ad0f703d6440827eaa46386169b5fdced42631d5a5ac419616046f"},
{file = "coverage-7.10.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:61f78c7c3bc272a410c5ae3fde7792b4ffb4acc03d35a7df73ca8978826bb7ab"},
{file = "coverage-7.10.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f39071caa126f69d63f99b324fb08c7b1da2ec28cbb1fe7b5b1799926492f65c"},
{file = "coverage-7.10.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343a023193f04d46edc46b2616cdbee68c94dd10208ecd3adc56fcc54ef2baa1"},
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:585ffe93ae5894d1ebdee69fc0b0d4b7c75d8007983692fb300ac98eed146f78"},
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b0ef4e66f006ed181df29b59921bd8fc7ed7cd6a9289295cd8b2824b49b570df"},
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eb7b0bbf7cc1d0453b843eca7b5fa017874735bef9bfdfa4121373d2cc885ed6"},
{file = "coverage-7.10.5-cp311-cp311-win32.whl", hash = "sha256:1d043a8a06987cc0c98516e57c4d3fc2c1591364831e9deb59c9e1b4937e8caf"},
{file = "coverage-7.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:fefafcca09c3ac56372ef64a40f5fe17c5592fab906e0fdffd09543f3012ba50"},
{file = "coverage-7.10.5-cp311-cp311-win_arm64.whl", hash = "sha256:7e78b767da8b5fc5b2faa69bb001edafcd6f3995b42a331c53ef9572c55ceb82"},
{file = "coverage-7.10.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c2d05c7e73c60a4cecc7d9b60dbfd603b4ebc0adafaef371445b47d0f805c8a9"},
{file = "coverage-7.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:32ddaa3b2c509778ed5373b177eb2bf5662405493baeff52278a0b4f9415188b"},
{file = "coverage-7.10.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:dd382410039fe062097aa0292ab6335a3f1e7af7bba2ef8d27dcda484918f20c"},
{file = "coverage-7.10.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7fa22800f3908df31cea6fb230f20ac49e343515d968cc3a42b30d5c3ebf9b5a"},
{file = "coverage-7.10.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f366a57ac81f5e12797136552f5b7502fa053c861a009b91b80ed51f2ce651c6"},
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5f1dc8f1980a272ad4a6c84cba7981792344dad33bf5869361576b7aef42733a"},
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:2285c04ee8676f7938b02b4936d9b9b672064daab3187c20f73a55f3d70e6b4a"},
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c2492e4dd9daab63f5f56286f8a04c51323d237631eb98505d87e4c4ff19ec34"},
{file = "coverage-7.10.5-cp312-cp312-win32.whl", hash = "sha256:38a9109c4ee8135d5df5505384fc2f20287a47ccbe0b3f04c53c9a1989c2bbaf"},
{file = "coverage-7.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:6b87f1ad60b30bc3c43c66afa7db6b22a3109902e28c5094957626a0143a001f"},
{file = "coverage-7.10.5-cp312-cp312-win_arm64.whl", hash = "sha256:672a6c1da5aea6c629819a0e1461e89d244f78d7b60c424ecf4f1f2556c041d8"},
{file = "coverage-7.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ef3b83594d933020f54cf65ea1f4405d1f4e41a009c46df629dd964fcb6e907c"},
{file = "coverage-7.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2b96bfdf7c0ea9faebce088a3ecb2382819da4fbc05c7b80040dbc428df6af44"},
{file = "coverage-7.10.5-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:63df1fdaffa42d914d5c4d293e838937638bf75c794cf20bee12978fc8c4e3bc"},
{file = "coverage-7.10.5-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8002dc6a049aac0e81ecec97abfb08c01ef0c1fbf962d0c98da3950ace89b869"},
{file = "coverage-7.10.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63d4bb2966d6f5f705a6b0c6784c8969c468dbc4bcf9d9ded8bff1c7e092451f"},
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1f672efc0731a6846b157389b6e6d5d5e9e59d1d1a23a5c66a99fd58339914d5"},
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:3f39cef43d08049e8afc1fde4a5da8510fc6be843f8dea350ee46e2a26b2f54c"},
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2968647e3ed5a6c019a419264386b013979ff1fb67dd11f5c9886c43d6a31fc2"},
{file = "coverage-7.10.5-cp313-cp313-win32.whl", hash = "sha256:0d511dda38595b2b6934c2b730a1fd57a3635c6aa2a04cb74714cdfdd53846f4"},
{file = "coverage-7.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:9a86281794a393513cf117177fd39c796b3f8e3759bb2764259a2abba5cce54b"},
{file = "coverage-7.10.5-cp313-cp313-win_arm64.whl", hash = "sha256:cebd8e906eb98bb09c10d1feed16096700b1198d482267f8bf0474e63a7b8d84"},
{file = "coverage-7.10.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0520dff502da5e09d0d20781df74d8189ab334a1e40d5bafe2efaa4158e2d9e7"},
{file = "coverage-7.10.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d9cd64aca68f503ed3f1f18c7c9174cbb797baba02ca8ab5112f9d1c0328cd4b"},
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0913dd1613a33b13c4f84aa6e3f4198c1a21ee28ccb4f674985c1f22109f0aae"},
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1b7181c0feeb06ed8a02da02792f42f829a7b29990fef52eff257fef0885d760"},
{file = "coverage-7.10.5-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36d42b7396b605f774d4372dd9c49bed71cbabce4ae1ccd074d155709dd8f235"},
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b4fdc777e05c4940b297bf47bf7eedd56a39a61dc23ba798e4b830d585486ca5"},
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:42144e8e346de44a6f1dbd0a56575dd8ab8dfa7e9007da02ea5b1c30ab33a7db"},
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:66c644cbd7aed8fe266d5917e2c9f65458a51cfe5eeff9c05f15b335f697066e"},
{file = "coverage-7.10.5-cp313-cp313t-win32.whl", hash = "sha256:2d1b73023854068c44b0c554578a4e1ef1b050ed07cf8b431549e624a29a66ee"},
{file = "coverage-7.10.5-cp313-cp313t-win_amd64.whl", hash = "sha256:54a1532c8a642d8cc0bd5a9a51f5a9dcc440294fd06e9dda55e743c5ec1a8f14"},
{file = "coverage-7.10.5-cp313-cp313t-win_arm64.whl", hash = "sha256:74d5b63fe3f5f5d372253a4ef92492c11a4305f3550631beaa432fc9df16fcff"},
{file = "coverage-7.10.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:68c5e0bc5f44f68053369fa0d94459c84548a77660a5f2561c5e5f1e3bed7031"},
{file = "coverage-7.10.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:cf33134ffae93865e32e1e37df043bef15a5e857d8caebc0099d225c579b0fa3"},
{file = "coverage-7.10.5-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ad8fa9d5193bafcf668231294241302b5e683a0518bf1e33a9a0dfb142ec3031"},
{file = "coverage-7.10.5-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:146fa1531973d38ab4b689bc764592fe6c2f913e7e80a39e7eeafd11f0ef6db2"},
{file = "coverage-7.10.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6013a37b8a4854c478d3219ee8bc2392dea51602dd0803a12d6f6182a0061762"},
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:eb90fe20db9c3d930fa2ad7a308207ab5b86bf6a76f54ab6a40be4012d88fcae"},
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:384b34482272e960c438703cafe63316dfbea124ac62006a455c8410bf2a2262"},
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:467dc74bd0a1a7de2bedf8deaf6811f43602cb532bd34d81ffd6038d6d8abe99"},
{file = "coverage-7.10.5-cp314-cp314-win32.whl", hash = "sha256:556d23d4e6393ca898b2e63a5bca91e9ac2d5fb13299ec286cd69a09a7187fde"},
{file = "coverage-7.10.5-cp314-cp314-win_amd64.whl", hash = "sha256:f4446a9547681533c8fa3e3c6cf62121eeee616e6a92bd9201c6edd91beffe13"},
{file = "coverage-7.10.5-cp314-cp314-win_arm64.whl", hash = "sha256:5e78bd9cf65da4c303bf663de0d73bf69f81e878bf72a94e9af67137c69b9fe9"},
{file = "coverage-7.10.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:5661bf987d91ec756a47c7e5df4fbcb949f39e32f9334ccd3f43233bbb65e508"},
{file = "coverage-7.10.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a46473129244db42a720439a26984f8c6f834762fc4573616c1f37f13994b357"},
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1f64b8d3415d60f24b058b58d859e9512624bdfa57a2d1f8aff93c1ec45c429b"},
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:44d43de99a9d90b20e0163f9770542357f58860a26e24dc1d924643bd6aa7cb4"},
{file = "coverage-7.10.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a931a87e5ddb6b6404e65443b742cb1c14959622777f2a4efd81fba84f5d91ba"},
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f9559b906a100029274448f4c8b8b0a127daa4dade5661dfd821b8c188058842"},
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:b08801e25e3b4526ef9ced1aa29344131a8f5213c60c03c18fe4c6170ffa2874"},
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ed9749bb8eda35f8b636fb7632f1c62f735a236a5d4edadd8bbcc5ea0542e732"},
{file = "coverage-7.10.5-cp314-cp314t-win32.whl", hash = "sha256:609b60d123fc2cc63ccee6d17e4676699075db72d14ac3c107cc4976d516f2df"},
{file = "coverage-7.10.5-cp314-cp314t-win_amd64.whl", hash = "sha256:0666cf3d2c1626b5a3463fd5b05f5e21f99e6aec40a3192eee4d07a15970b07f"},
{file = "coverage-7.10.5-cp314-cp314t-win_arm64.whl", hash = "sha256:bc85eb2d35e760120540afddd3044a5bf69118a91a296a8b3940dfc4fdcfe1e2"},
{file = "coverage-7.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:62835c1b00c4a4ace24c1a88561a5a59b612fbb83a525d1c70ff5720c97c0610"},
{file = "coverage-7.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5255b3bbcc1d32a4069d6403820ac8e6dbcc1d68cb28a60a1ebf17e47028e898"},
{file = "coverage-7.10.5-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3876385722e335d6e991c430302c24251ef9c2a9701b2b390f5473199b1b8ebf"},
{file = "coverage-7.10.5-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8048ce4b149c93447a55d279078c8ae98b08a6951a3c4d2d7e87f4efc7bfe100"},
{file = "coverage-7.10.5-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4028e7558e268dd8bcf4d9484aad393cafa654c24b4885f6f9474bf53183a82a"},
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03f47dc870eec0367fcdd603ca6a01517d2504e83dc18dbfafae37faec66129a"},
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2d488d7d42b6ded7ea0704884f89dcabd2619505457de8fc9a6011c62106f6e5"},
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b3dcf2ead47fa8be14224ee817dfc1df98043af568fe120a22f81c0eb3c34ad2"},
{file = "coverage-7.10.5-cp39-cp39-win32.whl", hash = "sha256:02650a11324b80057b8c9c29487020073d5e98a498f1857f37e3f9b6ea1b2426"},
{file = "coverage-7.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:b45264dd450a10f9e03237b41a9a24e85cbb1e278e5a32adb1a303f58f0017f3"},
{file = "coverage-7.10.5-py3-none-any.whl", hash = "sha256:0be24d35e4db1d23d0db5c0f6a74a962e2ec83c426b5cac09f4234aadef38e4a"},
{file = "coverage-7.10.5.tar.gz", hash = "sha256:f2e57716a78bc3ae80b2207be0709a3b2b63b9f2dcf9740ee6ac03588a2015b6"},
]
[package.dependencies]
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
[package.extras]
toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
[[package]]
name = "cryptography"
version = "45.0.6"
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
optional = false
python-versions = "!=3.9.0,!=3.9.1,>=3.7"
groups = ["main"]
files = [
{file = "cryptography-45.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:048e7ad9e08cf4c0ab07ff7f36cc3115924e22e2266e034450a890d9e312dd74"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:44647c5d796f5fc042bbc6d61307d04bf29bccb74d188f18051b635f20a9c75f"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e40b80ecf35ec265c452eea0ba94c9587ca763e739b8e559c128d23bff7ebbbf"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:00e8724bdad672d75e6f069b27970883179bd472cd24a63f6e620ca7e41cc0c5"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7a3085d1b319d35296176af31c90338eeb2ddac8104661df79f80e1d9787b8b2"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1b7fa6a1c1188c7ee32e47590d16a5a0646270921f8020efc9a511648e1b2e08"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:275ba5cc0d9e320cd70f8e7b96d9e59903c815ca579ab96c1e37278d231fc402"},
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f4028f29a9f38a2025abedb2e409973709c660d44319c61762202206ed577c42"},
{file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ee411a1b977f40bd075392c80c10b58025ee5c6b47a822a33c1198598a7a5f05"},
{file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e2a21a8eda2d86bb604934b6b37691585bd095c1f788530c1fcefc53a82b3453"},
{file = "cryptography-45.0.6-cp311-abi3-win32.whl", hash = "sha256:d063341378d7ee9c91f9d23b431a3502fc8bfacd54ef0a27baa72a0843b29159"},
{file = "cryptography-45.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:833dc32dfc1e39b7376a87b9a6a4288a10aae234631268486558920029b086ec"},
{file = "cryptography-45.0.6-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:3436128a60a5e5490603ab2adbabc8763613f638513ffa7d311c900a8349a2a0"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0d9ef57b6768d9fa58e92f4947cea96ade1233c0e236db22ba44748ffedca394"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea3c42f2016a5bbf71825537c2ad753f2870191134933196bee408aac397b3d9"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:20ae4906a13716139d6d762ceb3e0e7e110f7955f3bc3876e3a07f5daadec5f3"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2dac5ec199038b8e131365e2324c03d20e97fe214af051d20c49db129844e8b3"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:18f878a34b90d688982e43f4b700408b478102dd58b3e39de21b5ebf6509c301"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:5bd6020c80c5b2b2242d6c48487d7b85700f5e0038e67b29d706f98440d66eb5"},
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:eccddbd986e43014263eda489abbddfbc287af5cddfd690477993dbb31e31016"},
{file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:550ae02148206beb722cfe4ef0933f9352bab26b087af00e48fdfb9ade35c5b3"},
{file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5b64e668fc3528e77efa51ca70fadcd6610e8ab231e3e06ae2bab3b31c2b8ed9"},
{file = "cryptography-45.0.6-cp37-abi3-win32.whl", hash = "sha256:780c40fb751c7d2b0c6786ceee6b6f871e86e8718a8ff4bc35073ac353c7cd02"},
{file = "cryptography-45.0.6-cp37-abi3-win_amd64.whl", hash = "sha256:20d15aed3ee522faac1a39fbfdfee25d17b1284bafd808e1640a74846d7c4d1b"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:705bb7c7ecc3d79a50f236adda12ca331c8e7ecfbea51edd931ce5a7a7c4f012"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:826b46dae41a1155a0c0e66fafba43d0ede1dc16570b95e40c4d83bfcf0a451d"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:cc4d66f5dc4dc37b89cfef1bd5044387f7a1f6f0abb490815628501909332d5d"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:f68f833a9d445cc49f01097d95c83a850795921b3f7cc6488731e69bde3288da"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3b5bf5267e98661b9b888a9250d05b063220dfa917a8203744454573c7eb79db"},
{file = "cryptography-45.0.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2384f2ab18d9be88a6e4f8972923405e2dbb8d3e16c6b43f15ca491d7831bd18"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fc022c1fa5acff6def2fc6d7819bbbd31ccddfe67d075331a65d9cfb28a20983"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3de77e4df42ac8d4e4d6cdb342d989803ad37707cf8f3fbf7b088c9cbdd46427"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:599c8d7df950aa68baa7e98f7b73f4f414c9f02d0e8104a30c0182a07732638b"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:31a2b9a10530a1cb04ffd6aa1cd4d3be9ed49f7d77a4dafe198f3b382f41545c"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:e5b3dda1b00fb41da3af4c5ef3f922a200e33ee5ba0f0bc9ecf0b0c173958385"},
{file = "cryptography-45.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:629127cfdcdc6806dfe234734d7cb8ac54edaf572148274fa377a7d3405b0043"},
{file = "cryptography-45.0.6.tar.gz", hash = "sha256:5c966c732cf6e4a276ce83b6e4c729edda2df6929083a952cc7da973c539c719"},
]
[package.dependencies]
cffi = {version = ">=1.14", markers = "platform_python_implementation != \"PyPy\""}
[package.extras]
docs = ["sphinx (>=5.3.0)", "sphinx-inline-tabs ; python_full_version >= \"3.8.0\"", "sphinx-rtd-theme (>=3.0.0) ; python_full_version >= \"3.8.0\""]
docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"]
nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_full_version >= \"3.8.0\""]
pep8test = ["check-sdist ; python_full_version >= \"3.8.0\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"]
sdist = ["build (>=1.0.0)"]
ssh = ["bcrypt (>=3.1.5)"]
test = ["certifi (>=2024)", "cryptography-vectors (==45.0.6)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
test-randomorder = ["pytest-randomly"]
[[package]]
name = "deprecation"
version = "2.1.0"
@@ -480,7 +235,7 @@ version = "1.3.0"
description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
groups = ["main", "dev"]
groups = ["main"]
markers = "python_version < \"3.11\""
files = [
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
@@ -955,7 +710,7 @@ version = "2.1.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
groups = ["main"]
files = [
{file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"},
{file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
@@ -1024,7 +779,7 @@ version = "25.0"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
groups = ["main", "dev"]
groups = ["main"]
files = [
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
@@ -1036,7 +791,7 @@ version = "1.6.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
groups = ["main"]
files = [
{file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"},
{file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"},
@@ -1128,19 +883,6 @@ files = [
[package.dependencies]
pyasn1 = ">=0.6.1,<0.7.0"
[[package]]
name = "pycparser"
version = "2.22"
description = "C parser in Python"
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "platform_python_implementation != \"PyPy\""
files = [
{file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"},
{file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"},
]
[[package]]
name = "pydantic"
version = "2.11.7"
@@ -1305,7 +1047,7 @@ version = "2.19.2"
description = "Pygments is a syntax highlighting package written in Python."
optional = false
python-versions = ">=3.8"
groups = ["dev"]
groups = ["main"]
files = [
{file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"},
{file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"},
@@ -1326,9 +1068,6 @@ files = [
{file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"},
]
[package.dependencies]
cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"crypto\""}
[package.extras]
crypto = ["cryptography (>=3.4.0)"]
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"]
@@ -1353,7 +1092,7 @@ version = "8.4.1"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
groups = ["main"]
files = [
{file = "pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7"},
{file = "pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c"},
@@ -1377,7 +1116,7 @@ version = "1.1.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
groups = ["main"]
files = [
{file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"},
{file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"},
@@ -1391,33 +1130,13 @@ pytest = ">=8.2,<9"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-cov"
version = "6.2.1"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5"},
{file = "pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2"},
]
[package.dependencies]
coverage = {version = ">=7.5", extras = ["toml"]}
pluggy = ">=1.2"
pytest = ">=6.2.5"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]]
name = "pytest-mock"
version = "3.14.1"
description = "Thin-wrapper around the mock package for easier use with pytest"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
groups = ["main"]
files = [
{file = "pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0"},
{file = "pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e"},
@@ -1534,31 +1253,30 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.12.9"
version = "0.12.3"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "ruff-0.12.9-py3-none-linux_armv6l.whl", hash = "sha256:fcebc6c79fcae3f220d05585229463621f5dbf24d79fdc4936d9302e177cfa3e"},
{file = "ruff-0.12.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aed9d15f8c5755c0e74467731a007fcad41f19bcce41cd75f768bbd687f8535f"},
{file = "ruff-0.12.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5b15ea354c6ff0d7423814ba6d44be2807644d0c05e9ed60caca87e963e93f70"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d596c2d0393c2502eaabfef723bd74ca35348a8dac4267d18a94910087807c53"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b15599931a1a7a03c388b9c5df1bfa62be7ede6eb7ef753b272381f39c3d0ff"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d02faa2977fb6f3f32ddb7828e212b7dd499c59eb896ae6c03ea5c303575756"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:17d5b6b0b3a25259b69ebcba87908496e6830e03acfb929ef9fd4c58675fa2ea"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72db7521860e246adbb43f6ef464dd2a532ef2ef1f5dd0d470455b8d9f1773e0"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a03242c1522b4e0885af63320ad754d53983c9599157ee33e77d748363c561ce"},
{file = "ruff-0.12.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fc83e4e9751e6c13b5046d7162f205d0a7bac5840183c5beebf824b08a27340"},
{file = "ruff-0.12.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:881465ed56ba4dd26a691954650de6ad389a2d1fdb130fe51ff18a25639fe4bb"},
{file = "ruff-0.12.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:43f07a3ccfc62cdb4d3a3348bf0588358a66da756aa113e071b8ca8c3b9826af"},
{file = "ruff-0.12.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:07adb221c54b6bba24387911e5734357f042e5669fa5718920ee728aba3cbadc"},
{file = "ruff-0.12.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f5cd34fabfdea3933ab85d72359f118035882a01bff15bd1d2b15261d85d5f66"},
{file = "ruff-0.12.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:f6be1d2ca0686c54564da8e7ee9e25f93bdd6868263805f8c0b8fc6a449db6d7"},
{file = "ruff-0.12.9-py3-none-win32.whl", hash = "sha256:cc7a37bd2509974379d0115cc5608a1a4a6c4bff1b452ea69db83c8855d53f93"},
{file = "ruff-0.12.9-py3-none-win_amd64.whl", hash = "sha256:6fb15b1977309741d7d098c8a3cb7a30bc112760a00fb6efb7abc85f00ba5908"},
{file = "ruff-0.12.9-py3-none-win_arm64.whl", hash = "sha256:63c8c819739d86b96d500cce885956a1a48ab056bbcbc61b747ad494b2485089"},
{file = "ruff-0.12.9.tar.gz", hash = "sha256:fbd94b2e3c623f659962934e52c2bea6fc6da11f667a427a368adaf3af2c866a"},
{file = "ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2"},
{file = "ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041"},
{file = "ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882"},
{file = "ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901"},
{file = "ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0"},
{file = "ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6"},
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc"},
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687"},
{file = "ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e"},
{file = "ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311"},
{file = "ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07"},
{file = "ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12"},
{file = "ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b"},
{file = "ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f"},
{file = "ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d"},
{file = "ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7"},
{file = "ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1"},
{file = "ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77"},
]
[[package]]
@@ -1692,7 +1410,7 @@ version = "2.2.1"
description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
groups = ["main"]
markers = "python_version < \"3.11\""
files = [
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
@@ -1735,12 +1453,11 @@ version = "4.14.1"
description = "Backported and Experimental Type Hints for Python 3.9+"
optional = false
python-versions = ">=3.9"
groups = ["main", "dev"]
groups = ["main"]
files = [
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
]
markers = {dev = "python_version < \"3.11\""}
[[package]]
name = "typing-inspection"
@@ -1897,4 +1614,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
content-hash = "ef7818fba061cea2841c6d7ca4852acde83e4f73b32fca1315e58660002bb0d0"
content-hash = "f67db13e6f68b1d67a55eee908c1c560bfa44da8509f98f842889a7570a9830f"

View File

@@ -15,17 +15,15 @@ google-cloud-logging = "^3.12.1"
launchdarkly-server-sdk = "^9.12.0"
pydantic = "^2.11.7"
pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
pyjwt = "^2.10.1"
pytest-asyncio = "^1.1.0"
pytest-mock = "^3.14.1"
redis = "^6.2.0"
supabase = "^2.16.0"
uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies]
ruff = "^0.12.9"
pytest = "^8.4.1"
pytest-asyncio = "^1.1.0"
pytest-mock = "^3.14.1"
pytest-cov = "^6.2.1"
ruff = "^0.12.3"
[build-system]
requires = ["poetry-core"]

View File

@@ -16,6 +16,7 @@ DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
PRISMA_SCHEMA="postgres/schema.prisma"
ENABLE_AUTH=true
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
# Redis Configuration
@@ -30,7 +31,7 @@ RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
# Supabase Authentication
SUPABASE_URL=http://localhost:8000
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
## ===== REQUIRED SECURITY KEYS ===== ##
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
@@ -105,15 +106,6 @@ TODOIST_CLIENT_SECRET=
NOTION_CLIENT_ID=
NOTION_CLIENT_SECRET=
# Discord OAuth App credentials
# 1. Go to https://discord.com/developers/applications
# 2. Create a new application
# 3. Go to OAuth2 section and add redirect URI: http://localhost:3000/auth/integrations/oauth_callback
# 4. Copy Client ID and Client Secret below
DISCORD_CLIENT_ID=
DISCORD_CLIENT_SECRET=
REDDIT_CLIENT_ID=
REDDIT_CLIENT_SECRET=
@@ -174,4 +166,4 @@ SMARTLEAD_API_KEY=
ZEROBOUNCE_API_KEY=
# Other Services
AUTOMOD_API_KEY=
AUTOMOD_API_KEY=

View File

@@ -1,21 +1,16 @@
FROM debian:13-slim AS builder
FROM python:3.11.10-slim-bookworm AS builder
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
WORKDIR /app
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
# Update package list and install Python and build dependencies
# Update package list and install build dependencies in a single layer
RUN apt-get update --allow-releaseinfo-change --fix-missing \
&& apt-get install -y \
python3.13 \
python3.13-dev \
python3.13-venv \
python3-pip \
build-essential \
libpq5 \
libz-dev \
@@ -24,11 +19,13 @@ RUN apt-get update --allow-releaseinfo-change --fix-missing \
ENV POETRY_HOME=/opt/poetry
ENV POETRY_NO_INTERACTION=1
ENV POETRY_VIRTUALENVS_CREATE=true
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
ENV POETRY_VIRTUALENVS_CREATE=false
ENV PATH=/opt/poetry/bin:$PATH
RUN pip3 install poetry --break-system-packages
# Upgrade pip and setuptools to fix security vulnerabilities
RUN pip3 install --upgrade pip setuptools
RUN pip3 install poetry
# Copy and install dependencies
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
@@ -40,30 +37,27 @@ RUN poetry install --no-ansi --no-root
COPY autogpt_platform/backend/schema.prisma ./
RUN poetry run prisma generate
FROM debian:13-slim AS server_dependencies
FROM python:3.11.10-slim-bookworm AS server_dependencies
WORKDIR /app
ENV POETRY_HOME=/opt/poetry \
POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=true \
POETRY_VIRTUALENVS_IN_PROJECT=true \
DEBIAN_FRONTEND=noninteractive
POETRY_VIRTUALENVS_CREATE=false
ENV PATH=/opt/poetry/bin:$PATH
# Install Python without upgrading system-managed packages
RUN apt-get update && apt-get install -y \
python3.13 \
python3-pip
# Upgrade pip and setuptools to fix security vulnerabilities
RUN pip3 install --upgrade pip setuptools
# Copy only necessary files from builder
COPY --from=builder /app /app
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
COPY --from=builder /usr/local/lib/python3.11 /usr/local/lib/python3.11
COPY --from=builder /usr/local/bin /usr/local/bin
# Copy Prisma binaries
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
ENV PATH="/app/.venv/bin:$PATH"
RUN mkdir -p /app/autogpt_platform/autogpt_libs
RUN mkdir -p /app/autogpt_platform/backend

View File

@@ -132,58 +132,17 @@ def test_endpoint_success(snapshot: Snapshot):
### Testing with Authentication
For the main API routes that use JWT authentication, auth is provided by the `autogpt_libs.auth` module. If the test actually uses the `user_id`, the recommended approach for testing is to mock the `get_jwt_payload` function, which underpins all higher-level auth functions used in the API (`requires_user`, `requires_admin_user`, `get_user_id`).
If the test doesn't need the `user_id` specifically, mocking is not necessary as during tests auth is disabled anyway (see `conftest.py`).
#### Using Global Auth Fixtures
Two global auth fixtures are provided by `backend/server/conftest.py`:
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
These provide the easiest way to set up authentication mocking in test modules:
```python
import fastapi
import fastapi.testclient
import pytest
from backend.server.v2.myroute import router
def override_auth_middleware():
return {"sub": "test-user-id"}
app = fastapi.FastAPI()
app.include_router(router)
client = fastapi.testclient.TestClient(app)
def override_get_user_id():
return "test-user-id"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user['get_jwt_payload']
yield
app.dependency_overrides.clear()
app.dependency_overrides[auth_middleware] = override_auth_middleware
app.dependency_overrides[get_user_id] = override_get_user_id
```
For admin-only endpoints, use `mock_jwt_admin` instead:
```python
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_admin):
"""Setup auth overrides for admin tests"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin['get_jwt_payload']
yield
app.dependency_overrides.clear()
```
The IDs are also available separately as fixtures:
- `test_user_id`
- `admin_user_id`
- `target_user_id` (for admin <-> user operations)
### Mocking External Services
```python
@@ -194,10 +153,10 @@ def test_external_api_call(mocker, snapshot):
"backend.services.external_api.call",
return_value=mock_response
)
response = client.post("/api/process")
assert response.status_code == 200
snapshot.snapshot_dir = "snapshots"
snapshot.assert_match(
json.dumps(response.json(), indent=2, sort_keys=True),
@@ -228,17 +187,6 @@ def test_external_api_call(mocker, snapshot):
- Use `async def` with `@pytest.mark.asyncio` for testing async functions directly
### 5. Fixtures
#### Global Fixtures (conftest.py)
Authentication fixtures are available globally from `conftest.py`:
- `mock_jwt_user` - Standard user authentication
- `mock_jwt_admin` - Admin user authentication
- `configured_snapshot` - Pre-configured snapshot fixture
#### Custom Fixtures
Create reusable fixtures for common test data:
```python
@@ -254,18 +202,9 @@ def test_create_user(sample_user, snapshot):
# ... test implementation
```
#### Test Isolation
All tests must use fixtures that ensure proper isolation:
- Authentication overrides are automatically cleaned up after each test
- Database connections are properly managed with cleanup
- Mock objects are reset between tests
## CI/CD Integration
The GitHub Actions workflow automatically runs tests on:
- Pull requests
- Pushes to main branch
@@ -277,19 +216,16 @@ Snapshot tests work in CI by:
## Troubleshooting
### Snapshot Mismatches
- Review the diff carefully
- If changes are expected: `poetry run pytest --snapshot-update`
- If changes are unexpected: Fix the code causing the difference
### Async Test Issues
- Ensure async functions use `@pytest.mark.asyncio`
- Use `AsyncMock` for mocking async functions
- FastAPI TestClient handles async automatically
### Import Errors
- Check that all dependencies are in `pyproject.toml`
- Run `poetry install` to ensure dependencies are installed
- Verify import paths are correct
@@ -298,4 +234,4 @@ Snapshot tests work in CI by:
Snapshot testing provides a powerful way to ensure API responses remain consistent. Combined with traditional assertions, it creates a robust test suite that catches regressions while remaining maintainable.
Remember: Good tests are as important as good code!
Remember: Good tests are as important as good code!

View File

@@ -25,9 +25,6 @@ class AgentExecutorBlock(Block):
user_id: str = SchemaField(description="User ID")
graph_id: str = SchemaField(description="Graph ID")
graph_version: int = SchemaField(description="Graph Version")
agent_name: Optional[str] = SchemaField(
default=None, description="Name to display in the Builder UI"
)
inputs: BlockInput = SchemaField(description="Input data for the graph")
input_schema: dict = SchemaField(description="Input schema for the graph")

View File

@@ -166,7 +166,7 @@ class AIMusicGeneratorBlock(Block):
output_format=input_data.output_format,
normalization_strategy=input_data.normalization_strategy,
)
if result and isinstance(result, str) and result.startswith("http"):
if result and result != "No output received":
yield "result", result
return
else:

View File

@@ -1,205 +0,0 @@
"""
Meeting BaaS API client module.
All API calls centralized for consistency and maintainability.
"""
from typing import Any, Dict, List, Optional
from backend.sdk import Requests
class MeetingBaasAPI:
"""Client for Meeting BaaS API endpoints."""
BASE_URL = "https://api.meetingbaas.com"
def __init__(self, api_key: str):
"""Initialize API client with authentication key."""
self.api_key = api_key
self.headers = {"x-meeting-baas-api-key": api_key}
self.requests = Requests()
# Bot Management Endpoints
async def join_meeting(
self,
bot_name: str,
meeting_url: str,
reserved: bool = False,
bot_image: Optional[str] = None,
entry_message: Optional[str] = None,
start_time: Optional[int] = None,
speech_to_text: Optional[Dict[str, Any]] = None,
webhook_url: Optional[str] = None,
automatic_leave: Optional[Dict[str, Any]] = None,
extra: Optional[Dict[str, Any]] = None,
recording_mode: str = "speaker_view",
streaming: Optional[Dict[str, Any]] = None,
deduplication_key: Optional[str] = None,
zoom_sdk_id: Optional[str] = None,
zoom_sdk_pwd: Optional[str] = None,
) -> Dict[str, Any]:
"""
Deploy a bot to join and record a meeting.
POST /bots
"""
body = {
"bot_name": bot_name,
"meeting_url": meeting_url,
"reserved": reserved,
"recording_mode": recording_mode,
}
# Add optional fields if provided
if bot_image is not None:
body["bot_image"] = bot_image
if entry_message is not None:
body["entry_message"] = entry_message
if start_time is not None:
body["start_time"] = start_time
if speech_to_text is not None:
body["speech_to_text"] = speech_to_text
if webhook_url is not None:
body["webhook_url"] = webhook_url
if automatic_leave is not None:
body["automatic_leave"] = automatic_leave
if extra is not None:
body["extra"] = extra
if streaming is not None:
body["streaming"] = streaming
if deduplication_key is not None:
body["deduplication_key"] = deduplication_key
if zoom_sdk_id is not None:
body["zoom_sdk_id"] = zoom_sdk_id
if zoom_sdk_pwd is not None:
body["zoom_sdk_pwd"] = zoom_sdk_pwd
response = await self.requests.post(
f"{self.BASE_URL}/bots",
headers=self.headers,
json=body,
)
return response.json()
async def leave_meeting(self, bot_id: str) -> bool:
"""
Remove a bot from an ongoing meeting.
DELETE /bots/{uuid}
"""
response = await self.requests.delete(
f"{self.BASE_URL}/bots/{bot_id}",
headers=self.headers,
)
return response.status in [200, 204]
async def retranscribe(
self,
bot_uuid: str,
speech_to_text: Optional[Dict[str, Any]] = None,
webhook_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
Re-run transcription on a bot's audio.
POST /bots/retranscribe
"""
body: Dict[str, Any] = {"bot_uuid": bot_uuid}
if speech_to_text is not None:
body["speech_to_text"] = speech_to_text
if webhook_url is not None:
body["webhook_url"] = webhook_url
response = await self.requests.post(
f"{self.BASE_URL}/bots/retranscribe",
headers=self.headers,
json=body,
)
if response.status == 202:
return {"accepted": True}
return response.json()
# Data Retrieval Endpoints
async def get_meeting_data(
self, bot_id: str, include_transcripts: bool = True
) -> Dict[str, Any]:
"""
Retrieve meeting data including recording and transcripts.
GET /bots/meeting_data
"""
params = {
"bot_id": bot_id,
"include_transcripts": str(include_transcripts).lower(),
}
response = await self.requests.get(
f"{self.BASE_URL}/bots/meeting_data",
headers=self.headers,
params=params,
)
return response.json()
async def get_screenshots(self, bot_id: str) -> List[Dict[str, Any]]:
"""
Retrieve screenshots captured during a meeting.
GET /bots/{uuid}/screenshots
"""
response = await self.requests.get(
f"{self.BASE_URL}/bots/{bot_id}/screenshots",
headers=self.headers,
)
result = response.json()
# Ensure we return a list
if isinstance(result, list):
return result
return []
async def delete_data(self, bot_id: str) -> bool:
"""
Delete a bot's recorded data.
POST /bots/{uuid}/delete_data
"""
response = await self.requests.post(
f"{self.BASE_URL}/bots/{bot_id}/delete_data",
headers=self.headers,
)
return response.status == 200
async def list_bots_with_metadata(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
sort_by: Optional[str] = None,
sort_order: Optional[str] = None,
filter_by: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
List bots with metadata including IDs, names, and meeting details.
GET /bots/bots_with_metadata
"""
params = {}
if limit is not None:
params["limit"] = limit
if offset is not None:
params["offset"] = offset
if sort_by is not None:
params["sort_by"] = sort_by
if sort_order is not None:
params["sort_order"] = sort_order
if filter_by is not None:
params.update(filter_by)
response = await self.requests.get(
f"{self.BASE_URL}/bots/bots_with_metadata",
headers=self.headers,
params=params,
)
return response.json()

View File

@@ -1,13 +0,0 @@
"""
Shared configuration for all Meeting BaaS blocks using the SDK pattern.
"""
from backend.sdk import BlockCostType, ProviderBuilder
# Configure the Meeting BaaS provider with API key authentication
baas = (
ProviderBuilder("baas")
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
.build()
)

View File

@@ -1,217 +0,0 @@
"""
Meeting BaaS bot (recording) blocks.
"""
from typing import Optional
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
from ._api import MeetingBaasAPI
from ._config import baas
class BaasBotJoinMeetingBlock(Block):
"""
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = baas.credentials_field(
description="Meeting BaaS API credentials"
)
meeting_url: str = SchemaField(
description="The URL of the meeting the bot should join"
)
bot_name: str = SchemaField(
description="Display name for the bot in the meeting"
)
bot_image: str = SchemaField(
description="URL to an image for the bot's avatar (16:9 ratio recommended)",
default="",
)
entry_message: str = SchemaField(
description="Chat message the bot will post upon entry", default=""
)
reserved: bool = SchemaField(
description="Use a reserved bot slot (joins 4 min before meeting)",
default=False,
)
start_time: Optional[int] = SchemaField(
description="Unix timestamp (ms) when bot should join", default=None
)
webhook_url: str | None = SchemaField(
description="URL to receive webhook events for this bot", default=None
)
timeouts: dict = SchemaField(
description="Automatic leave timeouts configuration", default={}
)
extra: dict = SchemaField(
description="Custom metadata to attach to the bot", default={}
)
class Output(BlockSchema):
bot_id: str = SchemaField(description="UUID of the deployed bot")
join_response: dict = SchemaField(
description="Full response from join operation"
)
def __init__(self):
super().__init__(
id="377d1a6a-a99b-46cf-9af3-1d1b12758e04",
description="Deploy a bot to join and record a meeting",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
api_key = credentials.api_key.get_secret_value()
api = MeetingBaasAPI(api_key)
# Call API with all parameters
data = await api.join_meeting(
bot_name=input_data.bot_name,
meeting_url=input_data.meeting_url,
reserved=input_data.reserved,
bot_image=input_data.bot_image if input_data.bot_image else None,
entry_message=(
input_data.entry_message if input_data.entry_message else None
),
start_time=input_data.start_time,
speech_to_text={"provider": "Default"},
webhook_url=input_data.webhook_url if input_data.webhook_url else None,
automatic_leave=input_data.timeouts if input_data.timeouts else None,
extra=input_data.extra if input_data.extra else None,
)
yield "bot_id", data.get("bot_id", "")
yield "join_response", data
class BaasBotLeaveMeetingBlock(Block):
"""
Force the bot to exit the call.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = baas.credentials_field(
description="Meeting BaaS API credentials"
)
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
class Output(BlockSchema):
left: bool = SchemaField(description="Whether the bot successfully left")
def __init__(self):
super().__init__(
id="bf77d128-8b25-4280-b5c7-2d553ba7e482",
description="Remove a bot from an ongoing meeting",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
api_key = credentials.api_key.get_secret_value()
api = MeetingBaasAPI(api_key)
# Leave meeting
left = await api.leave_meeting(input_data.bot_id)
yield "left", left
class BaasBotFetchMeetingDataBlock(Block):
"""
Pull MP4 URL, transcript & metadata for a completed meeting.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = baas.credentials_field(
description="Meeting BaaS API credentials"
)
bot_id: str = SchemaField(description="UUID of the bot whose data to fetch")
include_transcripts: bool = SchemaField(
description="Include transcript data in response", default=True
)
class Output(BlockSchema):
mp4_url: str = SchemaField(
description="URL to download the meeting recording (time-limited)"
)
transcript: list = SchemaField(description="Meeting transcript data")
metadata: dict = SchemaField(description="Meeting metadata and bot information")
def __init__(self):
super().__init__(
id="ea7c1309-303c-4da1-893f-89c0e9d64e78",
description="Retrieve recorded meeting data",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
api_key = credentials.api_key.get_secret_value()
api = MeetingBaasAPI(api_key)
# Fetch meeting data
data = await api.get_meeting_data(
bot_id=input_data.bot_id,
include_transcripts=input_data.include_transcripts,
)
yield "mp4_url", data.get("mp4", "")
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
yield "metadata", data.get("bot_data", {}).get("bot", {})
class BaasBotDeleteRecordingBlock(Block):
"""
Purge MP4 + transcript data for privacy or storage management.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = baas.credentials_field(
description="Meeting BaaS API credentials"
)
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
class Output(BlockSchema):
deleted: bool = SchemaField(
description="Whether the data was successfully deleted"
)
def __init__(self):
super().__init__(
id="bf8d1aa6-42d8-4944-b6bd-6bac554c0d3b",
description="Permanently delete a meeting's recorded data",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
api_key = credentials.api_key.get_secret_value()
api = MeetingBaasAPI(api_key)
# Delete recording data
deleted = await api.delete_data(input_data.bot_id)
yield "deleted", deleted

View File

@@ -1,178 +0,0 @@
"""
DataForSEO API client with async support using the SDK patterns.
"""
import base64
from typing import Any, Dict, List, Optional
from backend.sdk import Requests, UserPasswordCredentials
class DataForSeoClient:
"""Client for the DataForSEO API using async requests."""
API_URL = "https://api.dataforseo.com"
def __init__(self, credentials: UserPasswordCredentials):
self.credentials = credentials
self.requests = Requests(
trusted_origins=["https://api.dataforseo.com"],
raise_for_status=False,
)
def _get_headers(self) -> Dict[str, str]:
"""Generate the authorization header using Basic Auth."""
username = self.credentials.username.get_secret_value()
password = self.credentials.password.get_secret_value()
credentials_str = f"{username}:{password}"
encoded = base64.b64encode(credentials_str.encode("ascii")).decode("ascii")
return {
"Authorization": f"Basic {encoded}",
"Content-Type": "application/json",
}
async def keyword_suggestions(
self,
keyword: str,
location_code: Optional[int] = None,
language_code: Optional[str] = None,
include_seed_keyword: bool = True,
include_serp_info: bool = False,
include_clickstream_data: bool = False,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""
Get keyword suggestions from DataForSEO Labs.
Args:
keyword: Seed keyword
location_code: Location code for targeting
language_code: Language code (e.g., "en")
include_seed_keyword: Include seed keyword in results
include_serp_info: Include SERP data
include_clickstream_data: Include clickstream metrics
limit: Maximum number of results (up to 3000)
Returns:
API response with keyword suggestions
"""
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/keyword_suggestions/live"
# Build payload only with non-None values to avoid sending null fields
task_data: dict[str, Any] = {
"keyword": keyword,
}
if location_code is not None:
task_data["location_code"] = location_code
if language_code is not None:
task_data["language_code"] = language_code
if include_seed_keyword is not None:
task_data["include_seed_keyword"] = include_seed_keyword
if include_serp_info is not None:
task_data["include_serp_info"] = include_serp_info
if include_clickstream_data is not None:
task_data["include_clickstream_data"] = include_clickstream_data
if limit is not None:
task_data["limit"] = limit
payload = [task_data]
response = await self.requests.post(
endpoint,
headers=self._get_headers(),
json=payload,
)
data = response.json()
# Check for API errors
if response.status != 200:
error_message = data.get("status_message", "Unknown error")
raise Exception(
f"DataForSEO API error ({response.status}): {error_message}"
)
# Extract the results from the response
if data.get("tasks") and len(data["tasks"]) > 0:
task = data["tasks"][0]
if task.get("status_code") == 20000: # Success code
return task.get("result", [])
else:
error_msg = task.get("status_message", "Task failed")
raise Exception(f"DataForSEO task error: {error_msg}")
return []
async def related_keywords(
self,
keyword: str,
location_code: Optional[int] = None,
language_code: Optional[str] = None,
include_seed_keyword: bool = True,
include_serp_info: bool = False,
include_clickstream_data: bool = False,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""
Get related keywords from DataForSEO Labs.
Args:
keyword: Seed keyword
location_code: Location code for targeting
language_code: Language code (e.g., "en")
include_seed_keyword: Include seed keyword in results
include_serp_info: Include SERP data
include_clickstream_data: Include clickstream metrics
limit: Maximum number of results (up to 3000)
Returns:
API response with related keywords
"""
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/related_keywords/live"
# Build payload only with non-None values to avoid sending null fields
task_data: dict[str, Any] = {
"keyword": keyword,
}
if location_code is not None:
task_data["location_code"] = location_code
if language_code is not None:
task_data["language_code"] = language_code
if include_seed_keyword is not None:
task_data["include_seed_keyword"] = include_seed_keyword
if include_serp_info is not None:
task_data["include_serp_info"] = include_serp_info
if include_clickstream_data is not None:
task_data["include_clickstream_data"] = include_clickstream_data
if limit is not None:
task_data["limit"] = limit
payload = [task_data]
response = await self.requests.post(
endpoint,
headers=self._get_headers(),
json=payload,
)
data = response.json()
# Check for API errors
if response.status != 200:
error_message = data.get("status_message", "Unknown error")
raise Exception(
f"DataForSEO API error ({response.status}): {error_message}"
)
# Extract the results from the response
if data.get("tasks") and len(data["tasks"]) > 0:
task = data["tasks"][0]
if task.get("status_code") == 20000: # Success code
return task.get("result", [])
else:
error_msg = task.get("status_message", "Task failed")
raise Exception(f"DataForSEO task error: {error_msg}")
return []

View File

@@ -1,17 +0,0 @@
"""
Configuration for all DataForSEO blocks using the new SDK pattern.
"""
from backend.sdk import BlockCostType, ProviderBuilder
# Build the DataForSEO provider with username/password authentication
dataforseo = (
ProviderBuilder("dataforseo")
.with_user_password(
username_env_var="DATAFORSEO_USERNAME",
password_env_var="DATAFORSEO_PASSWORD",
title="DataForSEO Credentials",
)
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -1,273 +0,0 @@
"""
DataForSEO Google Keyword Suggestions block.
"""
from typing import Any, Dict, List, Optional
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from ._api import DataForSeoClient
from ._config import dataforseo
class KeywordSuggestion(BlockSchema):
"""Schema for a keyword suggestion result."""
keyword: str = SchemaField(description="The keyword suggestion")
search_volume: Optional[int] = SchemaField(
description="Monthly search volume", default=None
)
competition: Optional[float] = SchemaField(
description="Competition level (0-1)", default=None
)
cpc: Optional[float] = SchemaField(
description="Cost per click in USD", default=None
)
keyword_difficulty: Optional[int] = SchemaField(
description="Keyword difficulty score", default=None
)
serp_info: Optional[Dict[str, Any]] = SchemaField(
description="data from SERP for each keyword", default=None
)
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
description="Clickstream data metrics", default=None
)
class DataForSeoKeywordSuggestionsBlock(Block):
"""Block for getting keyword suggestions from DataForSEO Labs."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = dataforseo.credentials_field(
description="DataForSEO credentials (username and password)"
)
keyword: str = SchemaField(description="Seed keyword to get suggestions for")
location_code: Optional[int] = SchemaField(
description="Location code for targeting (e.g., 2840 for USA)",
default=2840, # USA
)
language_code: Optional[str] = SchemaField(
description="Language code (e.g., 'en' for English)",
default="en",
)
include_seed_keyword: bool = SchemaField(
description="Include the seed keyword in results",
default=True,
)
include_serp_info: bool = SchemaField(
description="Include SERP information",
default=False,
)
include_clickstream_data: bool = SchemaField(
description="Include clickstream metrics",
default=False,
)
limit: int = SchemaField(
description="Maximum number of results (up to 3000)",
default=100,
ge=1,
le=3000,
)
class Output(BlockSchema):
suggestions: List[KeywordSuggestion] = SchemaField(
description="List of keyword suggestions with metrics"
)
suggestion: KeywordSuggestion = SchemaField(
description="A single keyword suggestion with metrics"
)
total_count: int = SchemaField(
description="Total number of suggestions returned"
)
seed_keyword: str = SchemaField(
description="The seed keyword used for the query"
)
def __init__(self):
super().__init__(
id="73c3e7c4-2b3f-4e9f-9e3e-8f7a5c3e2d45",
description="Get keyword suggestions from DataForSEO Labs Google API",
categories={BlockCategory.SEARCH, BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"credentials": dataforseo.get_test_credentials().model_dump(),
"keyword": "digital marketing",
"location_code": 2840,
"language_code": "en",
"limit": 1,
},
test_credentials=dataforseo.get_test_credentials(),
test_output=[
(
"suggestion",
lambda x: hasattr(x, "keyword")
and x.keyword == "digital marketing strategy",
),
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
("total_count", 1),
("seed_keyword", "digital marketing"),
],
test_mock={
"_fetch_keyword_suggestions": lambda *args, **kwargs: [
{
"items": [
{
"keyword": "digital marketing strategy",
"keyword_info": {
"search_volume": 10000,
"competition": 0.5,
"cpc": 2.5,
},
"keyword_properties": {
"keyword_difficulty": 50,
},
}
]
}
]
},
)
async def _fetch_keyword_suggestions(
self,
client: DataForSeoClient,
input_data: Input,
) -> Any:
"""Private method to fetch keyword suggestions - can be mocked for testing."""
return await client.keyword_suggestions(
keyword=input_data.keyword,
location_code=input_data.location_code,
language_code=input_data.language_code,
include_seed_keyword=input_data.include_seed_keyword,
include_serp_info=input_data.include_serp_info,
include_clickstream_data=input_data.include_clickstream_data,
limit=input_data.limit,
)
async def run(
self,
input_data: Input,
*,
credentials: UserPasswordCredentials,
**kwargs,
) -> BlockOutput:
"""Execute the keyword suggestions query."""
client = DataForSeoClient(credentials)
results = await self._fetch_keyword_suggestions(client, input_data)
# Process and format the results
suggestions = []
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
items = (
first_result.get("items", []) if isinstance(first_result, dict) else []
)
for item in items:
# Create the KeywordSuggestion object
suggestion = KeywordSuggestion(
keyword=item.get("keyword", ""),
search_volume=item.get("keyword_info", {}).get("search_volume"),
competition=item.get("keyword_info", {}).get("competition"),
cpc=item.get("keyword_info", {}).get("cpc"),
keyword_difficulty=item.get("keyword_properties", {}).get(
"keyword_difficulty"
),
serp_info=(
item.get("serp_info") if input_data.include_serp_info else None
),
clickstream_data=(
item.get("clickstream_keyword_info")
if input_data.include_clickstream_data
else None
),
)
yield "suggestion", suggestion
suggestions.append(suggestion)
yield "suggestions", suggestions
yield "total_count", len(suggestions)
yield "seed_keyword", input_data.keyword
class KeywordSuggestionExtractorBlock(Block):
"""Extracts individual fields from a KeywordSuggestion object."""
class Input(BlockSchema):
suggestion: KeywordSuggestion = SchemaField(
description="The keyword suggestion object to extract fields from"
)
class Output(BlockSchema):
keyword: str = SchemaField(description="The keyword suggestion")
search_volume: Optional[int] = SchemaField(
description="Monthly search volume", default=None
)
competition: Optional[float] = SchemaField(
description="Competition level (0-1)", default=None
)
cpc: Optional[float] = SchemaField(
description="Cost per click in USD", default=None
)
keyword_difficulty: Optional[int] = SchemaField(
description="Keyword difficulty score", default=None
)
serp_info: Optional[Dict[str, Any]] = SchemaField(
description="data from SERP for each keyword", default=None
)
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
description="Clickstream data metrics", default=None
)
def __init__(self):
super().__init__(
id="4193cb94-677c-48b0-9eec-6ac72fffd0f2",
description="Extract individual fields from a KeywordSuggestion object",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"suggestion": KeywordSuggestion(
keyword="test keyword",
search_volume=1000,
competition=0.5,
cpc=2.5,
keyword_difficulty=60,
).model_dump()
},
test_output=[
("keyword", "test keyword"),
("search_volume", 1000),
("competition", 0.5),
("cpc", 2.5),
("keyword_difficulty", 60),
("serp_info", None),
("clickstream_data", None),
],
)
async def run(
self,
input_data: Input,
**kwargs,
) -> BlockOutput:
"""Extract fields from the KeywordSuggestion object."""
suggestion = input_data.suggestion
yield "keyword", suggestion.keyword
yield "search_volume", suggestion.search_volume
yield "competition", suggestion.competition
yield "cpc", suggestion.cpc
yield "keyword_difficulty", suggestion.keyword_difficulty
yield "serp_info", suggestion.serp_info
yield "clickstream_data", suggestion.clickstream_data

View File

@@ -1,283 +0,0 @@
"""
DataForSEO Google Related Keywords block.
"""
from typing import Any, Dict, List, Optional
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from ._api import DataForSeoClient
from ._config import dataforseo
class RelatedKeyword(BlockSchema):
"""Schema for a related keyword result."""
keyword: str = SchemaField(description="The related keyword")
search_volume: Optional[int] = SchemaField(
description="Monthly search volume", default=None
)
competition: Optional[float] = SchemaField(
description="Competition level (0-1)", default=None
)
cpc: Optional[float] = SchemaField(
description="Cost per click in USD", default=None
)
keyword_difficulty: Optional[int] = SchemaField(
description="Keyword difficulty score", default=None
)
serp_info: Optional[Dict[str, Any]] = SchemaField(
description="SERP data for the keyword", default=None
)
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
description="Clickstream data metrics", default=None
)
class DataForSeoRelatedKeywordsBlock(Block):
"""Block for getting related keywords from DataForSEO Labs."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = dataforseo.credentials_field(
description="DataForSEO credentials (username and password)"
)
keyword: str = SchemaField(
description="Seed keyword to find related keywords for"
)
location_code: Optional[int] = SchemaField(
description="Location code for targeting (e.g., 2840 for USA)",
default=2840, # USA
)
language_code: Optional[str] = SchemaField(
description="Language code (e.g., 'en' for English)",
default="en",
)
include_seed_keyword: bool = SchemaField(
description="Include the seed keyword in results",
default=True,
)
include_serp_info: bool = SchemaField(
description="Include SERP information",
default=False,
)
include_clickstream_data: bool = SchemaField(
description="Include clickstream metrics",
default=False,
)
limit: int = SchemaField(
description="Maximum number of results (up to 3000)",
default=100,
ge=1,
le=3000,
)
class Output(BlockSchema):
related_keywords: List[RelatedKeyword] = SchemaField(
description="List of related keywords with metrics"
)
related_keyword: RelatedKeyword = SchemaField(
description="A related keyword with metrics"
)
total_count: int = SchemaField(
description="Total number of related keywords returned"
)
seed_keyword: str = SchemaField(
description="The seed keyword used for the query"
)
def __init__(self):
super().__init__(
id="8f2e4d6a-1b3c-4a5e-9d7f-2c8e6a4b3f1d",
description="Get related keywords from DataForSEO Labs Google API",
categories={BlockCategory.SEARCH, BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"credentials": dataforseo.get_test_credentials().model_dump(),
"keyword": "content marketing",
"location_code": 2840,
"language_code": "en",
"limit": 1,
},
test_credentials=dataforseo.get_test_credentials(),
test_output=[
(
"related_keyword",
lambda x: hasattr(x, "keyword") and x.keyword == "content strategy",
),
("related_keywords", lambda x: isinstance(x, list) and len(x) == 1),
("total_count", 1),
("seed_keyword", "content marketing"),
],
test_mock={
"_fetch_related_keywords": lambda *args, **kwargs: [
{
"items": [
{
"keyword_data": {
"keyword": "content strategy",
"keyword_info": {
"search_volume": 8000,
"competition": 0.4,
"cpc": 3.0,
},
"keyword_properties": {
"keyword_difficulty": 45,
},
}
}
]
}
]
},
)
async def _fetch_related_keywords(
self,
client: DataForSeoClient,
input_data: Input,
) -> Any:
"""Private method to fetch related keywords - can be mocked for testing."""
return await client.related_keywords(
keyword=input_data.keyword,
location_code=input_data.location_code,
language_code=input_data.language_code,
include_seed_keyword=input_data.include_seed_keyword,
include_serp_info=input_data.include_serp_info,
include_clickstream_data=input_data.include_clickstream_data,
limit=input_data.limit,
)
async def run(
self,
input_data: Input,
*,
credentials: UserPasswordCredentials,
**kwargs,
) -> BlockOutput:
"""Execute the related keywords query."""
client = DataForSeoClient(credentials)
results = await self._fetch_related_keywords(client, input_data)
# Process and format the results
related_keywords = []
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
items = (
first_result.get("items", []) if isinstance(first_result, dict) else []
)
for item in items:
# Extract keyword_data from the item
keyword_data = item.get("keyword_data", {})
# Create the RelatedKeyword object
keyword = RelatedKeyword(
keyword=keyword_data.get("keyword", ""),
search_volume=keyword_data.get("keyword_info", {}).get(
"search_volume"
),
competition=keyword_data.get("keyword_info", {}).get("competition"),
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
"keyword_difficulty"
),
serp_info=(
keyword_data.get("serp_info")
if input_data.include_serp_info
else None
),
clickstream_data=(
keyword_data.get("clickstream_keyword_info")
if input_data.include_clickstream_data
else None
),
)
yield "related_keyword", keyword
related_keywords.append(keyword)
yield "related_keywords", related_keywords
yield "total_count", len(related_keywords)
yield "seed_keyword", input_data.keyword
class RelatedKeywordExtractorBlock(Block):
"""Extracts individual fields from a RelatedKeyword object."""
class Input(BlockSchema):
related_keyword: RelatedKeyword = SchemaField(
description="The related keyword object to extract fields from"
)
class Output(BlockSchema):
keyword: str = SchemaField(description="The related keyword")
search_volume: Optional[int] = SchemaField(
description="Monthly search volume", default=None
)
competition: Optional[float] = SchemaField(
description="Competition level (0-1)", default=None
)
cpc: Optional[float] = SchemaField(
description="Cost per click in USD", default=None
)
keyword_difficulty: Optional[int] = SchemaField(
description="Keyword difficulty score", default=None
)
serp_info: Optional[Dict[str, Any]] = SchemaField(
description="SERP data for the keyword", default=None
)
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
description="Clickstream data metrics", default=None
)
def __init__(self):
super().__init__(
id="98342061-09d2-4952-bf77-0761fc8cc9a8",
description="Extract individual fields from a RelatedKeyword object",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"related_keyword": RelatedKeyword(
keyword="test related keyword",
search_volume=800,
competition=0.4,
cpc=3.0,
keyword_difficulty=55,
).model_dump()
},
test_output=[
("keyword", "test related keyword"),
("search_volume", 800),
("competition", 0.4),
("cpc", 3.0),
("keyword_difficulty", 55),
("serp_info", None),
("clickstream_data", None),
],
)
async def run(
self,
input_data: Input,
**kwargs,
) -> BlockOutput:
"""Extract fields from the RelatedKeyword object."""
related_keyword = input_data.related_keyword
yield "keyword", related_keyword.keyword
yield "search_volume", related_keyword.search_volume
yield "competition", related_keyword.competition
yield "cpc", related_keyword.cpc
yield "keyword_difficulty", related_keyword.keyword_difficulty
yield "serp_info", related_keyword.serp_info
yield "clickstream_data", related_keyword.clickstream_data

View File

@@ -2,29 +2,45 @@ import base64
import io
import mimetypes
from pathlib import Path
from typing import Any
from typing import Any, Literal
import aiohttp
import discord
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, SchemaField
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
from ._auth import (
TEST_BOT_CREDENTIALS,
TEST_BOT_CREDENTIALS_INPUT,
DiscordBotCredentialsField,
DiscordBotCredentialsInput,
)
DiscordCredentials = CredentialsMetaInput[
Literal[ProviderName.DISCORD], Literal["api_key"]
]
# Keep backward compatibility alias
DiscordCredentials = DiscordBotCredentialsInput
DiscordCredentialsField = DiscordBotCredentialsField
TEST_CREDENTIALS = TEST_BOT_CREDENTIALS
TEST_CREDENTIALS_INPUT = TEST_BOT_CREDENTIALS_INPUT
def DiscordCredentialsField() -> DiscordCredentials:
return CredentialsField(description="Discord bot token")
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="discord",
api_key=SecretStr("test_api_key"),
title="Mock Discord API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.type,
}
class ReadDiscordMessagesBlock(Block):

View File

@@ -1,117 +0,0 @@
"""
Discord API helper functions for making authenticated requests.
"""
import logging
from typing import Optional
from pydantic import BaseModel
from backend.data.model import OAuth2Credentials
from backend.util.request import Requests
logger = logging.getLogger(__name__)
class DiscordAPIException(Exception):
"""Exception raised for Discord API errors."""
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code
class DiscordOAuthUser(BaseModel):
"""Model for Discord OAuth user response."""
user_id: str
username: str
avatar_url: str
banner: Optional[str] = None
accent_color: Optional[int] = None
def get_api(credentials: OAuth2Credentials) -> Requests:
"""
Create a Requests instance configured for Discord API calls with OAuth2 credentials.
Args:
credentials: The OAuth2 credentials containing the access token.
Returns:
A configured Requests instance for Discord API calls.
"""
return Requests(
trusted_origins=[],
extra_headers={
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
"Content-Type": "application/json",
},
raise_for_status=False,
)
async def get_current_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
"""
Fetch the current user's information using Discord OAuth2 API.
Reference: https://discord.com/developers/docs/resources/user#get-current-user
Args:
credentials: The OAuth2 credentials.
Returns:
A model containing user data with avatar URL.
Raises:
DiscordAPIException: If the API request fails.
"""
api = get_api(credentials)
response = await api.get("https://discord.com/api/oauth2/@me")
if not response.ok:
error_text = response.text()
raise DiscordAPIException(
f"Failed to fetch user info: {response.status} - {error_text}",
response.status,
)
data = response.json()
logger.info(f"Discord OAuth2 API Response: {data}")
# The /api/oauth2/@me endpoint returns a user object nested in the response
user_info = data.get("user", {})
logger.info(f"User info extracted: {user_info}")
# Build avatar URL
user_id = user_info.get("id")
avatar_hash = user_info.get("avatar")
if avatar_hash:
# Custom avatar
avatar_ext = "gif" if avatar_hash.startswith("a_") else "png"
avatar_url = (
f"https://cdn.discordapp.com/avatars/{user_id}/{avatar_hash}.{avatar_ext}"
)
else:
# Default avatar based on discriminator or user ID
discriminator = user_info.get("discriminator", "0")
if discriminator == "0":
# New username system - use user ID for default avatar
default_avatar_index = (int(user_id) >> 22) % 6
else:
# Legacy discriminator system
default_avatar_index = int(discriminator) % 5
avatar_url = (
f"https://cdn.discordapp.com/embed/avatars/{default_avatar_index}.png"
)
result = DiscordOAuthUser(
user_id=user_id,
username=user_info.get("username", ""),
avatar_url=avatar_url,
banner=user_info.get("banner"),
accent_color=user_info.get("accent_color"),
)
logger.info(f"Returning user data: {result.model_dump()}")
return result

View File

@@ -1,74 +0,0 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
)
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
secrets = Secrets()
DISCORD_OAUTH_IS_CONFIGURED = bool(
secrets.discord_client_id and secrets.discord_client_secret
)
# Bot token credentials (existing)
DiscordBotCredentials = APIKeyCredentials
DiscordBotCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.DISCORD], Literal["api_key"]
]
# OAuth2 credentials (new)
DiscordOAuthCredentials = OAuth2Credentials
DiscordOAuthCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.DISCORD], Literal["oauth2"]
]
def DiscordBotCredentialsField() -> DiscordBotCredentialsInput:
"""Creates a Discord bot token credentials field."""
return CredentialsField(description="Discord bot token")
def DiscordOAuthCredentialsField(scopes: list[str]) -> DiscordOAuthCredentialsInput:
"""Creates a Discord OAuth2 credentials field."""
return CredentialsField(
description="Discord OAuth2 credentials",
required_scopes=set(scopes) | {"identify"}, # Basic user info scope
)
# Test credentials for bot tokens
TEST_BOT_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="discord",
api_key=SecretStr("test_api_key"),
title="Mock Discord API key",
expires_at=None,
)
TEST_BOT_CREDENTIALS_INPUT = {
"provider": TEST_BOT_CREDENTIALS.provider,
"id": TEST_BOT_CREDENTIALS.id,
"type": TEST_BOT_CREDENTIALS.type,
"title": TEST_BOT_CREDENTIALS.type,
}
# Test credentials for OAuth2
TEST_OAUTH_CREDENTIALS = OAuth2Credentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="discord",
access_token=SecretStr("test_access_token"),
title="Mock Discord OAuth",
scopes=["identify"],
username="testuser",
)
TEST_OAUTH_CREDENTIALS_INPUT = {
"provider": TEST_OAUTH_CREDENTIALS.provider,
"id": TEST_OAUTH_CREDENTIALS.id,
"type": TEST_OAUTH_CREDENTIALS.type,
"title": TEST_OAUTH_CREDENTIALS.type,
}

View File

@@ -1,99 +0,0 @@
"""
Discord OAuth-based blocks.
"""
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import OAuth2Credentials, SchemaField
from ._api import DiscordOAuthUser, get_current_user
from ._auth import (
DISCORD_OAUTH_IS_CONFIGURED,
TEST_OAUTH_CREDENTIALS,
TEST_OAUTH_CREDENTIALS_INPUT,
DiscordOAuthCredentialsField,
DiscordOAuthCredentialsInput,
)
class DiscordGetCurrentUserBlock(Block):
"""
Gets information about the currently authenticated Discord user using OAuth2.
This block requires Discord OAuth2 credentials (not bot tokens).
"""
class Input(BlockSchema):
credentials: DiscordOAuthCredentialsInput = DiscordOAuthCredentialsField(
["identify"]
)
class Output(BlockSchema):
user_id: str = SchemaField(description="The authenticated user's Discord ID")
username: str = SchemaField(description="The user's username")
avatar_url: str = SchemaField(description="URL to the user's avatar image")
banner_url: str = SchemaField(
description="URL to the user's banner image (if set)", default=""
)
accent_color: int = SchemaField(
description="The user's accent color as an integer", default=0
)
def __init__(self):
super().__init__(
id="8c7e39b8-4e9d-4f3a-b4e1-2a8c9d5f6e3b",
input_schema=DiscordGetCurrentUserBlock.Input,
output_schema=DiscordGetCurrentUserBlock.Output,
description="Gets information about the currently authenticated Discord user using OAuth2 credentials.",
categories={BlockCategory.SOCIAL},
disabled=not DISCORD_OAUTH_IS_CONFIGURED,
test_input={
"credentials": TEST_OAUTH_CREDENTIALS_INPUT,
},
test_credentials=TEST_OAUTH_CREDENTIALS,
test_output=[
("user_id", "123456789012345678"),
("username", "testuser"),
(
"avatar_url",
"https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
),
("banner_url", ""),
("accent_color", 0),
],
test_mock={
"get_user": lambda _: DiscordOAuthUser(
user_id="123456789012345678",
username="testuser",
avatar_url="https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
banner=None,
accent_color=0,
)
},
)
@staticmethod
async def get_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
user_info = await get_current_user(credentials)
return user_info
async def run(
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
) -> BlockOutput:
try:
result = await self.get_user(credentials)
# Yield each output field
yield "user_id", result.user_id
yield "username", result.username
yield "avatar_url", result.avatar_url
# Handle banner URL if banner hash exists
if result.banner:
banner_url = f"https://cdn.discordapp.com/banners/{result.user_id}/{result.banner}.png"
yield "banner_url", banner_url
else:
yield "banner_url", ""
yield "accent_color", result.accent_color or 0
except Exception as e:
raise ValueError(f"Failed to get Discord user info: {e}")

View File

@@ -93,11 +93,11 @@ class Webset(BaseModel):
"""
Set of key-value pairs you want to associate with this object.
"""
created_at: Annotated[datetime | None, Field(alias="createdAt")] = None
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
"""
The date and time the webset was created
"""
updated_at: Annotated[datetime | None, Field(alias="updatedAt")] = None
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
"""
The date and time the webset was last updated
"""

View File

@@ -30,7 +30,6 @@ TEST_CREDENTIALS_INPUT = {
class IdeogramModelName(str, Enum):
V3 = "V_3"
V2 = "V_2"
V1 = "V_1"
V1_TURBO = "V_1_TURBO"
@@ -96,8 +95,8 @@ class IdeogramModelBlock(Block):
title="Prompt",
)
ideogram_model_name: IdeogramModelName = SchemaField(
description="The name of the Image Generation Model, e.g., V_3",
default=IdeogramModelName.V3,
description="The name of the Image Generation Model, e.g., V_2",
default=IdeogramModelName.V2,
title="Image Generation Model",
advanced=False,
)
@@ -237,111 +236,6 @@ class IdeogramModelBlock(Block):
negative_prompt: Optional[str],
color_palette_name: str,
custom_colors: Optional[list[str]],
):
# Use V3 endpoint for V3 model, legacy endpoint for others
if model_name == "V_3":
return await self._run_model_v3(
api_key,
prompt,
seed,
aspect_ratio,
magic_prompt_option,
style_type,
negative_prompt,
color_palette_name,
custom_colors,
)
else:
return await self._run_model_legacy(
api_key,
model_name,
prompt,
seed,
aspect_ratio,
magic_prompt_option,
style_type,
negative_prompt,
color_palette_name,
custom_colors,
)
async def _run_model_v3(
self,
api_key: SecretStr,
prompt: str,
seed: Optional[int],
aspect_ratio: str,
magic_prompt_option: str,
style_type: str,
negative_prompt: Optional[str],
color_palette_name: str,
custom_colors: Optional[list[str]],
):
url = "https://api.ideogram.ai/v1/ideogram-v3/generate"
headers = {
"Api-Key": api_key.get_secret_value(),
"Content-Type": "application/json",
}
# Map legacy aspect ratio values to V3 format
aspect_ratio_map = {
"ASPECT_10_16": "10x16",
"ASPECT_16_10": "16x10",
"ASPECT_9_16": "9x16",
"ASPECT_16_9": "16x9",
"ASPECT_3_2": "3x2",
"ASPECT_2_3": "2x3",
"ASPECT_4_3": "4x3",
"ASPECT_3_4": "3x4",
"ASPECT_1_1": "1x1",
"ASPECT_1_3": "1x3",
"ASPECT_3_1": "3x1",
# Additional V3 supported ratios
"ASPECT_1_2": "1x2",
"ASPECT_2_1": "2x1",
"ASPECT_4_5": "4x5",
"ASPECT_5_4": "5x4",
}
v3_aspect_ratio = aspect_ratio_map.get(
aspect_ratio, "1x1"
) # Default to 1x1 if not found
# Use JSON for V3 endpoint (simpler than multipart/form-data)
data: Dict[str, Any] = {
"prompt": prompt,
"aspect_ratio": v3_aspect_ratio,
"magic_prompt": magic_prompt_option,
"style_type": style_type,
}
if seed is not None:
data["seed"] = seed
if negative_prompt:
data["negative_prompt"] = negative_prompt
# Note: V3 endpoint may have different color palette support
# For now, we'll omit color palettes for V3 to avoid errors
try:
response = await Requests().post(url, headers=headers, json=data)
return response.json()["data"][0]["url"]
except RequestException as e:
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
async def _run_model_legacy(
self,
api_key: SecretStr,
model_name: str,
prompt: str,
seed: Optional[int],
aspect_ratio: str,
magic_prompt_option: str,
style_type: str,
negative_prompt: Optional[str],
color_palette_name: str,
custom_colors: Optional[list[str]],
):
url = "https://api.ideogram.ai/generate"
headers = {
@@ -355,33 +249,28 @@ class IdeogramModelBlock(Block):
"model": model_name,
"aspect_ratio": aspect_ratio,
"magic_prompt_option": magic_prompt_option,
"style_type": style_type,
}
}
# Only add style_type for V2, V2_TURBO, and V3 models (V1 models don't support it)
if model_name in ["V_2", "V_2_TURBO", "V_3"]:
data["image_request"]["style_type"] = style_type
if seed is not None:
data["image_request"]["seed"] = seed
if negative_prompt:
data["image_request"]["negative_prompt"] = negative_prompt
# Only add color palette for V2 and V2_TURBO models (V1 models don't support it)
if model_name in ["V_2", "V_2_TURBO"]:
if color_palette_name != "NONE":
data["color_palette"] = {"name": color_palette_name}
elif custom_colors:
data["color_palette"] = {
"members": [{"color_hex": color} for color in custom_colors]
}
if color_palette_name != "NONE":
data["color_palette"] = {"name": color_palette_name}
elif custom_colors:
data["color_palette"] = {
"members": [{"color_hex": color} for color in custom_colors]
}
try:
response = await Requests().post(url, headers=headers, json=data)
return response.json()["data"][0]["url"]
except RequestException as e:
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
raise Exception(f"Failed to fetch image: {str(e)}")
async def upscale_image(self, api_key: SecretStr, image_url: str):
url = "https://api.ideogram.ai/upscale"

View File

@@ -1,8 +0,0 @@
from backend.sdk import BlockCostType, ProviderBuilder
stagehand = (
ProviderBuilder("stagehand")
.with_api_key("STAGEHAND_API_KEY", "Stagehand API Key")
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -1,393 +0,0 @@
import logging
import signal
import threading
from contextlib import contextmanager
from enum import Enum
# Monkey patch Stagehands to prevent signal handling in worker threads
import stagehand.main
from stagehand import Stagehand
from backend.blocks.llm import (
MODEL_METADATA,
AICredentials,
AICredentialsField,
LlmModel,
ModelMetadata,
)
from backend.blocks.stagehand._config import stagehand as stagehand_provider
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
# Store the original method
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
def safe_register_signal_handlers(self):
"""Only register signal handlers in the main thread"""
if threading.current_thread() is threading.main_thread():
original_register_signal_handlers(self)
else:
# Skip signal handling in worker threads
pass
# Replace the method
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
@contextmanager
def disable_signal_handling():
"""Context manager to temporarily disable signal handling"""
if threading.current_thread() is not threading.main_thread():
# In worker threads, temporarily replace signal.signal with a no-op
original_signal = signal.signal
def noop_signal(*args, **kwargs):
pass
signal.signal = noop_signal
try:
yield
finally:
signal.signal = original_signal
else:
# In main thread, don't modify anything
yield
logger = logging.getLogger(__name__)
class StagehandRecommendedLlmModel(str, Enum):
"""
This is subset of LLModel from autogpt_platform/backend/backend/blocks/llm.py
It contains only the models recommended by Stagehand
"""
# OpenAI
GPT41 = "gpt-4.1-2025-04-14"
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
@property
def provider_name(self) -> str:
"""
Returns the provider name for the model in the required format for Stagehand:
provider/model_name
"""
model_metadata = MODEL_METADATA[LlmModel(self.value)]
model_name = self.value
if len(model_name.split("/")) == 1 and not self.value.startswith(
model_metadata.provider
):
assert (
model_metadata.provider != "open_router"
), "Logic failed and open_router provider attempted to be prepended to model name! in stagehand/_config.py"
model_name = f"{model_metadata.provider}/{model_name}"
logger.error(f"Model name: {model_name}")
return model_name
@property
def provider(self) -> str:
return MODEL_METADATA[LlmModel(self.value)].provider
@property
def metadata(self) -> ModelMetadata:
return MODEL_METADATA[LlmModel(self.value)]
@property
def context_window(self) -> int:
return MODEL_METADATA[LlmModel(self.value)].context_window
@property
def max_output_tokens(self) -> int | None:
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
class StagehandObserveBlock(Block):
class Input(BlockSchema):
# Browserbase credentials (Stagehand provider) or raw API key
stagehand_credentials: CredentialsMetaInput = (
stagehand_provider.credentials_field(
description="Stagehand/Browserbase API key"
)
)
browserbase_project_id: str = SchemaField(
description="Browserbase project ID (required if using Browserbase)",
)
# Model selection and credentials (provider-discriminated like llm.py)
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
url: str = SchemaField(
description="URL to navigate to.",
)
instruction: str = SchemaField(
description="Natural language description of elements or actions to discover.",
)
iframes: bool = SchemaField(
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
default=True,
)
domSettleTimeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
default=45000,
)
class Output(BlockSchema):
selector: str = SchemaField(description="XPath selector to locate element.")
description: str = SchemaField(description="Human-readable description")
method: str | None = SchemaField(description="Suggested action method")
arguments: list[str] | None = SchemaField(
description="Additional action parameters"
)
def __init__(self):
super().__init__(
id="d3863944-0eaf-45c4-a0c9-63e0fe1ee8b9",
description="Find suggested actions for your workflows",
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
input_schema=StagehandObserveBlock.Input,
output_schema=StagehandObserveBlock.Output,
)
async def run(
self,
input_data: Input,
*,
stagehand_credentials: APIKeyCredentials,
model_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
api_key=stagehand_credentials.api_key.get_secret_value(),
project_id=input_data.browserbase_project_id,
model_name=input_data.model.provider_name,
model_api_key=model_credentials.api_key.get_secret_value(),
)
await stagehand.init()
page = stagehand.page
assert page is not None, "Stagehand page is not initialized"
await page.goto(input_data.url)
observe_results = await page.observe(
input_data.instruction,
iframes=input_data.iframes,
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
)
for result in observe_results:
yield "selector", result.selector
yield "description", result.description
yield "method", result.method
yield "arguments", result.arguments
class StagehandActBlock(Block):
class Input(BlockSchema):
# Browserbase credentials (Stagehand provider) or raw API key
stagehand_credentials: CredentialsMetaInput = (
stagehand_provider.credentials_field(
description="Stagehand/Browserbase API key"
)
)
browserbase_project_id: str = SchemaField(
description="Browserbase project ID (required if using Browserbase)",
)
# Model selection and credentials (provider-discriminated like llm.py)
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
url: str = SchemaField(
description="URL to navigate to.",
)
action: list[str] = SchemaField(
description="Action to perform. Suggested actions are: click, fill, type, press, scroll, select from dropdown. For multi-step actions, add an entry for each step.",
)
variables: dict[str, str] = SchemaField(
description="Variables to use in the action. Variables contains data you want the action to use.",
default_factory=dict,
)
iframes: bool = SchemaField(
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
default=True,
)
domSettleTimeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
default=45000,
)
timeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
default=60000,
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the action was completed successfully"
)
message: str = SchemaField(description="Details about the actions execution.")
action: str = SchemaField(description="Action performed")
def __init__(self):
super().__init__(
id="86eba68b-9549-4c0b-a0db-47d85a56cc27",
description="Interact with a web page by performing actions on a web page. Use it to build self-healing and deterministic automations that adapt to website chang.",
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
input_schema=StagehandActBlock.Input,
output_schema=StagehandActBlock.Output,
)
async def run(
self,
input_data: Input,
*,
stagehand_credentials: APIKeyCredentials,
model_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
api_key=stagehand_credentials.api_key.get_secret_value(),
project_id=input_data.browserbase_project_id,
model_name=input_data.model.provider_name,
model_api_key=model_credentials.api_key.get_secret_value(),
)
await stagehand.init()
page = stagehand.page
assert page is not None, "Stagehand page is not initialized"
await page.goto(input_data.url)
for action in input_data.action:
action_results = await page.act(
action,
variables=input_data.variables,
iframes=input_data.iframes,
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
timeoutMs=input_data.timeoutMs,
)
yield "success", action_results.success
yield "message", action_results.message
yield "action", action_results.action
class StagehandExtractBlock(Block):
class Input(BlockSchema):
# Browserbase credentials (Stagehand provider) or raw API key
stagehand_credentials: CredentialsMetaInput = (
stagehand_provider.credentials_field(
description="Stagehand/Browserbase API key"
)
)
browserbase_project_id: str = SchemaField(
description="Browserbase project ID (required if using Browserbase)",
)
# Model selection and credentials (provider-discriminated like llm.py)
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
url: str = SchemaField(
description="URL to navigate to.",
)
instruction: str = SchemaField(
description="Natural language description of elements or actions to discover.",
)
iframes: bool = SchemaField(
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
default=True,
)
domSettleTimeoutMs: int = SchemaField(
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
default=45000,
)
class Output(BlockSchema):
extraction: str = SchemaField(description="Extracted data from the page.")
def __init__(self):
super().__init__(
id="fd3c0b18-2ba6-46ae-9339-fcb40537ad98",
description="Extract structured data from a webpage.",
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
input_schema=StagehandExtractBlock.Input,
output_schema=StagehandExtractBlock.Output,
)
async def run(
self,
input_data: Input,
*,
stagehand_credentials: APIKeyCredentials,
model_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
api_key=stagehand_credentials.api_key.get_secret_value(),
project_id=input_data.browserbase_project_id,
model_name=input_data.model.provider_name,
model_api_key=model_credentials.api_key.get_secret_value(),
)
await stagehand.init()
page = stagehand.page
assert page is not None, "Stagehand page is not initialized"
await page.goto(input_data.url)
extraction = await page.extract(
input_data.instruction,
iframes=input_data.iframes,
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
)
yield "extraction", str(extraction.model_dump()["extraction"])

View File

@@ -1,283 +0,0 @@
import logging
from typing import Any
from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
class LibraryAgent(BaseModel):
"""Model representing an agent in the user's library."""
library_agent_id: str = ""
agent_id: str = ""
agent_version: int = 0
agent_name: str = ""
description: str = ""
creator: str = ""
is_archived: bool = False
categories: list[str] = []
class AddToLibraryFromStoreBlock(Block):
"""
Block that adds an agent from the store to the user's library.
This enables users to easily import agents from the marketplace into their personal collection.
"""
class Input(BlockSchema):
store_listing_version_id: str = SchemaField(
description="The ID of the store listing version to add to library"
)
agent_name: str | None = SchemaField(
description="Optional custom name for the agent in your library",
default=None,
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the agent was successfully added to library"
)
library_agent_id: str = SchemaField(
description="The ID of the library agent entry"
)
agent_id: str = SchemaField(description="The ID of the agent graph")
agent_version: int = SchemaField(
description="The version number of the agent graph"
)
agent_name: str = SchemaField(description="The name of the agent")
message: str = SchemaField(description="Success or error message")
def __init__(self):
super().__init__(
id="2602a7b1-3f4d-4e5f-9c8b-1a2b3c4d5e6f",
description="Add an agent from the store to your personal library",
categories={BlockCategory.BASIC},
input_schema=AddToLibraryFromStoreBlock.Input,
output_schema=AddToLibraryFromStoreBlock.Output,
test_input={
"store_listing_version_id": "test-listing-id",
"agent_name": "My Custom Agent",
},
test_output=[
("success", True),
("library_agent_id", "test-library-id"),
("agent_id", "test-agent-id"),
("agent_version", 1),
("agent_name", "Test Agent"),
("message", "Agent successfully added to library"),
],
test_mock={
"_add_to_library": lambda *_, **__: LibraryAgent(
library_agent_id="test-library-id",
agent_id="test-agent-id",
agent_version=1,
agent_name="Test Agent",
)
},
)
async def run(
self,
input_data: Input,
*,
user_id: str,
**kwargs,
) -> BlockOutput:
library_agent = await self._add_to_library(
user_id=user_id,
store_listing_version_id=input_data.store_listing_version_id,
custom_name=input_data.agent_name,
)
yield "success", True
yield "library_agent_id", library_agent.library_agent_id
yield "agent_id", library_agent.agent_id
yield "agent_version", library_agent.agent_version
yield "agent_name", library_agent.agent_name
yield "message", "Agent successfully added to library"
async def _add_to_library(
self,
user_id: str,
store_listing_version_id: str,
custom_name: str | None = None,
) -> LibraryAgent:
"""
Add a store agent to the user's library using the existing library database function.
"""
library_agent = (
await get_database_manager_async_client().add_store_agent_to_library(
store_listing_version_id=store_listing_version_id, user_id=user_id
)
)
# If custom name is provided, we could update the library agent name here
# For now, we'll just return the agent info
agent_name = custom_name if custom_name else library_agent.name
return LibraryAgent(
library_agent_id=library_agent.id,
agent_id=library_agent.graph_id,
agent_version=library_agent.graph_version,
agent_name=agent_name,
)
class ListLibraryAgentsBlock(Block):
"""
Block that lists all agents in the user's library.
"""
class Input(BlockSchema):
search_query: str | None = SchemaField(
description="Optional search query to filter agents", default=None
)
limit: int = SchemaField(
description="Maximum number of agents to return", default=50, ge=1, le=100
)
page: int = SchemaField(
description="Page number for pagination", default=1, ge=1
)
class Output(BlockSchema):
agents: list[LibraryAgent] = SchemaField(
description="List of agents in the library",
default_factory=list,
)
agent: LibraryAgent = SchemaField(
description="Individual library agent (yielded for each agent)"
)
total_count: int = SchemaField(
description="Total number of agents in library", default=0
)
page: int = SchemaField(description="Current page number", default=1)
total_pages: int = SchemaField(description="Total number of pages", default=1)
def __init__(self):
super().__init__(
id="082602d3-a74d-4600-9e9c-15b3af7eae98",
description="List all agents in your personal library",
categories={BlockCategory.BASIC, BlockCategory.DATA},
input_schema=ListLibraryAgentsBlock.Input,
output_schema=ListLibraryAgentsBlock.Output,
test_input={
"search_query": None,
"limit": 10,
"page": 1,
},
test_output=[
(
"agents",
[
LibraryAgent(
library_agent_id="test-lib-id",
agent_id="test-agent-id",
agent_version=1,
agent_name="Test Library Agent",
description="A test agent in library",
creator="Test User",
),
],
),
("total_count", 1),
("page", 1),
("total_pages", 1),
(
"agent",
LibraryAgent(
library_agent_id="test-lib-id",
agent_id="test-agent-id",
agent_version=1,
agent_name="Test Library Agent",
description="A test agent in library",
creator="Test User",
),
),
],
test_mock={
"_list_library_agents": lambda *_, **__: {
"agents": [
LibraryAgent(
library_agent_id="test-lib-id",
agent_id="test-agent-id",
agent_version=1,
agent_name="Test Library Agent",
description="A test agent in library",
creator="Test User",
)
],
"total": 1,
"page": 1,
"total_pages": 1,
}
},
)
async def run(
self,
input_data: Input,
*,
user_id: str,
**kwargs,
) -> BlockOutput:
result = await self._list_library_agents(
user_id=user_id,
search_query=input_data.search_query,
limit=input_data.limit,
page=input_data.page,
)
agents = result["agents"]
yield "agents", agents
yield "total_count", result["total"]
yield "page", result["page"]
yield "total_pages", result["total_pages"]
# Yield each agent individually for better graph connectivity
for agent in agents:
yield "agent", agent
async def _list_library_agents(
self,
user_id: str,
search_query: str | None = None,
limit: int = 50,
page: int = 1,
) -> dict[str, Any]:
"""
List agents in the user's library using the database client.
"""
result = await get_database_manager_async_client().list_library_agents(
user_id=user_id,
search_term=search_query,
page=page,
page_size=limit,
)
agents = [
LibraryAgent(
library_agent_id=agent.id,
agent_id=agent.graph_id,
agent_version=agent.graph_version,
agent_name=agent.name,
description=getattr(agent, "description", ""),
creator=getattr(agent, "creator", ""),
is_archived=getattr(agent, "is_archived", False),
categories=getattr(agent, "categories", []),
)
for agent in result.agents
]
return {
"agents": agents,
"total": result.pagination.total_items,
"page": result.pagination.current_page,
"total_pages": result.pagination.total_pages,
}

View File

@@ -1,311 +0,0 @@
import logging
from typing import Literal
from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
class StoreAgent(BaseModel):
"""Model representing a store agent."""
slug: str = ""
name: str = ""
description: str = ""
creator: str = ""
rating: float = 0.0
runs: int = 0
categories: list[str] = []
class StoreAgentDict(BaseModel):
"""Dictionary representation of a store agent."""
slug: str
name: str
description: str
creator: str
rating: float
runs: int
class SearchAgentsResponse(BaseModel):
"""Response from searching store agents."""
agents: list[StoreAgentDict]
total_count: int
class StoreAgentDetails(BaseModel):
"""Detailed information about a store agent."""
found: bool
store_listing_version_id: str = ""
agent_name: str = ""
description: str = ""
creator: str = ""
categories: list[str] = []
runs: int = 0
rating: float = 0.0
class GetStoreAgentDetailsBlock(Block):
"""
Block that retrieves detailed information about an agent from the store.
"""
class Input(BlockSchema):
creator: str = SchemaField(description="The username of the agent creator")
slug: str = SchemaField(description="The name of the agent")
class Output(BlockSchema):
found: bool = SchemaField(
description="Whether the agent was found in the store"
)
store_listing_version_id: str = SchemaField(
description="The store listing version ID"
)
agent_name: str = SchemaField(description="Name of the agent")
description: str = SchemaField(description="Description of the agent")
creator: str = SchemaField(description="Creator of the agent")
categories: list[str] = SchemaField(
description="Categories the agent belongs to", default_factory=list
)
runs: int = SchemaField(
description="Number of times the agent has been run", default=0
)
rating: float = SchemaField(
description="Average rating of the agent", default=0.0
)
def __init__(self):
super().__init__(
id="b604f0ec-6e0d-40a7-bf55-9fd09997cced",
description="Get detailed information about an agent from the store",
categories={BlockCategory.BASIC, BlockCategory.DATA},
input_schema=GetStoreAgentDetailsBlock.Input,
output_schema=GetStoreAgentDetailsBlock.Output,
test_input={"creator": "test-creator", "slug": "test-agent-slug"},
test_output=[
("found", True),
("store_listing_version_id", "test-listing-id"),
("agent_name", "Test Agent"),
("description", "A test agent"),
("creator", "Test Creator"),
("categories", ["productivity", "automation"]),
("runs", 100),
("rating", 4.5),
],
test_mock={
"_get_agent_details": lambda *_, **__: StoreAgentDetails(
found=True,
store_listing_version_id="test-listing-id",
agent_name="Test Agent",
description="A test agent",
creator="Test Creator",
categories=["productivity", "automation"],
runs=100,
rating=4.5,
)
},
static_output=True,
)
async def run(
self,
input_data: Input,
**kwargs,
) -> BlockOutput:
details = await self._get_agent_details(
creator=input_data.creator, slug=input_data.slug
)
yield "found", details.found
yield "store_listing_version_id", details.store_listing_version_id
yield "agent_name", details.agent_name
yield "description", details.description
yield "creator", details.creator
yield "categories", details.categories
yield "runs", details.runs
yield "rating", details.rating
async def _get_agent_details(self, creator: str, slug: str) -> StoreAgentDetails:
"""
Retrieve detailed information about a store agent.
"""
# Get by specific version ID
agent_details = (
await get_database_manager_async_client().get_store_agent_details(
username=creator, agent_name=slug
)
)
return StoreAgentDetails(
found=True,
store_listing_version_id=agent_details.store_listing_version_id,
agent_name=agent_details.agent_name,
description=agent_details.description,
creator=agent_details.creator,
categories=(
agent_details.categories if hasattr(agent_details, "categories") else []
),
runs=agent_details.runs,
rating=agent_details.rating,
)
class SearchStoreAgentsBlock(Block):
"""
Block that searches for agents in the store based on various criteria.
"""
class Input(BlockSchema):
query: str | None = SchemaField(
description="Search query to find agents", default=None
)
category: str | None = SchemaField(
description="Filter by category", default=None
)
sort_by: Literal["rating", "runs", "name", "recent"] = SchemaField(
description="How to sort the results", default="rating"
)
limit: int = SchemaField(
description="Maximum number of results to return", default=10, ge=1, le=100
)
class Output(BlockSchema):
agents: list[StoreAgent] = SchemaField(
description="List of agents matching the search criteria",
default_factory=list,
)
agent: StoreAgent = SchemaField(description="Basic information of the agent")
total_count: int = SchemaField(
description="Total number of agents found", default=0
)
def __init__(self):
super().__init__(
id="39524701-026c-4328-87cc-1b88c8e2cb4c",
description="Search for agents in the store",
categories={BlockCategory.BASIC, BlockCategory.DATA},
input_schema=SearchStoreAgentsBlock.Input,
output_schema=SearchStoreAgentsBlock.Output,
test_input={
"query": "productivity",
"category": None,
"sort_by": "rating",
"limit": 10,
},
test_output=[
(
"agents",
[
{
"slug": "test-agent",
"name": "Test Agent",
"description": "A test agent",
"creator": "Test Creator",
"rating": 4.5,
"runs": 100,
}
],
),
("total_count", 1),
(
"agent",
{
"slug": "test-agent",
"name": "Test Agent",
"description": "A test agent",
"creator": "Test Creator",
"rating": 4.5,
"runs": 100,
},
),
],
test_mock={
"_search_agents": lambda *_, **__: SearchAgentsResponse(
agents=[
StoreAgentDict(
slug="test-agent",
name="Test Agent",
description="A test agent",
creator="Test Creator",
rating=4.5,
runs=100,
)
],
total_count=1,
)
},
)
async def run(
self,
input_data: Input,
**kwargs,
) -> BlockOutput:
result = await self._search_agents(
query=input_data.query,
category=input_data.category,
sort_by=input_data.sort_by,
limit=input_data.limit,
)
agents = result.agents
total_count = result.total_count
# Convert to dict for output
agents_as_dicts = [agent.model_dump() for agent in agents]
yield "agents", agents_as_dicts
yield "total_count", total_count
for agent_dict in agents_as_dicts:
yield "agent", agent_dict
async def _search_agents(
self,
query: str | None = None,
category: str | None = None,
sort_by: str = "rating",
limit: int = 10,
) -> SearchAgentsResponse:
"""
Search for agents in the store using the existing store database function.
"""
# Map our sort_by to the store's sorted_by parameter
sorted_by_map = {
"rating": "most_popular",
"runs": "most_runs",
"name": "alphabetical",
"recent": "recently_updated",
}
result = await get_database_manager_async_client().get_store_agents(
featured=False,
creators=None,
sorted_by=sorted_by_map.get(sort_by, "most_popular"),
search_query=query,
category=category,
page=1,
page_size=limit,
)
agents = [
StoreAgentDict(
slug=agent.slug,
name=agent.agent_name,
description=agent.description,
creator=agent.creator,
rating=agent.rating,
runs=agent.runs,
)
for agent in result.agents
]
return SearchAgentsResponse(agents=agents, total_count=len(agents))

View File

@@ -1,155 +0,0 @@
from unittest.mock import MagicMock
import pytest
from backend.blocks.system.library_operations import (
AddToLibraryFromStoreBlock,
LibraryAgent,
)
from backend.blocks.system.store_operations import (
GetStoreAgentDetailsBlock,
SearchAgentsResponse,
SearchStoreAgentsBlock,
StoreAgentDetails,
StoreAgentDict,
)
@pytest.mark.asyncio
async def test_add_to_library_from_store_block_success(mocker):
"""Test successful addition of agent from store to library."""
block = AddToLibraryFromStoreBlock()
# Mock the library agent response
mock_library_agent = MagicMock()
mock_library_agent.id = "lib-agent-123"
mock_library_agent.graph_id = "graph-456"
mock_library_agent.graph_version = 1
mock_library_agent.name = "Test Agent"
mocker.patch.object(
block,
"_add_to_library",
return_value=LibraryAgent(
library_agent_id="lib-agent-123",
agent_id="graph-456",
agent_version=1,
agent_name="Test Agent",
),
)
input_data = block.Input(
store_listing_version_id="store-listing-v1", agent_name="Custom Agent Name"
)
outputs = {}
async for name, value in block.run(input_data, user_id="test-user"):
outputs[name] = value
assert outputs["success"] is True
assert outputs["library_agent_id"] == "lib-agent-123"
assert outputs["agent_id"] == "graph-456"
assert outputs["agent_version"] == 1
assert outputs["agent_name"] == "Test Agent"
assert outputs["message"] == "Agent successfully added to library"
@pytest.mark.asyncio
async def test_get_store_agent_details_block_success(mocker):
"""Test successful retrieval of store agent details."""
block = GetStoreAgentDetailsBlock()
mocker.patch.object(
block,
"_get_agent_details",
return_value=StoreAgentDetails(
found=True,
store_listing_version_id="version-123",
agent_name="Test Agent",
description="A test agent for testing",
creator="Test Creator",
categories=["productivity", "automation"],
runs=100,
rating=4.5,
),
)
input_data = block.Input(creator="Test Creator", slug="test-slug")
outputs = {}
async for name, value in block.run(input_data):
outputs[name] = value
assert outputs["found"] is True
assert outputs["store_listing_version_id"] == "version-123"
assert outputs["agent_name"] == "Test Agent"
assert outputs["description"] == "A test agent for testing"
assert outputs["creator"] == "Test Creator"
assert outputs["categories"] == ["productivity", "automation"]
assert outputs["runs"] == 100
assert outputs["rating"] == 4.5
@pytest.mark.asyncio
async def test_search_store_agents_block(mocker):
"""Test searching for store agents."""
block = SearchStoreAgentsBlock()
mocker.patch.object(
block,
"_search_agents",
return_value=SearchAgentsResponse(
agents=[
StoreAgentDict(
slug="creator1/agent1",
name="Agent One",
description="First test agent",
creator="Creator 1",
rating=4.8,
runs=500,
),
StoreAgentDict(
slug="creator2/agent2",
name="Agent Two",
description="Second test agent",
creator="Creator 2",
rating=4.2,
runs=200,
),
],
total_count=2,
),
)
input_data = block.Input(
query="test", category="productivity", sort_by="rating", limit=10
)
outputs = {}
async for name, value in block.run(input_data):
outputs[name] = value
assert len(outputs["agents"]) == 2
assert outputs["total_count"] == 2
assert outputs["agents"][0]["name"] == "Agent One"
assert outputs["agents"][0]["rating"] == 4.8
@pytest.mark.asyncio
async def test_search_store_agents_block_empty_results(mocker):
"""Test searching with no results."""
block = SearchStoreAgentsBlock()
mocker.patch.object(
block,
"_search_agents",
return_value=SearchAgentsResponse(agents=[], total_count=0),
)
input_data = block.Input(query="nonexistent", limit=10)
outputs = {}
async for name, value in block.run(input_data):
outputs[name] = value
assert outputs["agents"] == []
assert outputs["total_count"] == 0

View File

@@ -1,5 +1,4 @@
import asyncio
import logging
import time
from datetime import datetime, timedelta
from typing import Any, Literal, Union
@@ -8,7 +7,6 @@ from zoneinfo import ZoneInfo
from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.execution import UserContext
from backend.data.model import SchemaField
# Shared timezone literal type for all time/date blocks
@@ -53,80 +51,16 @@ TimezoneLiteral = Literal[
"Etc/GMT+12", # UTC-12:00
]
logger = logging.getLogger(__name__)
def _get_timezone(
format_type: Any, # Any format type with timezone and use_user_timezone attributes
user_timezone: str | None,
) -> ZoneInfo:
"""
Determine which timezone to use based on format settings and user context.
Args:
format_type: The format configuration containing timezone settings
user_timezone: The user's timezone from context
Returns:
ZoneInfo object for the determined timezone
"""
if format_type.use_user_timezone and user_timezone:
tz = ZoneInfo(user_timezone)
logger.debug(f"Using user timezone: {user_timezone}")
else:
tz = ZoneInfo(format_type.timezone)
logger.debug(f"Using specified timezone: {format_type.timezone}")
return tz
def _format_datetime_iso8601(dt: datetime, include_microseconds: bool = False) -> str:
"""
Format a datetime object to ISO8601 string.
Args:
dt: The datetime object to format
include_microseconds: Whether to include microseconds in the output
Returns:
ISO8601 formatted string
"""
if include_microseconds:
return dt.isoformat()
else:
return dt.isoformat(timespec="seconds")
# BACKWARDS COMPATIBILITY NOTE:
# The timezone field is kept at the format level (not block level) for backwards compatibility.
# Existing graphs have timezone saved within format_type, moving it would break them.
#
# The use_user_timezone flag was added to allow using the user's profile timezone.
# Default is False to maintain backwards compatibility - existing graphs will continue
# using their specified timezone.
#
# KNOWN ISSUE: If a user switches between format types (strftime <-> iso8601),
# the timezone setting doesn't carry over. This is a UX issue but fixing it would
# require either:
# 1. Moving timezone to block level (breaking change, needs migration)
# 2. Complex state management to sync timezone across format types
#
# Future migration path: When we do a major version bump, consider moving timezone
# to the block Input level for better UX.
class TimeStrftimeFormat(BaseModel):
discriminator: Literal["strftime"]
format: str = "%H:%M:%S"
timezone: TimezoneLiteral = "UTC"
# When True, overrides timezone with user's profile timezone
use_user_timezone: bool = False
class TimeISO8601Format(BaseModel):
discriminator: Literal["iso8601"]
timezone: TimezoneLiteral = "UTC"
# When True, overrides timezone with user's profile timezone
use_user_timezone: bool = False
include_microseconds: bool = False
@@ -181,27 +115,25 @@ class GetCurrentTimeBlock(Block):
],
)
async def run(
self, input_data: Input, *, user_context: UserContext, **kwargs
) -> BlockOutput:
# Extract timezone from user_context (always present)
effective_timezone = user_context.timezone
# Get the appropriate timezone
tz = _get_timezone(input_data.format_type, effective_timezone)
dt = datetime.now(tz=tz)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
if isinstance(input_data.format_type, TimeISO8601Format):
# ISO 8601 format for time only (extract time portion from full ISO datetime)
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
# Get the full ISO format and extract just the time portion with timezone
full_iso = _format_datetime_iso8601(
dt, input_data.format_type.include_microseconds
)
if input_data.format_type.include_microseconds:
full_iso = dt.isoformat()
else:
full_iso = dt.isoformat(timespec="seconds")
# Extract time portion (everything after 'T')
current_time = full_iso.split("T")[1] if "T" in full_iso else full_iso
current_time = f"T{current_time}" # Add T prefix for ISO 8601 time format
else: # TimeStrftimeFormat
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
current_time = dt.strftime(input_data.format_type.format)
yield "time", current_time
@@ -209,15 +141,11 @@ class DateStrftimeFormat(BaseModel):
discriminator: Literal["strftime"]
format: str = "%Y-%m-%d"
timezone: TimezoneLiteral = "UTC"
# When True, overrides timezone with user's profile timezone
use_user_timezone: bool = False
class DateISO8601Format(BaseModel):
discriminator: Literal["iso8601"]
timezone: TimezoneLiteral = "UTC"
# When True, overrides timezone with user's profile timezone
use_user_timezone: bool = False
class GetCurrentDateBlock(Block):
@@ -289,23 +217,20 @@ class GetCurrentDateBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Extract timezone from user_context (required keyword argument)
user_context: UserContext = kwargs["user_context"]
effective_timezone = user_context.timezone
try:
offset = int(input_data.offset)
except ValueError:
offset = 0
# Get the appropriate timezone
tz = _get_timezone(input_data.format_type, effective_timezone)
current_date = datetime.now(tz=tz) - timedelta(days=offset)
if isinstance(input_data.format_type, DateISO8601Format):
# ISO 8601 format for date only (YYYY-MM-DD)
tz = ZoneInfo(input_data.format_type.timezone)
current_date = datetime.now(tz=tz) - timedelta(days=offset)
# ISO 8601 date format is YYYY-MM-DD
date_str = current_date.date().isoformat()
else: # DateStrftimeFormat
tz = ZoneInfo(input_data.format_type.timezone)
current_date = datetime.now(tz=tz) - timedelta(days=offset)
date_str = current_date.strftime(input_data.format_type.format)
yield "date", date_str
@@ -315,15 +240,11 @@ class StrftimeFormat(BaseModel):
discriminator: Literal["strftime"]
format: str = "%Y-%m-%d %H:%M:%S"
timezone: TimezoneLiteral = "UTC"
# When True, overrides timezone with user's profile timezone
use_user_timezone: bool = False
class ISO8601Format(BaseModel):
discriminator: Literal["iso8601"]
timezone: TimezoneLiteral = "UTC"
# When True, overrides timezone with user's profile timezone
use_user_timezone: bool = False
include_microseconds: bool = False
@@ -395,22 +316,20 @@ class GetCurrentDateAndTimeBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Extract timezone from user_context (required keyword argument)
user_context: UserContext = kwargs["user_context"]
effective_timezone = user_context.timezone
# Get the appropriate timezone
tz = _get_timezone(input_data.format_type, effective_timezone)
dt = datetime.now(tz=tz)
if isinstance(input_data.format_type, ISO8601Format):
# ISO 8601 format with specified timezone (also RFC3339-compliant)
current_date_time = _format_datetime_iso8601(
dt, input_data.format_type.include_microseconds
)
else: # StrftimeFormat
current_date_time = dt.strftime(input_data.format_type.format)
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
# Format with or without microseconds
if input_data.format_type.include_microseconds:
current_date_time = dt.isoformat()
else:
current_date_time = dt.isoformat(timespec="seconds")
else: # StrftimeFormat
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
current_date_time = dt.strftime(input_data.format_type.format)
yield "date_time", current_date_time

View File

@@ -4,8 +4,6 @@ from logging import getLogger
from typing import Any, Dict, List, Union
from urllib.parse import urlencode
from pydantic import field_serializer
from backend.sdk import BaseModel, Credentials, Requests
logger = getLogger(__name__)
@@ -384,9 +382,8 @@ class CreatePostRequest(BaseModel):
# Advanced
metadata: List[Dict[str, Any]] | None = None
@field_serializer("date")
def serialize_date(self, value: datetime | None) -> str | None:
return value.isoformat() if value else None
class Config:
json_encoders = {datetime: lambda v: v.isoformat()}
class PostAuthor(BaseModel):

View File

@@ -6,6 +6,8 @@ from dotenv import load_dotenv
from backend.util.logging import configure_logging
os.environ["ENABLE_AUTH"] = "false"
load_dotenv()
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs

View File

@@ -307,18 +307,7 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
"type": ideogram_credentials.type,
}
},
),
BlockCost(
cost_amount=18,
cost_filter={
"ideogram_model_name": "V_3",
"credentials": {
"id": ideogram_credentials.id,
"provider": ideogram_credentials.provider,
"type": ideogram_credentials.type,
},
},
),
)
],
AIShortformVideoCreatorBlock: [
BlockCost(

View File

@@ -34,10 +34,10 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.server.model import Pagination
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.retry import func_retry
from backend.util.settings import Settings

View File

@@ -7,7 +7,7 @@ from prisma.models import CreditTransaction
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
from backend.data.execution import NodeExecutionEntry, UserContext
from backend.data.execution import NodeExecutionEntry
from backend.data.user import DEFAULT_USER_ID
from backend.executor.utils import block_usage_cost
from backend.integrations.credentials_store import openai_credentials
@@ -75,7 +75,6 @@ async def test_block_credit_usage(server: SpinTestServer):
"type": openai_credentials.type,
},
},
user_context=UserContext(timezone="UTC"),
),
)
assert spending_amount_1 > 0
@@ -89,7 +88,6 @@ async def test_block_credit_usage(server: SpinTestServer):
node_exec_id="test_node_exec",
block_id=AITextGeneratorBlock().id,
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
user_context=UserContext(timezone="UTC"),
),
)
assert spending_amount_2 == 0

View File

@@ -39,7 +39,6 @@ from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import type as type_utils
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.retry import func_retry
from backend.util.settings import Config
from backend.util.truncate import truncate
@@ -90,7 +89,6 @@ ExecutionStatus = AgentExecutionStatus
class GraphExecutionMeta(BaseDbModel):
id: str # type: ignore # Override base class to make this required
user_id: str
graph_id: str
graph_version: int
@@ -292,14 +290,13 @@ class GraphExecutionWithNodes(GraphExecution):
node_executions=node_executions,
)
def to_graph_execution_entry(self, user_context: "UserContext"):
def to_graph_execution_entry(self):
return GraphExecutionEntry(
user_id=self.user_id,
graph_id=self.graph_id,
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
user_context=user_context,
)
@@ -315,10 +312,9 @@ class NodeExecutionResult(BaseModel):
input_data: BlockInput
output_data: CompletedBlockOutput
add_time: datetime
queue_time: datetime | None = None
start_time: datetime | None = None
end_time: datetime | None = None
stats: NodeExecutionStats | None = None
queue_time: datetime | None
start_time: datetime | None
end_time: datetime | None
@staticmethod
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
@@ -370,12 +366,9 @@ class NodeExecutionResult(BaseModel):
queue_time=_node_exec.queuedTime,
start_time=_node_exec.startedTime,
end_time=_node_exec.endedTime,
stats=stats,
)
def to_node_execution_entry(
self, user_context: "UserContext"
) -> "NodeExecutionEntry":
def to_node_execution_entry(self) -> "NodeExecutionEntry":
return NodeExecutionEntry(
user_id=self.user_id,
graph_exec_id=self.graph_exec_id,
@@ -384,7 +377,6 @@ class NodeExecutionResult(BaseModel):
node_id=self.node_id,
block_id=self.block_id,
inputs=self.input_data,
user_context=user_context,
)
@@ -392,13 +384,13 @@ class NodeExecutionResult(BaseModel):
async def get_graph_executions(
graph_exec_id: Optional[str] = None,
graph_id: Optional[str] = None,
user_id: Optional[str] = None,
statuses: Optional[list[ExecutionStatus]] = None,
created_time_gte: Optional[datetime] = None,
created_time_lte: Optional[datetime] = None,
limit: Optional[int] = None,
graph_exec_id: str | None = None,
graph_id: str | None = None,
user_id: str | None = None,
statuses: list[ExecutionStatus] | None = None,
created_time_gte: datetime | None = None,
created_time_lte: datetime | None = None,
limit: int | None = None,
) -> list[GraphExecutionMeta]:
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
where_filter: AgentGraphExecutionWhereInput = {
@@ -426,60 +418,6 @@ async def get_graph_executions(
return [GraphExecutionMeta.from_db(execution) for execution in executions]
class GraphExecutionsPaginated(BaseModel):
"""Response schema for paginated graph executions."""
executions: list[GraphExecutionMeta]
pagination: Pagination
async def get_graph_executions_paginated(
user_id: str,
graph_id: Optional[str] = None,
page: int = 1,
page_size: int = 25,
statuses: Optional[list[ExecutionStatus]] = None,
created_time_gte: Optional[datetime] = None,
created_time_lte: Optional[datetime] = None,
) -> GraphExecutionsPaginated:
"""Get paginated graph executions for a specific graph."""
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
"userId": user_id,
}
if graph_id:
where_filter["agentGraphId"] = graph_id
if created_time_gte or created_time_lte:
where_filter["createdAt"] = {
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
}
if statuses:
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
total_count = await AgentGraphExecution.prisma().count(where=where_filter)
total_pages = (total_count + page_size - 1) // page_size
offset = (page - 1) * page_size
executions = await AgentGraphExecution.prisma().find_many(
where=where_filter,
order={"createdAt": "desc"},
take=page_size,
skip=offset,
)
return GraphExecutionsPaginated(
executions=[GraphExecutionMeta.from_db(execution) for execution in executions],
pagination=Pagination(
total_items=total_count,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
async def get_graph_execution_meta(
user_id: str, execution_id: str
) -> GraphExecutionMeta | None:
@@ -656,42 +594,6 @@ async def upsert_execution_input(
)
async def create_node_execution(
node_exec_id: str,
node_id: str,
graph_exec_id: str,
input_name: str,
input_data: Any,
) -> None:
"""Create a new node execution with the first input."""
json_input_data = SafeJson(input_data)
await AgentNodeExecution.prisma().create(
data=AgentNodeExecutionCreateInput(
id=node_exec_id,
agentNodeId=node_id,
agentGraphExecutionId=graph_exec_id,
executionStatus=ExecutionStatus.INCOMPLETE,
Input={"create": {"name": input_name, "data": json_input_data}},
)
)
async def add_input_to_node_execution(
node_exec_id: str,
input_name: str,
input_data: Any,
) -> None:
"""Add an input to an existing node execution."""
json_input_data = SafeJson(input_data)
await AgentNodeExecutionInputOutput.prisma().create(
data=AgentNodeExecutionInputOutputCreateInput(
name=input_name,
data=json_input_data,
referencedByInputExecId=node_exec_id,
)
)
async def upsert_execution_output(
node_exec_id: str,
output_name: str,
@@ -915,19 +817,12 @@ async def get_latest_node_execution(
# ----------------- Execution Infrastructure ----------------- #
class UserContext(BaseModel):
"""Generic user context for graph execution containing user-specific settings."""
timezone: str
class GraphExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
graph_version: int
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
user_context: UserContext
class NodeExecutionEntry(BaseModel):
@@ -938,7 +833,6 @@ class NodeExecutionEntry(BaseModel):
node_id: str
block_id: str
inputs: BlockInput
user_context: UserContext
class ExecutionQueue(Generic[T]):

View File

@@ -2,6 +2,7 @@ import json
from typing import Any
from uuid import UUID
import autogpt_libs.auth.models
import fastapi.exceptions
import pytest
from pytest_snapshot.plugin import Snapshot
@@ -316,7 +317,12 @@ async def test_access_store_listing_graph(server: SpinTestServer):
is_approved=True,
comments="Test comments",
),
user_id=admin_user.id,
autogpt_libs.auth.models.User(
user_id=admin_user.id,
role="admin",
email=admin_user.email,
phone_number="1234567890",
),
)
# Now we check the graph can be accessed by a user that does not own the graph

View File

@@ -96,12 +96,6 @@ class User(BaseModel):
default=True, description="Notify on monthly summary"
)
# User timezone for scheduling and time display
timezone: str = Field(
default="not-set",
description="User timezone (IANA timezone identifier or 'not-set')",
)
@classmethod
def from_db(cls, prisma_user: "PrismaUser") -> "User":
"""Convert a database User object to application User model."""
@@ -155,7 +149,6 @@ class User(BaseModel):
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
timezone=prisma_user.timezone or "not-set",
)

View File

@@ -54,6 +54,19 @@ class AgentRunData(BaseNotificationData):
class ZeroBalanceData(BaseNotificationData):
last_transaction: float
last_transaction_time: datetime
top_up_link: str
@field_validator("last_transaction_time")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class LowBalanceData(BaseNotificationData):
agent_name: str = Field(..., description="Name of the agent")
current_balance: float = Field(
..., description="Current balance in credits (100 = $1)"
@@ -62,13 +75,6 @@ class ZeroBalanceData(BaseNotificationData):
shortfall: float = Field(..., description="Amount of credits needed to continue")
class LowBalanceData(BaseNotificationData):
current_balance: float = Field(
..., description="Current balance in credits (100 = $1)"
)
billing_page_link: str = Field(..., description="Link to billing page")
class BlockExecutionFailedData(BaseNotificationData):
block_name: str
block_id: str
@@ -175,42 +181,6 @@ class RefundRequestData(BaseNotificationData):
balance: int
class AgentApprovalData(BaseNotificationData):
agent_name: str
agent_id: str
agent_version: int
reviewer_name: str
reviewer_email: str
comments: str
reviewed_at: datetime
store_url: str
@field_validator("reviewed_at")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class AgentRejectionData(BaseNotificationData):
agent_name: str
agent_id: str
agent_version: int
reviewer_name: str
reviewer_email: str
comments: str
reviewed_at: datetime
resubmit_url: str
@field_validator("reviewed_at")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
NotificationData = Annotated[
Union[
AgentRunData,
@@ -270,8 +240,6 @@ def get_notif_data_type(
NotificationType.MONTHLY_SUMMARY: MonthlySummaryData,
NotificationType.REFUND_REQUEST: RefundRequestData,
NotificationType.REFUND_PROCESSED: RefundRequestData,
NotificationType.AGENT_APPROVED: AgentApprovalData,
NotificationType.AGENT_REJECTED: AgentRejectionData,
}[notification_type]
@@ -306,7 +274,7 @@ class NotificationTypeOverride:
# These are batched by the notification service
NotificationType.AGENT_RUN: QueueType.BATCH,
# These are batched by the notification service, but with a backoff strategy
NotificationType.ZERO_BALANCE: QueueType.IMMEDIATE,
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
@@ -315,8 +283,6 @@ class NotificationTypeOverride:
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
NotificationType.AGENT_APPROVED: QueueType.IMMEDIATE,
NotificationType.AGENT_REJECTED: QueueType.IMMEDIATE,
}
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
@@ -334,8 +300,6 @@ class NotificationTypeOverride:
NotificationType.MONTHLY_SUMMARY: "monthly_summary.html",
NotificationType.REFUND_REQUEST: "refund_request.html",
NotificationType.REFUND_PROCESSED: "refund_processed.html",
NotificationType.AGENT_APPROVED: "agent_approved.html",
NotificationType.AGENT_REJECTED: "agent_rejected.html",
}[self.notification_type]
@property
@@ -351,8 +315,6 @@ class NotificationTypeOverride:
NotificationType.MONTHLY_SUMMARY: "We did a lot this month!",
NotificationType.REFUND_REQUEST: "[ACTION REQUIRED] You got a ${{data.amount / 100}} refund request from {{data.user_name}}",
NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed",
NotificationType.AGENT_APPROVED: "🎉 Your agent '{{data.agent_name}}' has been approved!",
NotificationType.AGENT_REJECTED: "Your agent '{{data.agent_name}}' needs some updates",
}[self.notification_type]

View File

@@ -1,151 +0,0 @@
"""Tests for notification data models."""
from datetime import datetime, timezone
import pytest
from pydantic import ValidationError
from backend.data.notifications import AgentApprovalData, AgentRejectionData
class TestAgentApprovalData:
"""Test cases for AgentApprovalData model."""
def test_valid_agent_approval_data(self):
"""Test creating valid AgentApprovalData."""
data = AgentApprovalData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
reviewer_name="John Doe",
reviewer_email="john@example.com",
comments="Great agent, approved!",
reviewed_at=datetime.now(timezone.utc),
store_url="https://app.autogpt.com/store/test-agent-123",
)
assert data.agent_name == "Test Agent"
assert data.agent_id == "test-agent-123"
assert data.agent_version == 1
assert data.reviewer_name == "John Doe"
assert data.reviewer_email == "john@example.com"
assert data.comments == "Great agent, approved!"
assert data.store_url == "https://app.autogpt.com/store/test-agent-123"
assert data.reviewed_at.tzinfo is not None
def test_agent_approval_data_without_timezone_raises_error(self):
"""Test that AgentApprovalData raises error without timezone."""
with pytest.raises(
ValidationError, match="datetime must have timezone information"
):
AgentApprovalData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
reviewer_name="John Doe",
reviewer_email="john@example.com",
comments="Great agent, approved!",
reviewed_at=datetime.now(), # No timezone
store_url="https://app.autogpt.com/store/test-agent-123",
)
def test_agent_approval_data_with_empty_comments(self):
"""Test AgentApprovalData with empty comments."""
data = AgentApprovalData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
reviewer_name="John Doe",
reviewer_email="john@example.com",
comments="", # Empty comments
reviewed_at=datetime.now(timezone.utc),
store_url="https://app.autogpt.com/store/test-agent-123",
)
assert data.comments == ""
class TestAgentRejectionData:
"""Test cases for AgentRejectionData model."""
def test_valid_agent_rejection_data(self):
"""Test creating valid AgentRejectionData."""
data = AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments="Please fix the security issues before resubmitting.",
reviewed_at=datetime.now(timezone.utc),
resubmit_url="https://app.autogpt.com/build/test-agent-123",
)
assert data.agent_name == "Test Agent"
assert data.agent_id == "test-agent-123"
assert data.agent_version == 1
assert data.reviewer_name == "Jane Doe"
assert data.reviewer_email == "jane@example.com"
assert data.comments == "Please fix the security issues before resubmitting."
assert data.resubmit_url == "https://app.autogpt.com/build/test-agent-123"
assert data.reviewed_at.tzinfo is not None
def test_agent_rejection_data_without_timezone_raises_error(self):
"""Test that AgentRejectionData raises error without timezone."""
with pytest.raises(
ValidationError, match="datetime must have timezone information"
):
AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments="Please fix the security issues.",
reviewed_at=datetime.now(), # No timezone
resubmit_url="https://app.autogpt.com/build/test-agent-123",
)
def test_agent_rejection_data_with_long_comments(self):
"""Test AgentRejectionData with long comments."""
long_comment = "A" * 1000 # Very long comment
data = AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments=long_comment,
reviewed_at=datetime.now(timezone.utc),
resubmit_url="https://app.autogpt.com/build/test-agent-123",
)
assert data.comments == long_comment
def test_model_serialization(self):
"""Test that models can be serialized and deserialized."""
original_data = AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments="Please fix the issues.",
reviewed_at=datetime.now(timezone.utc),
resubmit_url="https://app.autogpt.com/build/test-agent-123",
)
# Serialize to dict
data_dict = original_data.model_dump()
# Deserialize back
restored_data = AgentRejectionData.model_validate(data_dict)
assert restored_data.agent_name == original_data.agent_name
assert restored_data.agent_id == original_data.agent_id
assert restored_data.agent_version == original_data.agent_version
assert restored_data.reviewer_name == original_data.reviewer_name
assert restored_data.reviewer_email == original_data.reviewer_email
assert restored_data.comments == original_data.comments
assert restored_data.reviewed_at == original_data.reviewed_at
assert restored_data.resubmit_url == original_data.resubmit_url

View File

@@ -208,8 +208,6 @@ async def get_user_notification_preference(user_id: str) -> NotificationPreferen
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or False,
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or False,
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or False,
NotificationType.AGENT_APPROVED: user.notifyOnAgentApproved or False,
NotificationType.AGENT_REJECTED: user.notifyOnAgentRejected or False,
}
daily_limit = user.maxEmailsPerDay or 3
notification_preference = NotificationPreference(
@@ -268,14 +266,6 @@ async def update_user_notification_preference(
update_data["notifyOnMonthlySummary"] = data.preferences[
NotificationType.MONTHLY_SUMMARY
]
if NotificationType.AGENT_APPROVED in data.preferences:
update_data["notifyOnAgentApproved"] = data.preferences[
NotificationType.AGENT_APPROVED
]
if NotificationType.AGENT_REJECTED in data.preferences:
update_data["notifyOnAgentRejected"] = data.preferences[
NotificationType.AGENT_REJECTED
]
if data.daily_limit:
update_data["maxEmailsPerDay"] = data.daily_limit
@@ -296,8 +286,6 @@ async def update_user_notification_preference(
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or True,
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or True,
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or True,
NotificationType.AGENT_APPROVED: user.notifyOnAgentApproved or True,
NotificationType.AGENT_REJECTED: user.notifyOnAgentRejected or True,
}
notification_preference = NotificationPreference(
user_id=user.id,
@@ -396,17 +384,3 @@ async def unsubscribe_user_by_token(token: str) -> None:
)
except Exception as e:
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e
async def update_user_timezone(user_id: str, timezone: str) -> User:
"""Update a user's timezone setting."""
try:
user = await PrismaUser.prisma().update(
where={"id": user_id},
data={"timezone": timezone},
)
if not user:
raise ValueError(f"User not found with ID: {user_id}")
return User.from_db(user)
except Exception as e:
raise DatabaseError(f"Failed to update timezone for user {user_id}: {e}") from e

View File

@@ -115,7 +115,7 @@ async def generate_activity_status_for_execution(
# Get all node executions for this graph execution
node_executions = await db_client.get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=True
graph_exec_id, include_exec_data=True
)
# Get graph metadata and full graph structure for name, description, and links

View File

@@ -4,13 +4,12 @@ from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
add_input_to_node_execution,
create_graph_execution,
create_node_execution,
get_block_error_stats,
get_execution_kv_data,
get_graph_execution_meta,
get_graph_executions,
get_latest_node_execution,
get_node_execution,
get_node_executions,
set_execution_kv_data,
@@ -18,6 +17,7 @@ from backend.data.execution import (
update_graph_execution_stats,
update_node_execution_status,
update_node_execution_status_batch,
upsert_execution_input,
upsert_execution_output,
)
from backend.data.generate_data import get_user_execution_summary_data
@@ -42,8 +42,6 @@ from backend.data.user import (
get_user_notification_preference,
update_user_integrations,
)
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
from backend.util.service import (
AppService,
AppServiceClient,
@@ -105,13 +103,13 @@ class DatabaseManager(AppService):
create_graph_execution = _(create_graph_execution)
get_node_execution = _(get_node_execution)
get_node_executions = _(get_node_executions)
get_latest_node_execution = _(get_latest_node_execution)
update_node_execution_status = _(update_node_execution_status)
update_node_execution_status_batch = _(update_node_execution_status_batch)
update_graph_execution_start_time = _(update_graph_execution_start_time)
update_graph_execution_stats = _(update_graph_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
create_node_execution = _(create_node_execution)
add_input_to_node_execution = _(add_input_to_node_execution)
get_execution_kv_data = _(get_execution_kv_data)
set_execution_kv_data = _(set_execution_kv_data)
get_block_error_stats = _(get_block_error_stats)
@@ -147,14 +145,6 @@ class DatabaseManager(AppService):
get_user_notification_oldest_message_in_batch
)
# Library
list_library_agents = _(list_library_agents)
add_store_agent_to_library = _(add_store_agent_to_library)
# Store
get_store_agents = _(get_store_agents)
get_store_agent_details = _(get_store_agent_details)
# Summary data - async
get_user_execution_summary_data = _(get_user_execution_summary_data)
@@ -171,12 +161,10 @@ class DatabaseManagerClient(AppServiceClient):
get_graph_executions = _(d.get_graph_executions)
get_graph_execution_meta = _(d.get_graph_execution_meta)
get_node_executions = _(d.get_node_executions)
create_node_execution = _(d.create_node_execution)
update_node_execution_status = _(d.update_node_execution_status)
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
update_graph_execution_stats = _(d.update_graph_execution_stats)
upsert_execution_output = _(d.upsert_execution_output)
add_input_to_node_execution = _(d.add_input_to_node_execution)
# Graphs
get_graph_metadata = _(d.get_graph_metadata)
@@ -185,12 +173,12 @@ class DatabaseManagerClient(AppServiceClient):
spend_credits = _(d.spend_credits)
get_credits = _(d.get_credits)
# Summary data - async
get_user_execution_summary_data = _(d.get_user_execution_summary_data)
# Block error monitoring
get_block_error_stats = _(d.get_block_error_stats)
# User Emails
get_user_email_by_id = _(d.get_user_email_by_id)
class DatabaseManagerAsyncClient(AppServiceClient):
d = DatabaseManager
@@ -201,12 +189,16 @@ class DatabaseManagerAsyncClient(AppServiceClient):
create_graph_execution = d.create_graph_execution
get_connected_output_nodes = d.get_connected_output_nodes
get_latest_node_execution = d.get_latest_node_execution
get_graph = d.get_graph
get_graph_metadata = d.get_graph_metadata
get_graph_execution_meta = d.get_graph_execution_meta
get_node = d.get_node
get_node_execution = d.get_node_execution
get_node_executions = d.get_node_executions
get_user_integrations = d.get_user_integrations
upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output
update_graph_execution_stats = d.update_graph_execution_stats
update_node_execution_status = d.update_node_execution_status
update_node_execution_status_batch = d.update_node_execution_status_batch
@@ -231,13 +223,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
d.get_user_notification_oldest_message_in_batch
)
# Library
list_library_agents = d.list_library_agents
add_store_agent_to_library = d.add_store_agent_to_library
# Store
get_store_agents = d.get_store_agents
get_store_agent_details = d.get_store_agent_details
# Summary data
get_user_execution_summary_data = d.get_user_execution_summary_data

View File

@@ -1,154 +0,0 @@
import logging
import threading
from collections import OrderedDict
from functools import wraps
from typing import TYPE_CHECKING, Any, Optional
from backend.data.execution import ExecutionStatus, NodeExecutionResult
from backend.data.model import GraphExecutionStats, NodeExecutionStats
if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient
logger = logging.getLogger(__name__)
def with_lock(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self._lock:
return func(self, *args, **kwargs)
return wrapper
class ExecutionCache:
def __init__(self, graph_exec_id: str, db_client: "DatabaseManagerClient"):
self._lock = threading.RLock()
self._graph_exec_id = graph_exec_id
self._graph_stats: GraphExecutionStats = GraphExecutionStats()
self._node_executions: OrderedDict[str, NodeExecutionResult] = OrderedDict()
for execution in db_client.get_node_executions(self._graph_exec_id):
self._node_executions[execution.node_exec_id] = execution
@with_lock
def get_node_execution(self, node_exec_id: str) -> Optional[NodeExecutionResult]:
execution = self._node_executions.get(node_exec_id)
return execution.model_copy(deep=True) if execution else None
@with_lock
def get_latest_node_execution(self, node_id: str) -> Optional[NodeExecutionResult]:
for execution in reversed(self._node_executions.values()):
if (
execution.node_id == node_id
and execution.status != ExecutionStatus.INCOMPLETE
):
return execution.model_copy(deep=True)
return None
@with_lock
def get_node_executions(
self,
*,
statuses: Optional[list] = None,
block_ids: Optional[list] = None,
node_id: Optional[str] = None,
):
results = []
for execution in self._node_executions.values():
if statuses and execution.status not in statuses:
continue
if block_ids and execution.block_id not in block_ids:
continue
if node_id and execution.node_id != node_id:
continue
results.append(execution.model_copy(deep=True))
return results
@with_lock
def update_node_execution_status(
self,
exec_id: str,
status: ExecutionStatus,
execution_data: Optional[dict] = None,
stats: Optional[dict] = None,
):
if exec_id not in self._node_executions:
raise RuntimeError(f"Execution {exec_id} not found in cache")
execution = self._node_executions[exec_id]
execution.status = status
if execution_data:
execution.input_data.update(execution_data)
if stats:
execution.stats = execution.stats or NodeExecutionStats()
current_stats = execution.stats.model_dump()
current_stats.update(stats)
execution.stats = NodeExecutionStats.model_validate(current_stats)
@with_lock
def upsert_execution_output(
self, node_exec_id: str, output_name: str, output_data: Any
) -> NodeExecutionResult:
if node_exec_id not in self._node_executions:
raise RuntimeError(f"Execution {node_exec_id} not found in cache")
execution = self._node_executions[node_exec_id]
if output_name not in execution.output_data:
execution.output_data[output_name] = []
execution.output_data[output_name].append(output_data)
return execution
@with_lock
def update_graph_stats(
self, status: Optional[ExecutionStatus] = None, stats: Optional[dict] = None
):
if status is not None:
pass
if stats is not None:
current_stats = self._graph_stats.model_dump()
current_stats.update(stats)
self._graph_stats = GraphExecutionStats.model_validate(current_stats)
@with_lock
def update_graph_start_time(self):
"""Update graph start time (handled by database persistence)."""
pass
@with_lock
def find_incomplete_execution_for_input(
self, node_id: str, input_name: str
) -> tuple[str, NodeExecutionResult] | None:
for exec_id, execution in self._node_executions.items():
if (
execution.node_id == node_id
and execution.status == ExecutionStatus.INCOMPLETE
and input_name not in execution.input_data
):
return exec_id, execution
return None
@with_lock
def add_node_execution(
self, node_exec_id: str, execution: NodeExecutionResult
) -> None:
self._node_executions[node_exec_id] = execution
@with_lock
def update_execution_input(
self, exec_id: str, input_name: str, input_data: Any
) -> dict:
if exec_id not in self._node_executions:
raise RuntimeError(f"Execution {exec_id} not found in cache")
execution = self._node_executions[exec_id]
execution.input_data[input_name] = input_data
return execution.input_data.copy()
def finalize(self) -> None:
with self._lock:
self._node_executions.clear()
self._graph_stats = GraphExecutionStats()

View File

@@ -1,355 +0,0 @@
"""Test execution creation with proper ID generation and persistence."""
import asyncio
import threading
import uuid
from datetime import datetime
import pytest
from backend.data.execution import ExecutionStatus
from backend.executor.execution_data import ExecutionDataClient
@pytest.fixture
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
async def execution_client_with_mock_db(event_loop):
"""Create an ExecutionDataClient with proper database records."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from prisma.models import AgentGraph, AgentGraphExecution, User
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
# Create test database records to satisfy foreign key constraints
try:
await User.prisma().create(
data={
"id": "test_user_123",
"email": "test@example.com",
"name": "Test User",
}
)
await AgentGraph.prisma().create(
data={
"id": "test_graph_456",
"version": 1,
"userId": "test_user_123",
"name": "Test Graph",
"description": "Test graph for execution tests",
}
)
from prisma.enums import AgentExecutionStatus
await AgentGraphExecution.prisma().create(
data={
"id": "test_graph_exec_id",
"userId": "test_user_123",
"agentGraphId": "test_graph_456",
"agentGraphVersion": 1,
"executionStatus": AgentExecutionStatus.RUNNING,
}
)
except Exception:
# Records might already exist, that's fine
pass
# Mock the graph execution metadata - align with assertions below
mock_graph_meta = GraphExecutionMeta(
id="test_graph_exec_id",
user_id="test_user_123",
graph_id="test_graph_456",
graph_version=1,
status=ExecutionStatus.RUNNING,
started_at=datetime.now(timezone.utc),
ended_at=datetime.now(timezone.utc),
stats=None,
)
# Create client with ThreadPoolExecutor and graph metadata (constructed inside patch)
from concurrent.futures import ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=1)
# Storage for tracking created executions
created_executions = []
async def mock_create_node_execution(
node_exec_id, node_id, graph_exec_id, input_name, input_data
):
"""Mock execution creation that records what was created."""
created_executions.append(
{
"node_exec_id": node_exec_id,
"node_id": node_id,
"graph_exec_id": graph_exec_id,
"input_name": input_name,
"input_data": input_data,
}
)
return node_exec_id
def sync_mock_create_node_execution(
node_exec_id, node_id, graph_exec_id, input_name, input_data
):
"""Mock sync execution creation that records what was created."""
created_executions.append(
{
"node_exec_id": node_exec_id,
"node_id": node_id,
"graph_exec_id": graph_exec_id,
"input_name": input_name,
"input_data": input_data,
}
)
return node_exec_id
# Prepare mock async and sync DB clients
async_mock_client = AsyncMock()
async_mock_client.create_node_execution = mock_create_node_execution
sync_mock_client = MagicMock()
sync_mock_client.create_node_execution = sync_mock_create_node_execution
# Mock graph execution for return values
from backend.data.execution import GraphExecutionMeta
mock_graph_update = GraphExecutionMeta(
id="test_graph_exec_id",
user_id="test_user_123",
graph_id="test_graph_456",
graph_version=1,
status=ExecutionStatus.RUNNING,
started_at=datetime.now(timezone.utc),
ended_at=datetime.now(timezone.utc),
stats=None,
)
# No-ops for other sync methods used by the client during tests
sync_mock_client.add_input_to_node_execution.side_effect = lambda **kwargs: None
sync_mock_client.update_node_execution_status.side_effect = (
lambda *args, **kwargs: None
)
sync_mock_client.upsert_execution_output.side_effect = lambda **kwargs: None
sync_mock_client.update_graph_execution_stats.side_effect = (
lambda *args, **kwargs: mock_graph_update
)
sync_mock_client.update_graph_execution_start_time.side_effect = (
lambda *args, **kwargs: mock_graph_update
)
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
thread.start()
with patch(
"backend.executor.execution_data.get_database_manager_async_client",
return_value=async_mock_client,
), patch(
"backend.executor.execution_data.get_database_manager_client",
return_value=sync_mock_client,
), patch(
"backend.executor.execution_data.get_execution_event_bus"
), patch(
"backend.executor.execution_data.non_blocking_persist", lambda func: func
):
# Now construct the client under the patch so it captures the mocked clients
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
# Store the mocks for the test to access if needed
setattr(client, "_test_async_client", async_mock_client)
setattr(client, "_test_sync_client", sync_mock_client)
setattr(client, "_created_executions", created_executions)
yield client
# Cleanup test database records
try:
await AgentGraphExecution.prisma().delete_many(
where={"id": "test_graph_exec_id"}
)
await AgentGraph.prisma().delete_many(where={"id": "test_graph_456"})
await User.prisma().delete_many(where={"id": "test_user_123"})
except Exception:
# Cleanup may fail if records don't exist
pass
# Cleanup
event_loop.call_soon_threadsafe(event_loop.stop)
thread.join(timeout=1)
class TestExecutionCreation:
"""Test execution creation with proper ID generation and persistence."""
async def test_execution_creation_with_valid_ids(
self, execution_client_with_mock_db
):
"""Test that execution creation generates and persists valid IDs."""
client = execution_client_with_mock_db
node_id = "test_node_789"
input_name = "test_input"
input_data = "test_value"
block_id = "test_block_abc"
# This should trigger execution creation since cache is empty
exec_id, input_dict = client.upsert_execution_input(
node_id=node_id,
input_name=input_name,
input_data=input_data,
block_id=block_id,
)
# Verify execution ID is valid UUID
try:
uuid.UUID(exec_id)
except ValueError:
pytest.fail(f"Generated execution ID '{exec_id}' is not a valid UUID")
# Verify execution was created in cache with complete data
assert exec_id in client._cache._node_executions
cached_execution = client._cache._node_executions[exec_id]
# Check all required fields have valid values
assert cached_execution.user_id == "test_user_123"
assert cached_execution.graph_id == "test_graph_456"
assert cached_execution.graph_version == 1
assert cached_execution.graph_exec_id == "test_graph_exec_id"
assert cached_execution.node_exec_id == exec_id
assert cached_execution.node_id == node_id
assert cached_execution.block_id == block_id
assert cached_execution.status == ExecutionStatus.INCOMPLETE
assert cached_execution.input_data == {input_name: input_data}
assert isinstance(cached_execution.add_time, datetime)
# Verify execution was persisted to database with our generated ID
created_executions = getattr(client, "_created_executions", [])
assert len(created_executions) == 1
created = created_executions[0]
assert created["node_exec_id"] == exec_id # Our generated ID was used
assert created["node_id"] == node_id
assert created["graph_exec_id"] == "test_graph_exec_id"
assert created["input_name"] == input_name
assert created["input_data"] == input_data
# Verify input dict returned correctly
assert input_dict == {input_name: input_data}
async def test_execution_reuse_vs_creation(self, execution_client_with_mock_db):
"""Test that execution reuse works and creation only happens when needed."""
client = execution_client_with_mock_db
node_id = "reuse_test_node"
block_id = "reuse_test_block"
# Create first execution
exec_id_1, input_dict_1 = client.upsert_execution_input(
node_id=node_id,
input_name="input_1",
input_data="value_1",
block_id=block_id,
)
# This should reuse the existing INCOMPLETE execution
exec_id_2, input_dict_2 = client.upsert_execution_input(
node_id=node_id,
input_name="input_2",
input_data="value_2",
block_id=block_id,
)
# Should reuse the same execution
assert exec_id_1 == exec_id_2
assert input_dict_2 == {"input_1": "value_1", "input_2": "value_2"}
# Only one execution should be created in database
created_executions = getattr(client, "_created_executions", [])
assert len(created_executions) == 1
# Verify cache has the merged inputs
cached_execution = client._cache._node_executions[exec_id_1]
assert cached_execution.input_data == {
"input_1": "value_1",
"input_2": "value_2",
}
# Now complete the execution and try to add another input
client.update_node_status_and_publish(
exec_id=exec_id_1, status=ExecutionStatus.COMPLETED
)
# Verify the execution status was actually updated in the cache
updated_execution = client._cache._node_executions[exec_id_1]
assert (
updated_execution.status == ExecutionStatus.COMPLETED
), f"Expected COMPLETED but got {updated_execution.status}"
# This should create a NEW execution since the first is no longer INCOMPLETE
exec_id_3, input_dict_3 = client.upsert_execution_input(
node_id=node_id,
input_name="input_3",
input_data="value_3",
block_id=block_id,
)
# Should be a different execution
assert exec_id_3 != exec_id_1
assert input_dict_3 == {"input_3": "value_3"}
# Verify cache behavior: should have two different executions in cache now
cached_executions = client._cache._node_executions
assert len(cached_executions) == 2
assert exec_id_1 in cached_executions
assert exec_id_3 in cached_executions
# First execution should be COMPLETED
assert cached_executions[exec_id_1].status == ExecutionStatus.COMPLETED
# Third execution should be INCOMPLETE (newly created)
assert cached_executions[exec_id_3].status == ExecutionStatus.INCOMPLETE
async def test_multiple_nodes_get_different_execution_ids(
self, execution_client_with_mock_db
):
"""Test that different nodes get different execution IDs."""
client = execution_client_with_mock_db
# Create executions for different nodes
exec_id_a, _ = client.upsert_execution_input(
node_id="node_a",
input_name="test_input",
input_data="test_value",
block_id="block_a",
)
exec_id_b, _ = client.upsert_execution_input(
node_id="node_b",
input_name="test_input",
input_data="test_value",
block_id="block_b",
)
# Should be different executions with different IDs
assert exec_id_a != exec_id_b
# Both should be valid UUIDs
uuid.UUID(exec_id_a)
uuid.UUID(exec_id_b)
# Both should be in cache
cached_executions = client._cache._node_executions
assert len(cached_executions) == 2
assert exec_id_a in cached_executions
assert exec_id_b in cached_executions
# Both should have correct node IDs
assert cached_executions[exec_id_a].node_id == "node_a"
assert cached_executions[exec_id_b].node_id == "node_b"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,338 +0,0 @@
import logging
import threading
import uuid
from concurrent.futures import Executor
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar, cast
from backend.data import redis_client as redis
from backend.data.credit import UsageTransactionMetadata
from backend.data.execution import (
ExecutionStatus,
GraphExecutionMeta,
NodeExecutionResult,
)
from backend.data.graph import Node
from backend.data.model import GraphExecutionStats
from backend.executor.execution_cache import ExecutionCache
from backend.util.clients import (
get_database_manager_async_client,
get_database_manager_client,
get_execution_event_bus,
)
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
settings = Settings()
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
def non_blocking_persist(func: Callable[P, T]) -> Callable[P, None]:
from functools import wraps
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
# First argument is always self for methods - access through cast for typing
self = cast("ExecutionDataClient", args[0])
future = self._executor.submit(func, *args, **kwargs)
self._pending_tasks.add(future)
return wrapper
class ExecutionDataClient:
def __init__(
self, executor: Executor, graph_exec_id: str, graph_metadata: GraphExecutionMeta
):
self._executor = executor
self._graph_exec_id = graph_exec_id
self._cache = ExecutionCache(graph_exec_id, self.db_client_sync)
self._pending_tasks = set()
self._graph_metadata = graph_metadata
self.graph_lock = threading.RLock()
def finalize_execution(self, timeout: float = 30.0):
logger.info(f"Flushing db writes for execution {self._graph_exec_id}")
exceptions = []
# Wait for all pending database operations to complete
logger.debug(
f"Waiting for {len(self._pending_tasks)} pending database operations"
)
for future in list(self._pending_tasks):
try:
future.result(timeout=timeout)
except Exception as e:
logger.error(f"Background database operation failed: {e}")
exceptions.append(e)
finally:
self._pending_tasks.discard(future)
self._cache.finalize()
if exceptions:
logger.error(f"Background persistence failed with {len(exceptions)} errors")
raise RuntimeError(
f"Background persistence failed with {len(exceptions)} errors: {exceptions}"
)
@property
def db_client_async(self) -> "DatabaseManagerAsyncClient":
return get_database_manager_async_client()
@property
def db_client_sync(self) -> "DatabaseManagerClient":
return get_database_manager_client()
@property
def event_bus(self):
return get_execution_event_bus()
async def get_node(self, node_id: str) -> Node:
return await self.db_client_async.get_node(node_id)
def spend_credits(
self,
user_id: str,
cost: int,
metadata: UsageTransactionMetadata,
) -> int:
return self.db_client_sync.spend_credits(
user_id=user_id, cost=cost, metadata=metadata
)
def get_graph_execution_meta(
self, user_id: str, execution_id: str
) -> GraphExecutionMeta | None:
return self.db_client_sync.get_graph_execution_meta(
user_id=user_id, execution_id=execution_id
)
def get_graph_metadata(
self, graph_id: str, graph_version: int | None = None
) -> Any:
return self.db_client_sync.get_graph_metadata(graph_id, graph_version)
def get_credits(self, user_id: str) -> int:
return self.db_client_sync.get_credits(user_id)
def get_user_email_by_id(self, user_id: str) -> str | None:
return self.db_client_sync.get_user_email_by_id(user_id)
def get_latest_node_execution(self, node_id: str) -> NodeExecutionResult | None:
return self._cache.get_latest_node_execution(node_id)
def get_node_execution(self, node_exec_id: str) -> NodeExecutionResult | None:
return self._cache.get_node_execution(node_exec_id)
def get_node_executions(
self,
*,
node_id: str | None = None,
statuses: list[ExecutionStatus] | None = None,
block_ids: list[str] | None = None,
) -> list[NodeExecutionResult]:
return self._cache.get_node_executions(
statuses=statuses, block_ids=block_ids, node_id=node_id
)
def update_node_status_and_publish(
self,
exec_id: str,
status: ExecutionStatus,
execution_data: dict | None = None,
stats: dict[str, Any] | None = None,
):
self._cache.update_node_execution_status(exec_id, status, execution_data, stats)
self._persist_node_status_to_db(exec_id, status, execution_data, stats)
def upsert_execution_input(
self, node_id: str, input_name: str, input_data: Any, block_id: str
) -> tuple[str, dict]:
# Validate input parameters to prevent foreign key constraint errors
if not node_id or not isinstance(node_id, str):
raise ValueError(f"Invalid node_id: {node_id}")
if not self._graph_exec_id or not isinstance(self._graph_exec_id, str):
raise ValueError(f"Invalid graph_exec_id: {self._graph_exec_id}")
if not block_id or not isinstance(block_id, str):
raise ValueError(f"Invalid block_id: {block_id}")
# UPDATE: Try to find an existing incomplete execution for this node and input
if result := self._cache.find_incomplete_execution_for_input(
node_id, input_name
):
exec_id, _ = result
updated_input_data = self._cache.update_execution_input(
exec_id, input_name, input_data
)
self._persist_add_input_to_db(exec_id, input_name, input_data)
return exec_id, updated_input_data
# CREATE: No suitable execution found, create new one
node_exec_id = str(uuid.uuid4())
logger.debug(
f"Creating new execution {node_exec_id} for node {node_id} "
f"in graph execution {self._graph_exec_id}"
)
new_execution = NodeExecutionResult(
user_id=self._graph_metadata.user_id,
graph_id=self._graph_metadata.graph_id,
graph_version=self._graph_metadata.graph_version,
graph_exec_id=self._graph_exec_id,
node_exec_id=node_exec_id,
node_id=node_id,
block_id=block_id,
status=ExecutionStatus.INCOMPLETE,
input_data={input_name: input_data},
output_data={},
add_time=datetime.now(timezone.utc),
)
self._cache.add_node_execution(node_exec_id, new_execution)
self._persist_new_node_execution_to_db(
node_exec_id, node_id, input_name, input_data
)
return node_exec_id, {input_name: input_data}
def upsert_execution_output(
self, node_exec_id: str, output_name: str, output_data: Any
):
self._cache.upsert_execution_output(node_exec_id, output_name, output_data)
self._persist_execution_output_to_db(node_exec_id, output_name, output_data)
def update_graph_stats_and_publish(
self,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
) -> None:
stats_dict = stats.model_dump() if stats else None
self._cache.update_graph_stats(status=status, stats=stats_dict)
self._persist_graph_stats_to_db(status=status, stats=stats)
def update_graph_start_time_and_publish(self) -> None:
self._cache.update_graph_start_time()
self._persist_graph_start_time_to_db()
@non_blocking_persist
def _persist_node_status_to_db(
self,
exec_id: str,
status: ExecutionStatus,
execution_data: dict | None = None,
stats: dict[str, Any] | None = None,
):
exec_update = self.db_client_sync.update_node_execution_status(
exec_id, status, execution_data, stats
)
self.event_bus.publish(exec_update)
@non_blocking_persist
def _persist_add_input_to_db(
self, node_exec_id: str, input_name: str, input_data: Any
):
self.db_client_sync.add_input_to_node_execution(
node_exec_id=node_exec_id,
input_name=input_name,
input_data=input_data,
)
@non_blocking_persist
def _persist_execution_output_to_db(
self, node_exec_id: str, output_name: str, output_data: Any
):
self.db_client_sync.upsert_execution_output(
node_exec_id=node_exec_id,
output_name=output_name,
output_data=output_data,
)
if exec_update := self.get_node_execution(node_exec_id):
self.event_bus.publish(exec_update)
@non_blocking_persist
def _persist_graph_stats_to_db(
self,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
):
graph_update = self.db_client_sync.update_graph_execution_stats(
self._graph_exec_id, status, stats
)
if not graph_update:
raise RuntimeError(
f"Failed to update graph execution stats for {self._graph_exec_id}"
)
self.event_bus.publish(graph_update)
@non_blocking_persist
def _persist_graph_start_time_to_db(self):
graph_update = self.db_client_sync.update_graph_execution_start_time(
self._graph_exec_id
)
if not graph_update:
raise RuntimeError(
f"Failed to update graph execution start time for {self._graph_exec_id}"
)
self.event_bus.publish(graph_update)
async def generate_activity_status(
self,
graph_id: str,
graph_version: int,
execution_stats: GraphExecutionStats,
user_id: str,
execution_status: ExecutionStatus,
) -> str | None:
from backend.executor.activity_status_generator import (
generate_activity_status_for_execution,
)
return await generate_activity_status_for_execution(
graph_exec_id=self._graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_stats=execution_stats,
db_client=self.db_client_async,
user_id=user_id,
execution_status=execution_status,
)
@non_blocking_persist
def _send_execution_update(self, execution: NodeExecutionResult):
"""Send execution update to event bus."""
try:
self.event_bus.publish(execution)
except Exception as e:
logger.warning(f"Failed to send execution update: {e}")
@non_blocking_persist
def _persist_new_node_execution_to_db(
self, node_exec_id: str, node_id: str, input_name: str, input_data: Any
):
try:
self.db_client_sync.create_node_execution(
node_exec_id=node_exec_id,
node_id=node_id,
graph_exec_id=self._graph_exec_id,
input_name=input_name,
input_data=input_data,
)
except Exception as e:
logger.error(
f"Failed to create node execution {node_exec_id} for node {node_id} "
f"in graph execution {self._graph_exec_id}: {e}"
)
raise
def increment_execution_count(self, user_id: str) -> int:
r = redis.get_redis()
k = f"uec:{user_id}"
counter = cast(int, r.incr(k))
if counter == 1:
r.expire(k, settings.config.execution_counter_expiration_time)
return counter

View File

@@ -1,668 +0,0 @@
"""Test suite for ExecutionDataClient."""
import asyncio
import threading
import pytest
from backend.data.execution import ExecutionStatus
from backend.executor.execution_data import ExecutionDataClient
@pytest.fixture
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def execution_client(event_loop):
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
mock_graph_meta = GraphExecutionMeta(
id="test_graph_exec_id",
user_id="test_user_id",
graph_id="test_graph_id",
graph_version=1,
status=ExecutionStatus.RUNNING,
started_at=datetime.now(timezone.utc),
ended_at=datetime.now(timezone.utc),
stats=None,
)
from concurrent.futures import ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=1)
# Mock all database operations to prevent connection attempts
async_mock_client = AsyncMock()
sync_mock_client = MagicMock()
# Mock all database methods to return None or empty results
sync_mock_client.get_node_executions.return_value = []
sync_mock_client.create_node_execution.return_value = None
sync_mock_client.add_input_to_node_execution.return_value = None
sync_mock_client.update_node_execution_status.return_value = None
sync_mock_client.upsert_execution_output.return_value = None
sync_mock_client.update_graph_execution_stats.return_value = mock_graph_meta
sync_mock_client.update_graph_execution_start_time.return_value = mock_graph_meta
# Mock event bus to prevent connection attempts
mock_event_bus = MagicMock()
mock_event_bus.publish.return_value = None
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
thread.start()
with patch(
"backend.executor.execution_data.get_database_manager_async_client",
return_value=async_mock_client,
), patch(
"backend.executor.execution_data.get_database_manager_client",
return_value=sync_mock_client,
), patch(
"backend.executor.execution_data.get_execution_event_bus",
return_value=mock_event_bus,
), patch(
"backend.executor.execution_data.non_blocking_persist", lambda func: func
):
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
yield client
event_loop.call_soon_threadsafe(event_loop.stop)
thread.join(timeout=1)
class TestExecutionDataClient:
async def test_update_node_status_writes_to_cache_immediately(
self, execution_client
):
"""Test that node status updates are immediately visible in cache."""
# First create an execution to update
node_exec_id, _ = execution_client.upsert_execution_input(
node_id="test-node",
input_name="test_input",
input_data="test_value",
block_id="test-block",
)
status = ExecutionStatus.RUNNING
execution_data = {"step": "processing"}
stats = {"duration": 5.2}
# Update status of existing execution
execution_client.update_node_status_and_publish(
exec_id=node_exec_id,
status=status,
execution_data=execution_data,
stats=stats,
)
# Verify immediate visibility in cache
cached_exec = execution_client.get_node_execution(node_exec_id)
assert cached_exec is not None
assert cached_exec.status == status
# execution_data should be merged with existing input_data, not replace it
expected_input_data = {"test_input": "test_value", "step": "processing"}
assert cached_exec.input_data == expected_input_data
def test_update_node_status_execution_not_found_raises_error(
self, execution_client
):
"""Test that updating non-existent execution raises error instead of creating it."""
non_existent_id = "does-not-exist"
with pytest.raises(
RuntimeError, match="Execution does-not-exist not found in cache"
):
execution_client.update_node_status_and_publish(
exec_id=non_existent_id, status=ExecutionStatus.COMPLETED
)
async def test_upsert_execution_output_writes_to_cache_immediately(
self, execution_client
):
"""Test that output updates are immediately visible in cache."""
# First create an execution
node_exec_id, _ = execution_client.upsert_execution_input(
node_id="test-node",
input_name="test_input",
input_data="test_value",
block_id="test-block",
)
output_name = "result"
output_data = {"answer": 42, "confidence": 0.95}
# Update to RUNNING status first
execution_client.update_node_status_and_publish(
exec_id=node_exec_id,
status=ExecutionStatus.RUNNING,
execution_data={"input": "test"},
)
execution_client.upsert_execution_output(
node_exec_id=node_exec_id, output_name=output_name, output_data=output_data
)
# Check output through the node execution
cached_exec = execution_client.get_node_execution(node_exec_id)
assert cached_exec is not None
assert output_name in cached_exec.output_data
assert cached_exec.output_data[output_name] == [output_data]
async def test_get_node_execution_reads_from_cache(self, execution_client):
"""Test that get_node_execution returns cached data immediately."""
# First create an execution to work with
node_exec_id, _ = execution_client.upsert_execution_input(
node_id="test-node",
input_name="test_input",
input_data="test_value",
block_id="test-block",
)
# Then update its status
execution_client.update_node_status_and_publish(
exec_id=node_exec_id,
status=ExecutionStatus.COMPLETED,
execution_data={"result": "success"},
)
result = execution_client.get_node_execution(node_exec_id)
assert result is not None
assert result.status == ExecutionStatus.COMPLETED
# execution_data gets merged with existing input_data
expected_input_data = {"test_input": "test_value", "result": "success"}
assert result.input_data == expected_input_data
async def test_get_latest_node_execution_reads_from_cache(self, execution_client):
"""Test that get_latest_node_execution returns cached data."""
node_id = "node-1"
# First create an execution for this node
node_exec_id, _ = execution_client.upsert_execution_input(
node_id=node_id,
input_name="test_input",
input_data="test_value",
block_id="test-block",
)
# Then update its status to make it non-INCOMPLETE (so it's returned by get_latest)
execution_client.update_node_status_and_publish(
exec_id=node_exec_id,
status=ExecutionStatus.RUNNING,
execution_data={"from": "cache"},
)
result = execution_client.get_latest_node_execution(node_id)
assert result is not None
assert result.status == ExecutionStatus.RUNNING
# execution_data gets merged with existing input_data
expected_input_data = {"test_input": "test_value", "from": "cache"}
assert result.input_data == expected_input_data
async def test_get_node_executions_sync_filters_correctly(self, execution_client):
# Create executions with different statuses
executions = [
(ExecutionStatus.RUNNING, "block-a"),
(ExecutionStatus.COMPLETED, "block-a"),
(ExecutionStatus.FAILED, "block-b"),
(ExecutionStatus.RUNNING, "block-b"),
]
exec_ids = []
for i, (status, block_id) in enumerate(executions):
# First create the execution
exec_id, _ = execution_client.upsert_execution_input(
node_id=f"node-{i}",
input_name="test_input",
input_data="test_value",
block_id=block_id,
)
exec_ids.append(exec_id)
# Then update its status and metadata
execution_client.update_node_status_and_publish(
exec_id=exec_id, status=status, execution_data={"block": block_id}
)
# Update cached execution with graph_exec_id and block_id for filtering
# Note: In real implementation, these would be set during creation
# For test purposes, we'll skip this manual update since the filtering
# logic should work with the data as created
# Test status filtering
running_execs = execution_client.get_node_executions(
statuses=[ExecutionStatus.RUNNING]
)
assert len(running_execs) == 2
assert all(e.status == ExecutionStatus.RUNNING for e in running_execs)
# Test block_id filtering
block_a_execs = execution_client.get_node_executions(block_ids=["block-a"])
assert len(block_a_execs) == 2
assert all(e.block_id == "block-a" for e in block_a_execs)
# Test combined filtering
running_block_b = execution_client.get_node_executions(
statuses=[ExecutionStatus.RUNNING], block_ids=["block-b"]
)
assert len(running_block_b) == 1
assert running_block_b[0].status == ExecutionStatus.RUNNING
assert running_block_b[0].block_id == "block-b"
async def test_write_then_read_consistency(self, execution_client):
"""Test critical race condition scenario: immediate read after write."""
# First create an execution to work with
node_exec_id, _ = execution_client.upsert_execution_input(
node_id="consistency-test-node",
input_name="test_input",
input_data="test_value",
block_id="test-block",
)
# Write status
execution_client.update_node_status_and_publish(
exec_id=node_exec_id,
status=ExecutionStatus.RUNNING,
execution_data={"step": 1},
)
# Write output
execution_client.upsert_execution_output(
node_exec_id=node_exec_id,
output_name="intermediate",
output_data={"progress": 50},
)
# Update status again
execution_client.update_node_status_and_publish(
exec_id=node_exec_id,
status=ExecutionStatus.COMPLETED,
execution_data={"step": 2},
)
# All changes should be immediately visible
cached_exec = execution_client.get_node_execution(node_exec_id)
assert cached_exec is not None
assert cached_exec.status == ExecutionStatus.COMPLETED
# execution_data gets merged with existing input_data - step 2 overwrites step 1
expected_input_data = {"test_input": "test_value", "step": 2}
assert cached_exec.input_data == expected_input_data
# Output should be visible in execution record
assert cached_exec.output_data["intermediate"] == [{"progress": 50}]
async def test_concurrent_operations_are_thread_safe(self, execution_client):
"""Test that concurrent operations don't corrupt cache."""
num_threads = 3 # Reduced for simpler test
operations_per_thread = 5 # Reduced for simpler test
# Create all executions upfront
created_exec_ids = []
for thread_id in range(num_threads):
for i in range(operations_per_thread):
exec_id, _ = execution_client.upsert_execution_input(
node_id=f"node-{thread_id}-{i}",
input_name="test_input",
input_data="test_value",
block_id=f"block-{thread_id}-{i}",
)
created_exec_ids.append((exec_id, thread_id, i))
def worker(thread_data):
"""Perform multiple operations from a thread."""
thread_id, ops = thread_data
for i, (exec_id, _, _) in enumerate(ops):
# Status updates
execution_client.update_node_status_and_publish(
exec_id=exec_id,
status=ExecutionStatus.RUNNING,
execution_data={"thread": thread_id, "op": i},
)
# Output updates (use just one exec_id per thread for outputs)
if i == 0: # Only add outputs to first execution of each thread
execution_client.upsert_execution_output(
node_exec_id=exec_id,
output_name=f"output_{i}",
output_data={"thread": thread_id, "value": i},
)
# Organize executions by thread
thread_data = []
for tid in range(num_threads):
thread_ops = [
exec_data for exec_data in created_exec_ids if exec_data[1] == tid
]
thread_data.append((tid, thread_ops))
# Start multiple threads
threads = []
for data in thread_data:
thread = threading.Thread(target=worker, args=(data,))
threads.append(thread)
thread.start()
# Wait for completion
for thread in threads:
thread.join()
# Verify data integrity
expected_executions = num_threads * operations_per_thread
all_executions = execution_client.get_node_executions()
assert len(all_executions) == expected_executions
# Verify outputs - only first execution of each thread should have outputs
output_count = 0
for execution in all_executions:
if execution.output_data:
output_count += 1
assert output_count == num_threads # One output per thread
async def test_sync_and_async_versions_consistent(self, execution_client):
"""Test that sync and async versions of output operations behave the same."""
# First create the execution
node_exec_id, _ = execution_client.upsert_execution_input(
node_id="sync-async-test-node",
input_name="test_input",
input_data="test_value",
block_id="test-block",
)
execution_client.update_node_status_and_publish(
exec_id=node_exec_id,
status=ExecutionStatus.RUNNING,
execution_data={"input": "test"},
)
execution_client.upsert_execution_output(
node_exec_id=node_exec_id,
output_name="sync_result",
output_data={"method": "sync"},
)
execution_client.upsert_execution_output(
node_exec_id=node_exec_id,
output_name="async_result",
output_data={"method": "async"},
)
cached_exec = execution_client.get_node_execution(node_exec_id)
assert cached_exec is not None
assert "sync_result" in cached_exec.output_data
assert "async_result" in cached_exec.output_data
assert cached_exec.output_data["sync_result"] == [{"method": "sync"}]
assert cached_exec.output_data["async_result"] == [{"method": "async"}]
async def test_finalize_execution_completes_and_clears_cache(
self, execution_client
):
"""Test that finalize_execution waits for background tasks and clears cache."""
# First create the execution
node_exec_id, _ = execution_client.upsert_execution_input(
node_id="pending-test-node",
input_name="test_input",
input_data="test_value",
block_id="test-block",
)
# Trigger some background operations
execution_client.update_node_status_and_publish(
exec_id=node_exec_id, status=ExecutionStatus.RUNNING
)
execution_client.upsert_execution_output(
node_exec_id=node_exec_id, output_name="test", output_data={"value": 1}
)
# Wait for background tasks - may fail in test environment due to DB issues
try:
execution_client.finalize_execution(timeout=5.0)
except RuntimeError as e:
# In test environment, background DB operations may fail, but cache should still be cleared
assert "Background persistence failed" in str(e)
# Cache should be cleared regardless of background task failures
all_executions = execution_client.get_node_executions()
assert len(all_executions) == 0 # Cache should be cleared
async def test_manager_usage_pattern(self, execution_client):
# Create executions first
node_exec_id_1, _ = execution_client.upsert_execution_input(
node_id="node-1",
input_name="input1",
input_data="data1",
block_id="block-1",
)
node_exec_id_2, _ = execution_client.upsert_execution_input(
node_id="node-2",
input_name="input_from_node1",
input_data="value1",
block_id="block-2",
)
# Simulate manager.py workflow
# 1. Start execution
execution_client.update_node_status_and_publish(
exec_id=node_exec_id_1,
status=ExecutionStatus.RUNNING,
execution_data={"input": "data1"},
)
# 2. Node produces output
execution_client.upsert_execution_output(
node_exec_id=node_exec_id_1,
output_name="result",
output_data={"output": "value1"},
)
# 3. Complete first node
execution_client.update_node_status_and_publish(
exec_id=node_exec_id_1, status=ExecutionStatus.COMPLETED
)
# 4. Start second node (would read output from first)
execution_client.update_node_status_and_publish(
exec_id=node_exec_id_2,
status=ExecutionStatus.RUNNING,
execution_data={"input_from_node1": "value1"},
)
# 5. Manager queries for executions
all_executions = execution_client.get_node_executions()
running_executions = execution_client.get_node_executions(
statuses=[ExecutionStatus.RUNNING]
)
completed_executions = execution_client.get_node_executions(
statuses=[ExecutionStatus.COMPLETED]
)
# Verify manager can see all data immediately
assert len(all_executions) == 2
assert len(running_executions) == 1
assert len(completed_executions) == 1
# Verify output is accessible
exec_1 = execution_client.get_node_execution(node_exec_id_1)
assert exec_1 is not None
assert exec_1.output_data["result"] == [{"output": "value1"}]
def test_stats_handling_in_update_node_status(self, execution_client):
"""Test that stats parameter is properly handled in update_node_status_and_publish."""
# Create a fake execution directly in cache to avoid database issues
from datetime import datetime, timezone
from backend.data.execution import NodeExecutionResult
node_exec_id = "test-stats-exec-id"
fake_execution = NodeExecutionResult(
user_id="test-user",
graph_id="test-graph",
graph_version=1,
graph_exec_id="test-graph-exec",
node_exec_id=node_exec_id,
node_id="stats-test-node",
block_id="test-block",
status=ExecutionStatus.INCOMPLETE,
input_data={"test_input": "test_value"},
output_data={},
add_time=datetime.now(timezone.utc),
queue_time=None,
start_time=None,
end_time=None,
stats=None,
)
# Add directly to cache
execution_client._cache.add_node_execution(node_exec_id, fake_execution)
stats = {"token_count": 150, "processing_time": 2.5}
# Update status with stats
execution_client.update_node_status_and_publish(
exec_id=node_exec_id,
status=ExecutionStatus.RUNNING,
execution_data={"input": "test"},
stats=stats,
)
# Verify execution was updated and stats are stored
execution = execution_client.get_node_execution(node_exec_id)
assert execution is not None
assert execution.status == ExecutionStatus.RUNNING
# Stats should be stored in proper stats field
assert execution.stats is not None
stats_dict = execution.stats.model_dump()
# Only check the fields we set, ignore defaults
assert stats_dict["token_count"] == 150
assert stats_dict["processing_time"] == 2.5
# Update with additional stats
additional_stats = {"error_count": 0}
execution_client.update_node_status_and_publish(
exec_id=node_exec_id,
status=ExecutionStatus.COMPLETED,
stats=additional_stats,
)
# Stats should be merged
execution = execution_client.get_node_execution(node_exec_id)
assert execution is not None
assert execution.status == ExecutionStatus.COMPLETED
stats_dict = execution.stats.model_dump()
# Check the merged stats
assert stats_dict["token_count"] == 150
assert stats_dict["processing_time"] == 2.5
assert stats_dict["error_count"] == 0
async def test_upsert_execution_input_scenarios(self, execution_client):
"""Test different scenarios of upsert_execution_input - create vs update."""
node_id = "test-node"
graph_exec_id = (
"test_graph_exec_id" # Must match the ExecutionDataClient's scope
)
# Scenario 1: Create new execution when none exists
exec_id_1, input_data_1 = execution_client.upsert_execution_input(
node_id=node_id,
input_name="first_input",
input_data="value1",
block_id="test-block",
)
# Should create new execution
execution = execution_client.get_node_execution(exec_id_1)
assert execution is not None
assert execution.status == ExecutionStatus.INCOMPLETE
assert execution.node_id == node_id
assert execution.graph_exec_id == graph_exec_id
assert input_data_1 == {"first_input": "value1"}
# Scenario 2: Add input to existing INCOMPLETE execution
exec_id_2, input_data_2 = execution_client.upsert_execution_input(
node_id=node_id,
input_name="second_input",
input_data="value2",
block_id="test-block",
)
# Should use same execution
assert exec_id_2 == exec_id_1
assert input_data_2 == {"first_input": "value1", "second_input": "value2"}
# Verify execution has both inputs
execution = execution_client.get_node_execution(exec_id_1)
assert execution is not None
assert execution.input_data == {
"first_input": "value1",
"second_input": "value2",
}
# Scenario 3: Create new execution when existing is not INCOMPLETE
execution_client.update_node_status_and_publish(
exec_id=exec_id_1, status=ExecutionStatus.RUNNING
)
exec_id_3, input_data_3 = execution_client.upsert_execution_input(
node_id=node_id,
input_name="third_input",
input_data="value3",
block_id="test-block",
)
# Should create new execution
assert exec_id_3 != exec_id_1
execution_3 = execution_client.get_node_execution(exec_id_3)
assert execution_3 is not None
assert input_data_3 == {"third_input": "value3"}
# Verify we now have 2 executions
all_executions = execution_client.get_node_executions()
assert len(all_executions) == 2
def test_graph_stats_operations(self, execution_client):
"""Test graph-level stats and start time operations."""
# Test update_graph_stats_and_publish
from backend.data.model import GraphExecutionStats
stats = GraphExecutionStats(
walltime=10.5, cputime=8.2, node_count=5, node_error_count=1
)
execution_client.update_graph_stats_and_publish(
status=ExecutionStatus.RUNNING, stats=stats
)
# Verify stats are stored in cache
cached_stats = execution_client._cache._graph_stats
assert cached_stats.walltime == 10.5
execution_client.update_graph_start_time_and_publish()
cached_stats = execution_client._cache._graph_stats
assert cached_stats.walltime == 10.5
def test_public_methods_accessible(self, execution_client):
"""Test that public methods are accessible."""
assert hasattr(execution_client._cache, "update_node_execution_status")
assert hasattr(execution_client._cache, "upsert_execution_output")
assert hasattr(execution_client._cache, "add_node_execution")
assert hasattr(execution_client._cache, "find_incomplete_execution_for_input")
assert hasattr(execution_client._cache, "update_execution_input")
assert hasattr(execution_client, "upsert_execution_input")
assert hasattr(execution_client, "update_node_status_and_publish")
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -5,15 +5,37 @@ import threading
import time
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Optional, TypeVar, cast
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from prometheus_client import Gauge, start_http_server
from pydantic import JsonValue
from redis.asyncio.lock import Lock as RedisLock
from backend.blocks.io import AgentOutputBlock
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventModel,
NotificationType,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.activity_status_generator import (
generate_activity_status_for_execution,
)
from backend.executor.utils import LogMetadata
from backend.notifications.notifications import queue_notification
from backend.util.exceptions import InsufficientBalanceError, ModerationError
if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
from prometheus_client import Gauge, start_http_server
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentOutputBlock
from backend.data import redis_client as redis
from backend.data.block import (
BlockData,
BlockInput,
@@ -25,28 +47,18 @@ from backend.data.credit import UsageTransactionMetadata
from backend.data.execution import (
ExecutionQueue,
ExecutionStatus,
GraphExecution,
GraphExecutionEntry,
NodeExecutionEntry,
UserContext,
NodeExecutionResult,
)
from backend.data.graph import Link, Node
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventModel,
NotificationType,
ZeroBalanceData,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.execution_data import ExecutionDataClient
from backend.executor.utils import (
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_QUEUE_NAME,
CancelExecutionEvent,
ExecutionOutputEntry,
LogMetadata,
NodeExecutionProgress,
block_usage_cost,
create_execution_queue_config,
@@ -55,20 +67,22 @@ from backend.executor.utils import (
validate_exec,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.notifications.notifications import queue_notification
from backend.server.v2.AutoMod.manager import automod_manager
from backend.util import json
from backend.util.clients import get_notification_manager_client
from backend.util.clients import (
get_async_execution_event_bus,
get_database_manager_async_client,
get_database_manager_client,
get_execution_event_bus,
)
from backend.util.decorator import (
async_error_logged,
async_time_measured,
error_logged,
time_measured,
)
from backend.util.exceptions import InsufficientBalanceError, ModerationError
from backend.util.file import clean_exec_files
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.metrics import DiscordChannel
from backend.util.process import AppProcess, set_service_name
from backend.util.retry import continuous_retry, func_retry
from backend.util.settings import Settings
@@ -120,6 +134,7 @@ async def execute_node(
persist the execution result, and return the subsequent node to be executed.
Args:
db_client: The client to send execution updates to the server.
creds_manager: The manager to acquire and release credentials.
data: The execution data for executing the current node.
execution_stats: The execution statistics to be updated.
@@ -175,9 +190,6 @@ async def execute_node(
"user_id": user_id,
}
# Add user context from NodeExecutionEntry
extra_exec_kwargs["user_context"] = data.user_context
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
# changes during execution. ⚠️ This means a set of credentials can only be used by
# one (running) block at a time; simultaneous execution of blocks using same
@@ -216,7 +228,7 @@ async def execute_node(
async def _enqueue_next_nodes(
execution_data_client: ExecutionDataClient,
db_client: "DatabaseManagerAsyncClient",
node: Node,
output: BlockData,
user_id: str,
@@ -224,12 +236,12 @@ async def _enqueue_next_nodes(
graph_id: str,
log_metadata: LogMetadata,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
user_context: UserContext,
) -> list[NodeExecutionEntry]:
async def add_enqueued_execution(
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
) -> NodeExecutionEntry:
execution_data_client.update_node_status_and_publish(
await async_update_node_execution_status(
db_client=db_client,
exec_id=node_exec_id,
status=ExecutionStatus.QUEUED,
execution_data=data,
@@ -242,7 +254,6 @@ async def _enqueue_next_nodes(
node_id=node_id,
block_id=block_id,
inputs=data,
user_context=user_context,
)
async def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
@@ -262,22 +273,21 @@ async def _enqueue_next_nodes(
next_data = parse_execution_output(output, next_output_name)
if next_data is None and output_name != next_output_name:
return enqueued_executions
next_node = await execution_data_client.get_node(next_node_id)
next_node = await db_client.get_node(next_node_id)
# Multiple node can register the same next node, we need this to be atomic
# To avoid same execution to be enqueued multiple times,
# Or the same input to be consumed multiple times.
with execution_data_client.graph_lock:
async with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
# Add output data to the earliest incomplete execution, or create a new one.
next_node_exec_id, next_node_input = (
execution_data_client.upsert_execution_input(
node_id=next_node_id,
input_name=next_input_name,
input_data=next_data,
block_id=next_node.block_id,
)
next_node_exec_id, next_node_input = await db_client.upsert_execution_input(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
input_name=next_input_name,
input_data=next_data,
)
execution_data_client.update_node_status_and_publish(
await async_update_node_execution_status(
db_client=db_client,
exec_id=next_node_exec_id,
status=ExecutionStatus.INCOMPLETE,
)
@@ -289,8 +299,8 @@ async def _enqueue_next_nodes(
if link.is_static and link.sink_name not in next_node_input
}
if static_link_names and (
latest_execution := execution_data_client.get_latest_node_execution(
next_node_id
latest_execution := await db_client.get_latest_node_execution(
next_node_id, graph_exec_id
)
):
for name in static_link_names:
@@ -329,8 +339,9 @@ async def _enqueue_next_nodes(
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
for iexec in execution_data_client.get_node_executions(
for iexec in await db_client.get_node_executions(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
statuses=[ExecutionStatus.INCOMPLETE],
):
idata = iexec.input_data
@@ -394,9 +405,6 @@ class ExecutionProcessor:
9. Node executor enqueues the next executed nodes to the node execution queue.
"""
# Current execution data client (scoped to current graph execution)
execution_data: ExecutionDataClient
@async_error_logged(swallow=True)
async def on_node_execution(
self,
@@ -414,7 +422,8 @@ class ExecutionProcessor:
node_id=node_exec.node_id,
block_name="-",
)
node = await self.execution_data.get_node(node_exec.node_id)
db_client = get_db_async_client()
node = await db_client.get_node(node_exec.node_id)
execution_stats = NodeExecutionStats()
timing_info, status = await self._on_node_execution(
@@ -422,6 +431,7 @@ class ExecutionProcessor:
node_exec=node_exec,
node_exec_progress=node_exec_progress,
stats=execution_stats,
db_client=db_client,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
)
@@ -445,12 +455,15 @@ class ExecutionProcessor:
if node_error and not isinstance(node_error, str):
node_stats["error"] = str(node_error) or node_stats.__class__.__name__
self.execution_data.update_node_status_and_publish(
await async_update_node_execution_status(
db_client=db_client,
exec_id=node_exec.node_exec_id,
status=status,
stats=node_stats,
)
self.execution_data.update_graph_stats_and_publish(
await async_update_graph_execution_state(
db_client=db_client,
graph_exec_id=node_exec.graph_exec_id,
stats=graph_stats,
)
@@ -463,17 +476,22 @@ class ExecutionProcessor:
node_exec: NodeExecutionEntry,
node_exec_progress: NodeExecutionProgress,
stats: NodeExecutionStats,
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> ExecutionStatus:
status = ExecutionStatus.RUNNING
async def persist_output(output_name: str, output_data: Any) -> None:
self.execution_data.upsert_execution_output(
await db_client.upsert_execution_output(
node_exec_id=node_exec.node_exec_id,
output_name=output_name,
output_data=output_data,
)
if exec_update := await db_client.get_node_execution(
node_exec.node_exec_id
):
await send_async_execution_update(exec_update)
node_exec_progress.add_output(
ExecutionOutputEntry(
@@ -485,7 +503,8 @@ class ExecutionProcessor:
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
self.execution_data.update_node_status_and_publish(
await async_update_node_execution_status(
db_client=db_client,
exec_id=node_exec.node_exec_id,
status=ExecutionStatus.RUNNING,
)
@@ -546,8 +565,6 @@ class ExecutionProcessor:
self.node_evaluation_thread = threading.Thread(
target=self.node_evaluation_loop.run_forever, daemon=True
)
# single thread executor
self.execution_data_executor = ThreadPoolExecutor(max_workers=1)
self.node_execution_thread.start()
self.node_evaluation_thread.start()
logger.info(f"[GraphExecutor] {self.tid} started")
@@ -567,13 +584,9 @@ class ExecutionProcessor:
node_eid="*",
block_name="-",
)
db_client = get_db_client()
# Get graph execution metadata first via sync client
from backend.util.clients import get_database_manager_client
db_client_sync = get_database_manager_client()
exec_meta = db_client_sync.get_graph_execution_meta(
exec_meta = db_client.get_graph_execution_meta(
user_id=graph_exec.user_id,
execution_id=graph_exec.graph_exec_id,
)
@@ -583,15 +596,12 @@ class ExecutionProcessor:
)
return
# Create scoped ExecutionDataClient for this graph execution with metadata
self.execution_data = ExecutionDataClient(
self.execution_data_executor, graph_exec.graph_exec_id, exec_meta
)
if exec_meta.status == ExecutionStatus.QUEUED:
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
exec_meta.status = ExecutionStatus.RUNNING
self.execution_data.update_graph_start_time_and_publish()
send_execution_update(
db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
)
elif exec_meta.status == ExecutionStatus.RUNNING:
log_metadata.info(
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
@@ -601,7 +611,9 @@ class ExecutionProcessor:
log_metadata.info(
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was disturbed, continuing where it left off."
)
self.execution_data.update_graph_stats_and_publish(
update_graph_execution_state(
db_client=db_client,
graph_exec_id=graph_exec.graph_exec_id,
status=ExecutionStatus.RUNNING,
)
else:
@@ -632,10 +644,12 @@ class ExecutionProcessor:
# Activity status handling
activity_status = asyncio.run_coroutine_threadsafe(
self.execution_data.generate_activity_status(
generate_activity_status_for_execution(
graph_exec_id=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
graph_version=graph_exec.graph_version,
execution_stats=exec_stats,
db_client=get_db_async_client(),
user_id=graph_exec.user_id,
execution_status=status,
),
@@ -650,32 +664,33 @@ class ExecutionProcessor:
)
# Communication handling
self._handle_agent_run_notif(graph_exec, exec_stats)
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
finally:
self.execution_data.update_graph_stats_and_publish(
update_graph_execution_state(
db_client=db_client,
graph_exec_id=graph_exec.graph_exec_id,
status=exec_meta.status,
stats=exec_stats,
)
self.execution_data.finalize_execution()
def _charge_usage(
self,
node_exec: NodeExecutionEntry,
execution_count: int,
) -> tuple[int, int]:
) -> int:
total_cost = 0
remaining_balance = 0
db_client = get_db_client()
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return total_cost, 0
return total_cost
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
if cost > 0:
remaining_balance = self.execution_data.spend_credits(
db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
@@ -693,7 +708,7 @@ class ExecutionProcessor:
cost, usage_count = execution_usage_cost(execution_count)
if cost > 0:
remaining_balance = self.execution_data.spend_credits(
db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
@@ -708,7 +723,7 @@ class ExecutionProcessor:
)
total_cost += cost
return total_cost, remaining_balance
return total_cost
@time_measured
def _on_graph_execution(
@@ -726,6 +741,7 @@ class ExecutionProcessor:
"""
execution_status: ExecutionStatus = ExecutionStatus.RUNNING
error: Exception | None = None
db_client = get_db_client()
execution_stats_lock = threading.Lock()
# State holders ----------------------------------------------------
@@ -736,7 +752,7 @@ class ExecutionProcessor:
execution_queue = ExecutionQueue[NodeExecutionEntry]()
try:
if self.execution_data.get_credits(graph_exec.user_id) <= 0:
if db_client.get_credits(graph_exec.user_id) <= 0:
raise InsufficientBalanceError(
user_id=graph_exec.user_id,
message="You have no credits left to run an agent.",
@@ -748,7 +764,7 @@ class ExecutionProcessor:
try:
if moderation_error := asyncio.run_coroutine_threadsafe(
automod_manager.moderate_graph_execution_inputs(
db_client=self.execution_data.db_client_async,
db_client=get_db_async_client(),
graph_exec=graph_exec,
),
self.node_evaluation_loop,
@@ -763,34 +779,15 @@ class ExecutionProcessor:
# ------------------------------------------------------------
# Prepopulate queue ---------------------------------------
# ------------------------------------------------------------
queued_executions = self.execution_data.get_node_executions(
for node_exec in db_client.get_node_executions(
graph_exec.graph_exec_id,
statuses=[
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.TERMINATED,
],
)
log_metadata.info(
f"Pre-populating queue with {len(queued_executions)} executions from cache"
)
for i, node_exec in enumerate(queued_executions):
log_metadata.info(
f" [{i}] {node_exec.node_exec_id}: status={node_exec.status}, node={node_exec.node_id}"
)
try:
node_entry = node_exec.to_node_execution_entry(
graph_exec.user_context
)
execution_queue.add(node_entry)
log_metadata.info(" Added to execution queue successfully")
except Exception as e:
log_metadata.error(f" Failed to add to execution queue: {e}")
log_metadata.info(
f"Execution queue populated with {len(queued_executions)} executions"
)
):
execution_queue.add(node_exec.to_node_execution_entry())
# ------------------------------------------------------------
# Main dispatch / polling loop -----------------------------
@@ -808,36 +805,31 @@ class ExecutionProcessor:
# Charge usage (may raise) ------------------------------
try:
cost, remaining_balance = self._charge_usage(
cost = self._charge_usage(
node_exec=queued_node_exec,
execution_count=self.execution_data.increment_execution_count(
graph_exec.user_id
),
execution_count=increment_execution_count(graph_exec.user_id),
)
with execution_stats_lock:
execution_stats.cost += cost
# Check if we crossed the low balance threshold
self._handle_low_balance(
user_id=graph_exec.user_id,
current_balance=remaining_balance,
transaction_cost=cost,
)
except InsufficientBalanceError as balance_error:
error = balance_error # Set error to trigger FAILED status
node_exec_id = queued_node_exec.node_exec_id
self.execution_data.upsert_execution_output(
db_client.upsert_execution_output(
node_exec_id=node_exec_id,
output_name="error",
output_data=str(error),
)
self.execution_data.update_node_status_and_publish(
update_node_execution_status(
db_client=db_client,
exec_id=node_exec_id,
status=ExecutionStatus.FAILED,
)
self._handle_insufficient_funds_notif(
self._handle_low_balance_notif(
db_client,
graph_exec.user_id,
graph_exec.graph_id,
execution_stats,
error,
)
# Gracefully stop the execution loop
@@ -922,13 +914,12 @@ class ExecutionProcessor:
time.sleep(0.1)
# loop done --------------------------------------------------
# Background task finalization moved to finally block
# Output moderation
try:
if moderation_error := asyncio.run_coroutine_threadsafe(
automod_manager.moderate_graph_execution_outputs(
db_client=self.execution_data.db_client_async,
db_client=get_db_async_client(),
graph_exec_id=graph_exec.graph_exec_id,
user_id=graph_exec.user_id,
graph_id=graph_exec.graph_id,
@@ -982,6 +973,7 @@ class ExecutionProcessor:
error=error,
graph_exec_id=graph_exec.graph_exec_id,
log_metadata=log_metadata,
db_client=db_client,
)
@error_logged(swallow=True)
@@ -994,6 +986,7 @@ class ExecutionProcessor:
error: Exception | None,
graph_exec_id: str,
log_metadata: LogMetadata,
db_client: "DatabaseManagerClient",
) -> None:
"""
Clean up running node executions and evaluations when graph execution ends.
@@ -1027,7 +1020,8 @@ class ExecutionProcessor:
)
while queued_execution := execution_queue.get_or_none():
self.execution_data.update_node_status_and_publish(
update_node_execution_status(
db_client=db_client,
exec_id=queued_execution.node_exec_id,
status=execution_status,
stats={"error": str(error)} if error else None,
@@ -1055,10 +1049,11 @@ class ExecutionProcessor:
nodes_input_masks: Optional map of node input overrides
execution_queue: Queue to add next executions to
"""
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
db_client = get_db_async_client()
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
for next_execution in await _enqueue_next_nodes(
execution_data_client=self.execution_data,
db_client=db_client,
node=output.node,
output=output.data,
user_id=graph_exec.user_id,
@@ -1066,19 +1061,20 @@ class ExecutionProcessor:
graph_id=graph_exec.graph_id,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
user_context=graph_exec.user_context,
):
execution_queue.add(next_execution)
def _handle_agent_run_notif(
self,
db_client: "DatabaseManagerClient",
graph_exec: GraphExecutionEntry,
exec_stats: GraphExecutionStats,
):
metadata = self.execution_data.get_graph_metadata(
metadata = db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = self.execution_data.get_node_executions(
outputs = db_client.get_node_executions(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
@@ -1105,24 +1101,25 @@ class ExecutionProcessor:
)
)
def _handle_insufficient_funds_notif(
def _handle_low_balance_notif(
self,
db_client: "DatabaseManagerClient",
user_id: str,
graph_id: str,
exec_stats: GraphExecutionStats,
e: InsufficientBalanceError,
):
shortfall = abs(e.amount) - e.balance
metadata = self.execution_data.get_graph_metadata(graph_id)
shortfall = e.balance - e.amount
metadata = db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.ZERO_BALANCE,
data=ZeroBalanceData(
current_balance=e.balance,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=exec_stats.cost,
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
@@ -1130,72 +1127,6 @@ class ExecutionProcessor:
)
)
try:
user_email = self.execution_data.get_user_email_by_id(user_id)
alert_message = (
f"❌ **Insufficient Funds Alert**\n"
f"User: {user_email or user_id}\n"
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
f"Current balance: ${e.balance/100:.2f}\n"
f"Attempted cost: ${abs(e.amount)/100:.2f}\n"
f"Shortfall: ${abs(shortfall)/100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as alert_error:
logger.error(
f"Failed to send insufficient funds Discord alert: {alert_error}"
)
def _handle_low_balance(
self,
user_id: str,
current_balance: int,
transaction_cost: int,
):
"""Check and handle low balance scenarios after a transaction"""
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
balance_before = current_balance + transaction_cost
if (
current_balance < LOW_BALANCE_THRESHOLD
and balance_before >= LOW_BALANCE_THRESHOLD
):
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=current_balance,
billing_page_link=f"{base_url}/profile/credits",
),
)
)
try:
user_email = self.execution_data.get_user_email_by_id(user_id)
alert_message = (
f"⚠️ **Low Balance Alert**\n"
f"User: {user_email or user_id}\n"
f"Balance dropped below ${LOW_BALANCE_THRESHOLD/100:.2f}\n"
f"Current balance: ${current_balance/100:.2f}\n"
f"Transaction cost: ${transaction_cost/100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as e:
logger.error(f"Failed to send low balance Discord alert: {e}")
class ExecutionManager(AppProcess):
def __init__(self):
@@ -1377,14 +1308,14 @@ class ExecutionManager(AppProcess):
delivery_tag = method.delivery_tag
@func_retry
def _ack_message(reject: bool, requeue: bool):
def _ack_message(reject: bool = False):
"""Acknowledge or reject the message based on execution status."""
# Connection can be lost, so always get a fresh channel
channel = self.run_client.get_channel()
if reject:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_nack(delivery_tag, requeue=requeue)
lambda: channel.basic_nack(delivery_tag, requeue=True)
)
else:
channel.connection.add_callback_threadsafe(
@@ -1396,13 +1327,13 @@ class ExecutionManager(AppProcess):
logger.info(
f"[{self.service_name}] Rejecting new execution during shutdown"
)
_ack_message(reject=True, requeue=True)
_ack_message(reject=True)
return
# Check if we can accept more runs
self._cleanup_completed_runs()
if len(self.active_graph_runs) >= self.pool_size:
_ack_message(reject=True, requeue=True)
_ack_message(reject=True)
return
try:
@@ -1411,7 +1342,7 @@ class ExecutionManager(AppProcess):
logger.error(
f"[{self.service_name}] Could not parse run message: {e}, body={body}"
)
_ack_message(reject=True, requeue=False)
_ack_message(reject=True)
return
graph_exec_id = graph_exec_entry.graph_exec_id
@@ -1423,7 +1354,7 @@ class ExecutionManager(AppProcess):
logger.error(
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
)
_ack_message(reject=True, requeue=False)
_ack_message(reject=True)
return
cancel_event = threading.Event()
@@ -1439,9 +1370,9 @@ class ExecutionManager(AppProcess):
logger.error(
f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}"
)
_ack_message(reject=True, requeue=True)
_ack_message(reject=True)
else:
_ack_message(reject=False, requeue=False)
_ack_message(reject=False)
except BaseException as e:
logger.exception(
f"[{self.service_name}] Error in run completion callback: {e}"
@@ -1559,3 +1490,117 @@ class ExecutionManager(AppProcess):
)
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
# ------- UTILITIES ------- #
def get_db_client() -> "DatabaseManagerClient":
return get_database_manager_client()
def get_db_async_client() -> "DatabaseManagerAsyncClient":
return get_database_manager_async_client()
@func_retry
async def send_async_execution_update(
entry: GraphExecution | NodeExecutionResult | None,
) -> None:
if entry is None:
return
await get_async_execution_event_bus().publish(entry)
@func_retry
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
if entry is None:
return
return get_execution_event_bus().publish(entry)
async def async_update_node_execution_status(
db_client: "DatabaseManagerAsyncClient",
exec_id: str,
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> NodeExecutionResult:
"""Sets status and fetches+broadcasts the latest state of the node execution"""
exec_update = await db_client.update_node_execution_status(
exec_id, status, execution_data, stats
)
await send_async_execution_update(exec_update)
return exec_update
def update_node_execution_status(
db_client: "DatabaseManagerClient",
exec_id: str,
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> NodeExecutionResult:
"""Sets status and fetches+broadcasts the latest state of the node execution"""
exec_update = db_client.update_node_execution_status(
exec_id, status, execution_data, stats
)
send_execution_update(exec_update)
return exec_update
async def async_update_graph_execution_state(
db_client: "DatabaseManagerAsyncClient",
graph_exec_id: str,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
graph_update = await db_client.update_graph_execution_stats(
graph_exec_id, status, stats
)
if graph_update:
await send_async_execution_update(graph_update)
else:
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
return graph_update
def update_graph_execution_state(
db_client: "DatabaseManagerClient",
graph_exec_id: str,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
if graph_update:
send_execution_update(graph_update)
else:
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
return graph_update
@asynccontextmanager
async def synchronized(key: str, timeout: int = 60):
r = await redis.get_redis_async()
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
try:
await lock.acquire()
yield
finally:
if await lock.locked() and await lock.owned():
await lock.release()
def increment_execution_count(user_id: str) -> int:
"""
Increment the execution count for a given user,
this will be used to charge the user for the execution cost.
"""
r = redis.get_redis()
k = f"uec:{user_id}" # User Execution Count global key
counter = cast(int, r.incr(k))
if counter == 1:
r.expire(k, settings.config.execution_counter_expiration_time)
return counter

View File

@@ -1,202 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from prisma.enums import NotificationType
from backend.data.notifications import LowBalanceData
from backend.executor.manager import ExecutionProcessor
from backend.util.test import SpinTestServer
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
"""Test that _handle_low_balance triggers notification when crossing threshold."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
current_balance = 400 # $4 - below $5 threshold
transaction_cost = 600 # $6 transaction
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.low_balance_threshold = 500 # $5 threshold
mock_settings.config.frontend_base_url = "https://test.com"
# Initialize the execution processor and mock its execution_data
execution_processor.on_graph_executor_start()
# Mock the execution_data attribute since it's created in on_graph_execution
mock_execution_data = MagicMock()
execution_processor.execution_data = mock_execution_data
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
# Test the low balance handler
execution_processor._handle_low_balance(
user_id=user_id,
current_balance=current_balance,
transaction_cost=transaction_cost,
)
# Verify notification was queued
mock_queue_notif.assert_called_once()
notification_call = mock_queue_notif.call_args[0][0]
# Verify notification details
assert notification_call.type == NotificationType.LOW_BALANCE
assert notification_call.user_id == user_id
assert isinstance(notification_call.data, LowBalanceData)
assert notification_call.data.current_balance == current_balance
# Verify Discord alert was sent
mock_client.discord_system_alert.assert_called_once()
discord_message = mock_client.discord_system_alert.call_args[0][0]
assert "Low Balance Alert" in discord_message
assert "test@example.com" in discord_message
assert "$4.00" in discord_message
assert "$6.00" in discord_message
# Cleanup execution processor threads
try:
execution_processor.node_execution_loop.call_soon_threadsafe(
execution_processor.node_execution_loop.stop
)
execution_processor.node_evaluation_loop.call_soon_threadsafe(
execution_processor.node_evaluation_loop.stop
)
execution_processor.node_execution_thread.join(timeout=1)
execution_processor.node_evaluation_thread.join(timeout=1)
except Exception:
pass # Ignore cleanup errors
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_low_balance_no_notification_when_not_crossing(
server: SpinTestServer,
):
"""Test that no notification is sent when not crossing the threshold."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
current_balance = 600 # $6 - above $5 threshold
transaction_cost = (
100 # $1 transaction (balance before was $7, still above threshold)
)
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.low_balance_threshold = 500 # $5 threshold
# Initialize the execution processor and mock its execution_data
execution_processor.on_graph_executor_start()
# Mock the execution_data attribute since it's created in on_graph_execution
mock_execution_data = MagicMock()
execution_processor.execution_data = mock_execution_data
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
# Test the low balance handler
execution_processor._handle_low_balance(
user_id=user_id,
current_balance=current_balance,
transaction_cost=transaction_cost,
)
# Verify no notification was sent
mock_queue_notif.assert_not_called()
mock_client.discord_system_alert.assert_not_called()
# Cleanup execution processor threads
try:
execution_processor.node_execution_loop.call_soon_threadsafe(
execution_processor.node_execution_loop.stop
)
execution_processor.node_evaluation_loop.call_soon_threadsafe(
execution_processor.node_evaluation_loop.stop
)
execution_processor.node_execution_thread.join(timeout=1)
execution_processor.node_evaluation_thread.join(timeout=1)
except Exception:
pass # Ignore cleanup errors
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_low_balance_no_duplicate_when_already_below(
server: SpinTestServer,
):
"""Test that no notification is sent when already below threshold."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
current_balance = 300 # $3 - below $5 threshold
transaction_cost = (
100 # $1 transaction (balance before was $4, also below threshold)
)
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.low_balance_threshold = 500 # $5 threshold
# Initialize the execution processor and mock its execution_data
execution_processor.on_graph_executor_start()
# Mock the execution_data attribute since it's created in on_graph_execution
mock_execution_data = MagicMock()
execution_processor.execution_data = mock_execution_data
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
# Test the low balance handler
execution_processor._handle_low_balance(
user_id=user_id,
current_balance=current_balance,
transaction_cost=transaction_cost,
)
# Verify no notification was sent (user was already below threshold)
mock_queue_notif.assert_not_called()
mock_client.discord_system_alert.assert_not_called()
# Cleanup execution processor threads
try:
execution_processor.node_execution_loop.call_soon_threadsafe(
execution_processor.node_execution_loop.stop
)
execution_processor.node_evaluation_loop.call_soon_threadsafe(
execution_processor.node_evaluation_loop.stop
)
execution_processor.node_execution_thread.join(timeout=1)
execution_processor.node_evaluation_thread.join(timeout=1)
except Exception:
pass # Ignore cleanup errors

View File

@@ -1,5 +1,6 @@
import logging
import autogpt_libs.auth.models
import fastapi.responses
import pytest
@@ -520,7 +521,12 @@ async def test_store_listing_graph(server: SpinTestServer):
is_approved=True,
comments="Test comments",
),
user_id=admin_user.id,
autogpt_libs.auth.models.User(
user_id=admin_user.id,
role="admin",
email=admin_user.email,
phone_number="1234567890",
),
)
alt_test_user = admin_user

View File

@@ -17,7 +17,6 @@ from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.util import ZoneInfo
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
@@ -304,7 +303,6 @@ class Scheduler(AppService):
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
},
logger=apscheduler_logger,
timezone=ZoneInfo("UTC"),
)
if self.register_system_tasks:
@@ -408,8 +406,6 @@ class Scheduler(AppService):
)
)
logger.info(f"Scheduling job for user {user_id} in UTC (cron: {cron})")
job_args = GraphExecutionJobArgs(
user_id=user_id,
graph_id=graph_id,
@@ -422,12 +418,12 @@ class Scheduler(AppService):
execute_graph,
kwargs=job_args.model_dump(),
name=name,
trigger=CronTrigger.from_crontab(cron, timezone="UTC"),
trigger=CronTrigger.from_crontab(cron),
jobstore=Jobstores.EXECUTION.value,
replace_existing=True,
)
logger.info(
f"Added job {job.id} with cron schedule '{cron}' in UTC, input data: {input_data}"
f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}"
)
return GraphExecutionJobInfo.from_db(job_args, job)

View File

@@ -18,12 +18,10 @@ from backend.data.execution import (
ExecutionStatus,
GraphExecutionStats,
GraphExecutionWithNodes,
UserContext,
)
from backend.data.graph import GraphModel, Node
from backend.data.model import CredentialsMetaInput
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import get_user_by_id
from backend.util.clients import (
get_async_execution_event_bus,
get_async_execution_queue,
@@ -36,27 +34,6 @@ from backend.util.mock import MockObject
from backend.util.settings import Config
from backend.util.type import convert
async def get_user_context(user_id: str) -> UserContext:
"""
Get UserContext for a user, always returns a valid context with timezone.
Defaults to UTC if user has no timezone set.
"""
user_context = UserContext(timezone="UTC") # Default to UTC
try:
user = await get_user_by_id(user_id)
if user and user.timezone and user.timezone != "not-set":
user_context.timezone = user.timezone
logger.debug(f"Retrieved user context: timezone={user.timezone}")
else:
logger.debug("User has no timezone set, using UTC")
except Exception as e:
logger.warning(f"Could not fetch user timezone: {e}")
# Continue with UTC as default
return user_context
config = Config()
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
@@ -900,11 +877,8 @@ async def add_graph_execution(
preset_id=preset_id,
)
# Fetch user context for the graph execution
user_context = await get_user_context(user_id)
queue = await get_async_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry(user_context)
graph_exec_entry = graph_exec.to_graph_execution_entry()
if nodes_input_masks:
graph_exec_entry.nodes_input_masks = nodes_input_masks

View File

@@ -4,7 +4,6 @@ from pydantic import BaseModel
from backend.integrations.oauth.todoist import TodoistOAuthHandler
from .discord import DiscordOAuthHandler
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
@@ -16,7 +15,6 @@ if TYPE_CHECKING:
# --8<-- [start:HANDLERS_BY_NAMEExample]
# Build handlers dict with string keys for compatibility with SDK auto-registration
_ORIGINAL_HANDLERS = [
DiscordOAuthHandler,
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,

View File

@@ -1,175 +0,0 @@
import time
from typing import Optional
from urllib.parse import urlencode
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
from .base import BaseOAuthHandler
class DiscordOAuthHandler(BaseOAuthHandler):
"""
Discord OAuth2 handler implementation.
Based on the documentation at:
- https://discord.com/developers/docs/topics/oauth2
Discord OAuth2 tokens expire after 7 days by default and include refresh tokens.
"""
PROVIDER_NAME = ProviderName.DISCORD
DEFAULT_SCOPES = ["identify"] # Basic user information
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.auth_base_url = "https://discord.com/oauth2/authorize"
self.token_url = "https://discord.com/api/oauth2/token"
self.revoke_url = "https://discord.com/api/oauth2/token/revoke"
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
# Handle default scopes
scopes = self.handle_default_scopes(scopes)
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"response_type": "code",
"scope": " ".join(scopes),
"state": state,
}
# Discord supports PKCE
if code_challenge:
params["code_challenge"] = code_challenge
params["code_challenge_method"] = "S256"
return f"{self.auth_base_url}?{urlencode(params)}"
async def exchange_code_for_tokens(
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
params = {
"code": code,
"redirect_uri": self.redirect_uri,
"grant_type": "authorization_code",
}
# Include PKCE verifier if provided
if code_verifier:
params["code_verifier"] = code_verifier
return await self._request_tokens(params)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
if not credentials.access_token:
raise ValueError("No access token to revoke")
# Discord requires client authentication for token revocation
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
}
headers = {
"Content-Type": "application/x-www-form-urlencoded",
}
response = await Requests().post(
url=self.revoke_url,
data=data,
headers=headers,
auth=(self.client_id, self.client_secret),
)
# Discord returns 200 OK for successful revocation
return response.status == 200
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
if not credentials.refresh_token:
return credentials
return await self._request_tokens(
{
"refresh_token": credentials.refresh_token.get_secret_value(),
"grant_type": "refresh_token",
},
current_credentials=credentials,
)
async def _request_tokens(
self,
params: dict[str, str],
current_credentials: Optional[OAuth2Credentials] = None,
) -> OAuth2Credentials:
request_body = {
"client_id": self.client_id,
"client_secret": self.client_secret,
**params,
}
headers = {
"Content-Type": "application/x-www-form-urlencoded",
}
response = await Requests().post(
self.token_url, data=request_body, headers=headers
)
token_data: dict = response.json()
# Get username if this is a new token request
username = None
if "access_token" in token_data:
username = await self._request_username(token_data["access_token"])
now = int(time.time())
new_credentials = OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=current_credentials.title if current_credentials else None,
username=username,
access_token=token_data["access_token"],
scopes=token_data.get("scope", "").split()
or (current_credentials.scopes if current_credentials else []),
refresh_token=token_data.get("refresh_token"),
# Discord tokens expire after expires_in seconds (typically 7 days)
access_token_expires_at=(
now + expires_in
if (expires_in := token_data.get("expires_in", None))
else None
),
# Discord doesn't provide separate refresh token expiration
refresh_token_expires_at=None,
)
if current_credentials:
new_credentials.id = current_credentials.id
return new_credentials
async def _request_username(self, access_token: str) -> str | None:
"""
Fetch the username using the Discord OAuth2 @me endpoint.
"""
url = "https://discord.com/api/oauth2/@me"
headers = {
"Authorization": f"Bearer {access_token}",
}
response = await Requests().get(url, headers=headers)
if not response.ok:
return None
# Get user info from the response
data = response.json()
user_info = data.get("user", {})
# Return username (without discriminator)
return user_info.get("username")

View File

@@ -29,7 +29,7 @@ from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.clients import get_database_manager_async_client
from backend.util.logging import TruncatedLogger
from backend.util.metrics import DiscordChannel, discord_send_alert
from backend.util.metrics import discord_send_alert
from backend.util.retry import continuous_retry
from backend.util.service import (
AppService,
@@ -382,10 +382,8 @@ class NotificationManager(AppService):
}
@expose
async def discord_system_alert(
self, content: str, channel: DiscordChannel = DiscordChannel.PLATFORM
):
await discord_send_alert(content, channel)
async def discord_system_alert(self, content: str):
await discord_send_alert(content)
async def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
"""Queue a scheduled notification - exposed method for other services to call"""

View File

@@ -1,73 +0,0 @@
{# Agent Approved Notification Email Template #}
{#
Template variables:
data.agent_name: the name of the approved agent
data.agent_id: the ID of the agent
data.agent_version: the version of the agent
data.reviewer_name: the name of the reviewer who approved it
data.reviewer_email: the email of the reviewer
data.comments: comments from the reviewer
data.reviewed_at: when the agent was reviewed
data.store_url: URL to view the agent in the store
Subject: 🎉 Your agent '{{ data.agent_name }}' has been approved!
#}
{% block content %}
<h1 style="color: #28a745; font-size: 32px; font-weight: 700; margin: 0 0 24px 0; text-align: center;">
🎉 Congratulations!
</h1>
<p style="color: #586069; font-size: 18px; text-align: center; margin: 0 0 24px 0;">
Your agent <strong>'{{ data.agent_name }}'</strong> has been approved and is now live in the store!
</p>
<div style="height: 32px; background: transparent;"></div>
{% if data.comments %}
<div style="background: #d4edda; border: 1px solid #c3e6cb; border-radius: 8px; padding: 20px; margin: 0;">
<h3 style="color: #155724; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
💬 Creator feedback area
</h3>
<p style="color: #155724; margin: 0; font-size: 16px; line-height: 1.5;">
{{ data.comments }}
</p>
</div>
<div style="height: 40px; background: transparent;"></div>
{% endif %}
<div style="background: #d1ecf1; border: 1px solid #bee5eb; border-radius: 8px; padding: 20px; margin: 0;">
<h3 style="color: #0c5460; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
What's Next?
</h3>
<ul style="color: #0c5460; margin: 0; padding-left: 18px; font-size: 16px; line-height: 1.6;">
<li>Your agent is now live and discoverable in the AutoGPT Store</li>
<li>Users can find, install, and run your agent</li>
<li>You can update your agent anytime by submitting a new version</li>
</ul>
</div>
<div style="height: 32px; background: transparent;"></div>
<div style="text-align: center; margin: 24px 0;">
<a href="{{ data.store_url }}" style="display: inline-block; background: linear-gradient(135deg, #7c3aed 0%, #5b21b6 100%); color: black; text-decoration: none; padding: 14px 28px; border-radius: 6px; font-weight: 600; font-size: 16px;">
View Your Agent in Store
</a>
</div>
<div style="height: 32px; background: transparent;"></div>
<div style="background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 6px; padding: 16px; margin: 24px 0; text-align: center;">
<p style="margin: 0; color: #856404; font-size: 14px;">
<strong>💡 Pro Tip:</strong> Share your agent with the community! Post about it on social media, forums, or your blog to help more users discover and benefit from your creation.
</p>
</div>
<div style="height: 32px; background: transparent;"></div>
<p style="color: #6a737d; font-size: 14px; text-align: center; margin: 24px 0;">
Thank you for contributing to the AutoGPT ecosystem! 🚀
</p>
{% endblock %}

View File

@@ -1,77 +0,0 @@
{# Agent Rejected Notification Email Template #}
{#
Template variables:
data.agent_name: the name of the rejected agent
data.agent_id: the ID of the agent
data.agent_version: the version of the agent
data.reviewer_name: the name of the reviewer who rejected it
data.reviewer_email: the email of the reviewer
data.comments: comments from the reviewer explaining the rejection
data.reviewed_at: when the agent was reviewed
data.resubmit_url: URL to resubmit the agent
Subject: Your agent '{{ data.agent_name }}' needs some updates
#}
{% block content %}
<h1 style="color: #d73a49; font-size: 32px; font-weight: 700; margin: 0 0 24px 0; text-align: center;">
📝 Review Complete
</h1>
<p style="color: #586069; font-size: 18px; text-align: center; margin: 0 0 24px 0;">
Your agent <strong>'{{ data.agent_name }}'</strong> needs some updates before approval.
</p>
<div style="height: 32px; background: transparent;"></div>
<div style="background: #f8d7da; border: 1px solid #f5c6cb; border-radius: 8px; padding: 20px; margin: 0 0 24px 0;">
<h3 style="color: #721c24; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
💬 Creator feedback area
</h3>
<p style="color: #721c24; margin: 0; font-size: 16px; line-height: 1.5;">
{{ data.comments }}
</p>
</div>
<div style="height: 40px; background: transparent;"></div>
<div style="background: #d4edda; border: 1px solid #c3e6cb; border-radius: 8px; padding: 20px; margin: 0 0 24px 0;">
<h3 style="color: #155724; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
☑ Steps to Resubmit:
</h3>
<ul style="color: #155724; margin: 0; padding-left: 18px; font-size: 16px; line-height: 1.6;">
<li>Review the feedback provided above carefully</li>
<li>Make the necessary updates to your agent</li>
<li>Test your agent thoroughly to ensure it works as expected</li>
<li>Submit your updated agent for review</li>
</ul>
</div>
<div style="height: 32px; background: transparent;"></div>
<div style="background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 6px; padding: 12px; margin: 0 0 24px 0; text-align: center;">
<p style="margin: 0; color: #856404; font-size: 14px;">
<strong>💡 Tip:</strong> Address all the points mentioned in the feedback to increase your chances of approval in the next review.
</p>
</div>
<div style="text-align: center; margin: 32px 0;">
<a href="{{ data.resubmit_url }}" style="display: inline-block; background: linear-gradient(135deg, #7c3aed 0%, #5b21b6 100%); color: black; text-decoration: none; padding: 14px 28px; border-radius: 6px; font-weight: 600; font-size: 16px;">
Update & Resubmit Agent
</a>
</div>
<div style="background: #d1ecf1; border: 1px solid #bee5eb; border-radius: 6px; padding: 16px; margin: 24px 0;">
<p style="margin: 0; color: #0c5460; font-size: 14px; text-align: center;">
<strong>🌟 Don't Give Up!</strong> Many successful agents go through multiple iterations before approval. Our review team is here to help you succeed!
</p>
</div>
<div style="height: 32px; background: transparent;"></div>
<p style="color: #6a737d; font-size: 14px; text-align: center; margin: 32px 0 24px 0;">
We're excited to see your improved agent submission! 🚀
</p>
{% endblock %}

View File

@@ -1,7 +1,9 @@
{# Low Balance Notification Email Template #}
{# Template variables:
data.agent_name: the name of the agent
data.current_balance: the current balance of the user
data.billing_page_link: the link to the billing page
data.shortfall: the shortfall amount
#}
<p style="
@@ -23,7 +25,7 @@ data.billing_page_link: the link to the billing page
margin-top: 0;
margin-bottom: 20px;
">
Your account balance has dropped below the recommended threshold.
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
</p>
<div style="
@@ -42,6 +44,15 @@ data.billing_page_link: the link to the billing page
">
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
</p>
</div>
@@ -68,7 +79,7 @@ data.billing_page_link: the link to the billing page
margin-top: 0;
margin-bottom: 5px;
">
Your account requires additional credits to continue running agents. Please add credits to your account to avoid service interruption.
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
</p>
</div>
@@ -99,5 +110,5 @@ data.billing_page_link: the link to the billing page
margin-bottom: 10px;
font-style: italic;
">
This is an automated low balance notification. Consider adding credits soon to avoid service interruption.
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
</p>

View File

@@ -1,114 +0,0 @@
{# Low Balance Notification Email Template #}
{# Template variables:
data.agent_name: the name of the agent
data.current_balance: the current balance of the user
data.billing_page_link: the link to the billing page
data.shortfall: the shortfall amount
#}
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Zero Balance Warning</strong>
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 20px;
">
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
</p>
<div style="
margin-left: 15px;
margin-bottom: 20px;
padding: 15px;
border-left: 4px solid #5D23BB;
background-color: #f8f8ff;
">
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
</p>
</div>
<div style="
margin-left: 15px;
margin-bottom: 20px;
padding: 15px;
border-left: 4px solid #FF6B6B;
background-color: #FFF0F0;
">
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Low Balance:</strong>
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 5px;
">
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
</p>
</div>
<div style="
text-align: center;
margin: 30px 0;
">
<a href="{{ data.billing_page_link }}" style="
font-family: 'Poppins', sans-serif;
background-color: #5D23BB;
color: white;
padding: 12px 24px;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
display: inline-block;
">
Manage Billing
</a>
</div>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 150%;
margin-top: 30px;
margin-bottom: 10px;
font-style: italic;
">
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
</p>

View File

@@ -14,8 +14,6 @@ from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
CredentialsType,
OAuth2Credentials,
UserPasswordCredentials,
)
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.webhooks._base import BaseWebhooksManager
@@ -106,39 +104,14 @@ class Provider:
)
def get_test_credentials(self) -> Credentials:
"""Get test credentials for the provider based on supported auth types."""
test_id = str(self.test_credentials_uuid)
# Return credentials based on the first supported auth type
if "user_password" in self.supported_auth_types:
return UserPasswordCredentials(
id=test_id,
provider=self.name,
username=SecretStr(f"mock-{self.name}-username"),
password=SecretStr(f"mock-{self.name}-password"),
title=f"Mock {self.name.title()} credentials",
)
elif "oauth2" in self.supported_auth_types:
return OAuth2Credentials(
id=test_id,
provider=self.name,
username=f"mock-{self.name}-username",
access_token=SecretStr(f"mock-{self.name}-access-token"),
access_token_expires_at=None,
refresh_token=SecretStr(f"mock-{self.name}-refresh-token"),
refresh_token_expires_at=None,
scopes=[f"mock-{self.name}-scope"],
title=f"Mock {self.name.title()} OAuth credentials",
)
else:
# Default to API key credentials
return APIKeyCredentials(
id=test_id,
provider=self.name,
api_key=SecretStr(f"mock-{self.name}-api-key"),
title=f"Mock {self.name.title()} API key",
expires_at=None,
)
"""Get test credentials for the provider."""
return APIKeyCredentials(
id=str(self.test_credentials_uuid),
provider=self.name,
api_key=SecretStr("mock-api-key"),
title=f"Mock {self.name.title()} API key",
expires_at=None,
)
def get_api(self, credentials: Credentials) -> Any:
"""Get API client instance for the given credentials."""

View File

@@ -11,45 +11,7 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
return snapshot
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "test-user-id"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "admin-user-id"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "target-user-id"
@pytest.fixture
def mock_jwt_user(test_user_id):
"""Provide mock JWT payload for regular user testing."""
import fastapi
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
return {"sub": test_user_id, "role": "user", "email": "test@example.com"}
return {"get_jwt_payload": override_get_jwt_payload, "user_id": test_user_id}
@pytest.fixture
def mock_jwt_admin(admin_user_id):
"""Provide mock JWT payload for admin user testing."""
import fastapi
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
return {
"sub": admin_user_id,
"role": "admin",
"email": "test-admin@example.com",
}
return {"get_jwt_payload": override_get_jwt_payload, "user_id": admin_user_id}
# Test ID constants
TEST_USER_ID = "test-user-id"
ADMIN_USER_ID = "admin-user-id"
TARGET_USER_ID = "target-user-id"

View File

@@ -58,13 +58,17 @@ class ProviderConstants(BaseModel):
default_factory=lambda: {
name.upper().replace("-", "_"): name for name in get_all_provider_names()
},
examples=[
{
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"EXA": "exa",
"GEM": "gem",
"EXAMPLE_SERVICE": "example-service",
}
],
)
class Config:
schema_extra = {
"example": {
"PROVIDER_NAMES": {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"EXA": "exa",
"GEM": "gem",
"EXAMPLE_SERVICE": "example-service",
}
}
}

View File

@@ -3,15 +3,14 @@ import logging
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, Awaitable, List, Literal
from autogpt_libs.auth import get_user_id
from fastapi import (
APIRouter,
Body,
Depends,
HTTPException,
Path,
Query,
Request,
Security,
status,
)
from pydantic import BaseModel, Field, SecretStr
@@ -51,6 +50,8 @@ from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.oauth import BaseOAuthHandler
from ..utils import get_user_id
logger = logging.getLogger(__name__)
settings = Settings()
router = APIRouter()
@@ -68,7 +69,7 @@ async def login(
provider: Annotated[
ProviderName, Path(title="The provider to initiate an OAuth flow for")
],
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
scopes: Annotated[
str, Query(title="Comma-separated list of authorization scopes")
@@ -108,7 +109,7 @@ async def callback(
],
code: Annotated[str, Body(title="Authorization code acquired by user login")],
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
) -> CredentialsMetaResponse:
logger.debug(f"Received OAuth callback for provider: {provider}")
@@ -181,7 +182,7 @@ async def callback(
@router.get("/credentials")
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> list[CredentialsMetaResponse]:
credentials = await creds_manager.store.get_all_creds(user_id)
return [
@@ -203,7 +204,7 @@ async def list_credentials_by_provider(
provider: Annotated[
ProviderName, Path(title="The provider to list credentials for")
],
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> list[CredentialsMetaResponse]:
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
@@ -226,7 +227,7 @@ async def get_credential(
ProviderName, Path(title="The provider to retrieve credentials for")
],
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> Credentials:
credential = await creds_manager.get(user_id, cred_id)
if not credential:
@@ -243,7 +244,7 @@ async def get_credential(
@router.post("/{provider}/credentials", status_code=201)
async def create_credentials(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
provider: Annotated[
ProviderName, Path(title="The provider to create credentials for")
],
@@ -287,7 +288,7 @@ async def delete_credentials(
ProviderName, Path(title="The provider to delete credentials for")
],
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
force: Annotated[
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
] = False,
@@ -428,7 +429,7 @@ async def webhook_ingress_generic(
@router.post("/webhooks/{webhook_id}/ping")
async def webhook_ping(
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
user_id: Annotated[str, Security(get_user_id)], # require auth
user_id: Annotated[str, Depends(get_user_id)], # require auth
):
webhook = await get_webhook(webhook_id)
webhook_manager = get_webhook_manager(webhook.provider)
@@ -567,7 +568,7 @@ def _get_provider_oauth_handler(
@router.get("/ayrshare/sso_url")
async def get_ayrshare_sso_url(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> AyrshareSSOResponse:
"""
Generate an SSO URL for Ayrshare social media integration.

View File

@@ -5,7 +5,6 @@ import pydantic
from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash
from backend.data.graph import Graph
from backend.util.timezone_name import TimeZoneName
class WSMethod(enum.Enum):
@@ -61,6 +60,21 @@ class UpdatePermissionsRequest(pydantic.BaseModel):
permissions: list[APIKeyPermission]
class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[2]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)
class RequestTopUp(pydantic.BaseModel):
credit_amount: int
@@ -71,12 +85,3 @@ class UploadFileResponse(pydantic.BaseModel):
size: int
content_type: str
expires_in_hours: int
class TimezoneResponse(pydantic.BaseModel):
# Allow "not-set" as a special value, or any valid IANA timezone
timezone: TimeZoneName | str
class UpdateTimezoneRequest(pydantic.BaseModel):
timezone: TimeZoneName

View File

@@ -3,13 +3,12 @@ import logging
from enum import Enum
from typing import Any, Optional
import autogpt_libs.auth.models
import fastapi
import fastapi.responses
import pydantic
import starlette.middleware.cors
import uvicorn
from autogpt_libs.auth import add_auth_responses_to_openapi
from autogpt_libs.auth import verify_settings as verify_auth_settings
from fastapi.exceptions import RequestValidationError
from fastapi.routing import APIRoute
@@ -21,8 +20,6 @@ import backend.server.routers.postmark.postmark
import backend.server.routers.v1
import backend.server.v2.admin.credit_admin_routes
import backend.server.v2.admin.store_admin_routes
import backend.server.v2.builder
import backend.server.v2.builder.routes
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.library.routes
@@ -62,8 +59,6 @@ def launch_darkly_context():
@contextlib.asynccontextmanager
async def lifespan_context(app: fastapi.FastAPI):
verify_auth_settings()
await backend.data.db.connect()
# Ensure SDK auto-registration is patched before initializing blocks
@@ -134,9 +129,6 @@ app = fastapi.FastAPI(
app.add_middleware(SecurityHeadersMiddleware)
# Add 401 responses to authenticated endpoints in OpenAPI spec
add_auth_responses_to_openapi(app)
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
def handler(request: fastapi.Request, exc: Exception):
@@ -203,9 +195,6 @@ app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/ap
app.include_router(
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
)
app.include_router(
backend.server.v2.builder.routes.router, tags=["v2"], prefix="/api/builder"
)
app.include_router(
backend.server.v2.admin.store_admin_routes.router,
tags=["v2", "admin"],
@@ -376,10 +365,10 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_review_store_listing(
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user_id: str,
user: autogpt_libs.auth.models.User,
):
return await backend.server.v2.admin.store_admin_routes.review_submission(
request.store_listing_version_id, request, user_id
request.store_listing_version_id, request, user
)
@staticmethod

View File

@@ -5,9 +5,9 @@ from typing import Annotated
import fastapi
import pydantic
from autogpt_libs.auth import get_user_id
import backend.data.analytics
from backend.server.utils import get_user_id
router = fastapi.APIRouter()
logger = logging.getLogger(__name__)
@@ -21,7 +21,7 @@ class LogRawMetricRequest(pydantic.BaseModel):
@router.post(path="/log_raw_metric")
async def log_raw_metric(
user_id: Annotated[str, fastapi.Security(get_user_id)],
user_id: Annotated[str, fastapi.Depends(get_user_id)],
request: LogRawMetricRequest,
):
try:
@@ -47,7 +47,7 @@ async def log_raw_metric(
@router.post("/log_raw_analytics")
async def log_raw_analytics(
user_id: Annotated[str, fastapi.Security(get_user_id)],
user_id: Annotated[str, fastapi.Depends(get_user_id)],
type: Annotated[str, fastapi.Body(..., embed=True)],
data: Annotated[
dict,

View File

@@ -5,17 +5,18 @@ from unittest.mock import AsyncMock, Mock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from pytest_snapshot.plugin import Snapshot
import backend.server.routers.analytics as analytics_routes
from backend.server.conftest import TEST_USER_ID
from backend.server.test_helpers import (
assert_error_response_structure,
assert_mock_called_with_partial,
assert_response_status,
safe_parse_json,
)
from backend.server.utils import get_user_id
app = fastapi.FastAPI()
app.include_router(analytics_routes.router)
@@ -23,20 +24,17 @@ app.include_router(analytics_routes.router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
def override_get_user_id() -> str:
"""Override get_user_id for testing"""
return TEST_USER_ID
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
app.dependency_overrides[get_user_id] = override_get_user_id
def test_log_raw_metric_success_improved(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test successful raw metric logging with improved assertions."""
# Mock the analytics function
@@ -65,7 +63,7 @@ def test_log_raw_metric_success_improved(
# Verify the function was called with correct parameters
assert_mock_called_with_partial(
mock_log_metric,
user_id=test_user_id,
user_id=TEST_USER_ID,
metric_name="page_load_time",
metric_value=2.5,
data_string="/dashboard",

View File

@@ -10,6 +10,8 @@ import pytest_mock
from pytest_snapshot.plugin import Snapshot
import backend.server.routers.analytics as analytics_routes
from backend.server.conftest import TEST_USER_ID
from backend.server.utils import get_user_id
app = fastapi.FastAPI()
app.include_router(analytics_routes.router)
@@ -17,14 +19,12 @@ app.include_router(analytics_routes.router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
def override_get_user_id() -> str:
"""Override get_user_id for testing"""
return TEST_USER_ID
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
app.dependency_overrides[get_user_id] = override_get_user_id
@pytest.mark.parametrize(

View File

@@ -3,11 +3,12 @@ from unittest.mock import AsyncMock, Mock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from pytest_snapshot.plugin import Snapshot
import backend.server.routers.analytics as analytics_routes
from backend.server.conftest import TEST_USER_ID
from backend.server.utils import get_user_id
app = fastapi.FastAPI()
app.include_router(analytics_routes.router)
@@ -15,20 +16,17 @@ app.include_router(analytics_routes.router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
def override_get_user_id() -> str:
"""Override get_user_id for testing"""
return TEST_USER_ID
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
app.dependency_overrides[get_user_id] = override_get_user_id
def test_log_raw_metric_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test successful raw metric logging"""
@@ -55,7 +53,7 @@ def test_log_raw_metric_success(
# Verify the function was called with correct parameters
mock_log_metric.assert_called_once_with(
user_id=test_user_id,
user_id=TEST_USER_ID,
metric_name="page_load_time",
metric_value=2.5,
data_string="/dashboard",
@@ -123,7 +121,6 @@ def test_log_raw_metric_various_values(
def test_log_raw_analytics_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test successful raw analytics logging"""
@@ -158,7 +155,7 @@ def test_log_raw_analytics_success(
# Verify the function was called with correct parameters
mock_log_analytics.assert_called_once_with(
test_user_id,
TEST_USER_ID,
"user_action",
request_data["data"],
"button_click_submit_form",

View File

@@ -1,7 +1,8 @@
import logging
from typing import Annotated
from fastapi import APIRouter, Body, HTTPException, Query, Security
from autogpt_libs.auth.middleware import APIKeyValidator
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from fastapi.responses import JSONResponse
from backend.data.user import (
@@ -19,19 +20,19 @@ from backend.server.routers.postmark.models import (
PostmarkSubscriptionChangeWebhook,
PostmarkWebhook,
)
from backend.server.utils.api_key_auth import APIKeyAuthenticator
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
router = APIRouter()
postmark_api_key_auth = APIKeyAuthenticator(
postmark_validator = APIKeyValidator(
"X-Postmark-Webhook-Token",
settings.secrets.postmark_webhook_token,
)
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/unsubscribe", summary="One Click Email Unsubscribe")
async def unsubscribe_via_one_click(token: Annotated[str, Query()]):
@@ -49,7 +50,7 @@ async def unsubscribe_via_one_click(token: Annotated[str, Query()]):
@router.post(
"/",
dependencies=[Security(postmark_api_key_auth)],
dependencies=[Depends(postmark_validator.get_dependency())],
summary="Handle Postmark Email Webhooks",
)
async def postmark_webhook_handler(

View File

@@ -7,18 +7,16 @@ from typing import Annotated, Any, Sequence
import pydantic
import stripe
from autogpt_libs.auth import get_user_id, requires_user
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from autogpt_libs.auth.middleware import auth_middleware
from fastapi import (
APIRouter,
Body,
Depends,
File,
HTTPException,
Path,
Query,
Request,
Response,
Security,
UploadFile,
)
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
@@ -62,11 +60,9 @@ from backend.data.onboarding import (
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
get_user_notification_preference,
update_user_email,
update_user_notification_preference,
update_user_timezone,
)
from backend.executor import scheduler
from backend.executor import utils as execution_utils
@@ -81,20 +77,15 @@ from backend.server.model import (
ExecuteGraphResponse,
RequestTopUp,
SetGraphActiveVersion,
TimezoneResponse,
UpdatePermissionsRequest,
UpdateTimezoneRequest,
UploadFileResponse,
)
from backend.server.utils import get_user_id
from backend.util.clients import get_scheduler_client
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.exceptions import GraphValidationError, NotFoundError
from backend.util.feature_flag import feature_flag
from backend.util.settings import Settings
from backend.util.timezone_utils import (
convert_cron_to_utc,
convert_utc_time_to_user_timezone,
get_user_timezone_or_utc,
)
from backend.util.virus_scanner import scan_content_safe
@@ -124,7 +115,7 @@ v1_router.include_router(
backend.server.routers.analytics.router,
prefix="/analytics",
tags=["analytics"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
@@ -137,9 +128,9 @@ v1_router.include_router(
"/auth/user",
summary="Get or create user",
tags=["auth"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
async def get_or_create_user_route(user_data: dict = Depends(auth_middleware)):
user = await get_or_create_user(user_data)
return user.model_dump()
@@ -148,53 +139,24 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
"/auth/user/email",
summary="Update user email",
tags=["auth"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def update_user_email_route(
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
user_id: Annotated[str, Depends(get_user_id)], email: str = Body(...)
) -> dict[str, str]:
await update_user_email(user_id, email)
return {"email": email}
@v1_router.get(
"/auth/user/timezone",
summary="Get user timezone",
tags=["auth"],
dependencies=[Security(requires_user)],
)
async def get_user_timezone_route(
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Get user timezone setting."""
user = await get_or_create_user(user_data)
return TimezoneResponse(timezone=user.timezone)
@v1_router.post(
"/auth/user/timezone",
summary="Update user timezone",
tags=["auth"],
dependencies=[Security(requires_user)],
response_model=TimezoneResponse,
)
async def update_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
user = await update_user_timezone(user_id, str(request.timezone))
return TimezoneResponse(timezone=user.timezone)
@v1_router.get(
"/auth/user/preferences",
summary="Get notification preferences",
tags=["auth"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_preferences(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> NotificationPreference:
preferences = await get_user_notification_preference(user_id)
return preferences
@@ -204,10 +166,10 @@ async def get_preferences(
"/auth/user/preferences",
summary="Update notification preferences",
tags=["auth"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def update_preferences(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
preferences: NotificationPreferenceDTO = Body(...),
) -> NotificationPreference:
output = await update_user_notification_preference(user_id, preferences)
@@ -223,9 +185,9 @@ async def update_preferences(
"/onboarding",
summary="Get onboarding status",
tags=["onboarding"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_onboarding(user_id: Annotated[str, Security(get_user_id)]):
async def get_onboarding(user_id: Annotated[str, Depends(get_user_id)]):
return await get_user_onboarding(user_id)
@@ -233,10 +195,10 @@ async def get_onboarding(user_id: Annotated[str, Security(get_user_id)]):
"/onboarding",
summary="Update onboarding progress",
tags=["onboarding"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def update_onboarding(
user_id: Annotated[str, Security(get_user_id)], data: UserOnboardingUpdate
user_id: Annotated[str, Depends(get_user_id)], data: UserOnboardingUpdate
):
return await update_user_onboarding(user_id, data)
@@ -245,10 +207,10 @@ async def update_onboarding(
"/onboarding/agents",
summary="Get recommended agents",
tags=["onboarding"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_onboarding_agents(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
):
return await get_recommended_agents(user_id)
@@ -257,7 +219,7 @@ async def get_onboarding_agents(
"/onboarding/enabled",
summary="Check onboarding enabled",
tags=["onboarding", "public"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def is_onboarding_enabled():
return await onboarding_enabled()
@@ -272,7 +234,7 @@ async def is_onboarding_enabled():
path="/blocks",
summary="List available blocks",
tags=["blocks"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in get_blocks().values()]
@@ -286,7 +248,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
path="/blocks/{block_id}/execute",
summary="Execute graph block",
tags=["blocks"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
obj = get_block(block_id)
@@ -303,10 +265,10 @@ async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlock
path="/files/upload",
summary="Upload file to cloud storage",
tags=["files"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def upload_file(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
file: UploadFile = File(...),
provider: str = "gcs",
expiration_hours: int = 24,
@@ -394,10 +356,10 @@ async def upload_file(
path="/credits",
tags=["credits"],
summary="Get user credits",
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_user_credits(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, int]:
return {"credits": await _user_credit_model.get_credits(user_id)}
@@ -406,10 +368,10 @@ async def get_user_credits(
path="/credits",
summary="Request credit top up",
tags=["credits"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def request_top_up(
request: RequestTopUp, user_id: Annotated[str, Security(get_user_id)]
request: RequestTopUp, user_id: Annotated[str, Depends(get_user_id)]
):
checkout_url = await _user_credit_model.top_up_intent(
user_id, request.credit_amount
@@ -421,10 +383,10 @@ async def request_top_up(
path="/credits/{transaction_key}/refund",
summary="Refund credit transaction",
tags=["credits"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def refund_top_up(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
transaction_key: str,
metadata: dict[str, str],
) -> int:
@@ -435,9 +397,9 @@ async def refund_top_up(
path="/credits",
summary="Fulfill checkout session",
tags=["credits"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
async def fulfill_checkout(user_id: Annotated[str, Depends(get_user_id)]):
await _user_credit_model.fulfill_checkout(user_id=user_id)
return Response(status_code=200)
@@ -446,10 +408,10 @@ async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
path="/credits/auto-top-up",
summary="Configure auto top up",
tags=["credits"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def configure_user_auto_top_up(
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
request: AutoTopUpConfig, user_id: Annotated[str, Depends(get_user_id)]
) -> str:
if request.threshold < 0:
raise ValueError("Threshold must be greater than 0")
@@ -475,10 +437,10 @@ async def configure_user_auto_top_up(
path="/credits/auto-top-up",
summary="Get auto top up",
tags=["credits"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_user_auto_top_up(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> AutoTopUpConfig:
return await get_auto_top_up(user_id)
@@ -528,10 +490,10 @@ async def stripe_webhook(request: Request):
path="/credits/manage",
tags=["credits"],
summary="Manage payment methods",
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def manage_payment_method(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, str]:
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
@@ -540,10 +502,10 @@ async def manage_payment_method(
path="/credits/transactions",
tags=["credits"],
summary="Get credit history",
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_credit_history(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
transaction_time: datetime | None = None,
transaction_type: str | None = None,
transaction_count_limit: int = 100,
@@ -563,10 +525,10 @@ async def get_credit_history(
path="/credits/refunds",
tags=["credits"],
summary="Get refund requests",
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_refund_requests(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> list[RefundRequest]:
return await _user_credit_model.get_refund_requests(user_id)
@@ -584,10 +546,10 @@ class DeleteGraphResponse(TypedDict):
path="/graphs",
summary="List user graphs",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def list_graphs(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> Sequence[graph_db.GraphMeta]:
return await graph_db.list_graphs(filter_by="active", user_id=user_id)
@@ -596,17 +558,17 @@ async def list_graphs(
path="/graphs/{graph_id}",
summary="Get specific graph",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(
path="/graphs/{graph_id}/versions/{version}",
summary="Get graph version",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_graph(
graph_id: str,
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
version: int | None = None,
for_export: bool = False,
) -> graph_db.GraphModel:
@@ -626,10 +588,10 @@ async def get_graph(
path="/graphs/{graph_id}/versions",
summary="Get all graph versions",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_graph_all_versions(
graph_id: str, user_id: Annotated[str, Security(get_user_id)]
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not graphs:
@@ -641,11 +603,11 @@ async def get_graph_all_versions(
path="/graphs",
summary="Create new graph",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def create_new_graph(
create_graph: CreateGraph,
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.GraphModel:
graph = graph_db.make_graph_model(create_graph.graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
@@ -662,10 +624,10 @@ async def create_new_graph(
path="/graphs/{graph_id}",
summary="Delete graph permanently",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def delete_graph(
graph_id: str, user_id: Annotated[str, Security(get_user_id)]
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> DeleteGraphResponse:
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
await on_graph_deactivate(active_version, user_id=user_id)
@@ -677,12 +639,12 @@ async def delete_graph(
path="/graphs/{graph_id}",
summary="Update graph version",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def update_graph(
graph_id: str,
graph: graph_db.Graph,
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.GraphModel:
# Sanity check
if graph.id and graph.id != graph_id:
@@ -733,12 +695,12 @@ async def update_graph(
path="/graphs/{graph_id}/versions/active",
summary="Set active graph version",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def set_graph_active_version(
graph_id: str,
request_body: SetGraphActiveVersion,
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
):
new_active_version = request_body.active_graph_version
new_active_graph = await graph_db.get_graph(
@@ -772,11 +734,11 @@ async def set_graph_active_version(
path="/graphs/{graph_id}/execute/{graph_version}",
summary="Execute graph agent",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def execute_graph(
graph_id: str,
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
credentials_inputs: Annotated[
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
@@ -818,10 +780,10 @@ async def execute_graph(
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
summary="Stop graph execution",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def stop_graph_run(
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Security(get_user_id)]
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> execution_db.GraphExecutionMeta | None:
res = await _stop_graph_run(
user_id=user_id,
@@ -858,48 +820,39 @@ async def _stop_graph_run(
@v1_router.get(
path="/executions",
summary="List all executions",
summary="Get all executions",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def list_graphs_executions(
user_id: Annotated[str, Security(get_user_id)],
async def get_graphs_executions(
user_id: Annotated[str, Depends(get_user_id)],
) -> list[execution_db.GraphExecutionMeta]:
return await execution_db.get_graph_executions(user_id=user_id)
@v1_router.get(
path="/graphs/{graph_id}/executions",
summary="List graph executions",
summary="Get graph executions",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def list_graph_executions(
async def get_graph_executions(
graph_id: str,
user_id: Annotated[str, Security(get_user_id)],
page: int = Query(1, ge=1, description="Page number (1-indexed)"),
page_size: int = Query(
25, ge=1, le=100, description="Number of executions per page"
),
) -> execution_db.GraphExecutionsPaginated:
return await execution_db.get_graph_executions_paginated(
graph_id=graph_id,
user_id=user_id,
page=page,
page_size=page_size,
)
user_id: Annotated[str, Depends(get_user_id)],
) -> list[execution_db.GraphExecutionMeta]:
return await execution_db.get_graph_executions(graph_id=graph_id, user_id=user_id)
@v1_router.get(
path="/graphs/{graph_id}/executions/{graph_exec_id}",
summary="Get execution details",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_graph_execution(
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> execution_db.GraphExecution | execution_db.GraphExecutionWithNodes:
graph = await graph_db.get_graph(graph_id=graph_id, user_id=user_id)
if not graph:
@@ -924,12 +877,12 @@ async def get_graph_execution(
path="/executions/{graph_exec_id}",
summary="Delete graph execution",
tags=["graphs"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
status_code=HTTP_204_NO_CONTENT,
)
async def delete_graph_execution(
graph_exec_id: str,
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> None:
await execution_db.delete_graph_execution(
graph_exec_id=graph_exec_id, user_id=user_id
@@ -953,10 +906,10 @@ class ScheduleCreationRequest(pydantic.BaseModel):
path="/graphs/{graph_id}/schedules",
summary="Create execution schedule",
tags=["schedules"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def create_graph_execution_schedule(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str = Path(..., description="ID of the graph to schedule"),
schedule_params: ScheduleCreationRequest = Body(),
) -> scheduler.GraphExecutionJobInfo:
@@ -971,99 +924,53 @@ async def create_graph_execution_schedule(
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
)
user = await get_user_by_id(user_id)
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
# Convert cron expression from user timezone to UTC
try:
utc_cron = convert_cron_to_utc(schedule_params.cron, user_timezone)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid cron expression for timezone {user_timezone}: {e}",
)
result = await get_scheduler_client().add_execution_schedule(
return await get_scheduler_client().add_execution_schedule(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
name=schedule_params.name,
cron=utc_cron, # Send UTC cron to scheduler
cron=schedule_params.cron,
input_data=schedule_params.inputs,
input_credentials=schedule_params.credentials,
)
# Convert the next_run_time back to user timezone for display
if result.next_run_time:
result.next_run_time = convert_utc_time_to_user_timezone(
result.next_run_time, user_timezone
)
return result
@v1_router.get(
path="/graphs/{graph_id}/schedules",
summary="List execution schedules for a graph",
tags=["schedules"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def list_graph_execution_schedules(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str = Path(),
) -> list[scheduler.GraphExecutionJobInfo]:
schedules = await get_scheduler_client().get_execution_schedules(
return await get_scheduler_client().get_execution_schedules(
user_id=user_id,
graph_id=graph_id,
)
# Get user timezone for conversion
user = await get_user_by_id(user_id)
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
# Convert next_run_time to user timezone for display
for schedule in schedules:
if schedule.next_run_time:
schedule.next_run_time = convert_utc_time_to_user_timezone(
schedule.next_run_time, user_timezone
)
return schedules
@v1_router.get(
path="/schedules",
summary="List execution schedules for a user",
tags=["schedules"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def list_all_graphs_execution_schedules(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> list[scheduler.GraphExecutionJobInfo]:
schedules = await get_scheduler_client().get_execution_schedules(user_id=user_id)
# Get user timezone for conversion
user = await get_user_by_id(user_id)
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
# Convert UTC next_run_time to user timezone for display
for schedule in schedules:
if schedule.next_run_time:
schedule.next_run_time = convert_utc_time_to_user_timezone(
schedule.next_run_time, user_timezone
)
return schedules
return await get_scheduler_client().get_execution_schedules(user_id=user_id)
@v1_router.delete(
path="/schedules/{schedule_id}",
summary="Delete execution schedule",
tags=["schedules"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def delete_graph_execution_schedule(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
schedule_id: str = Path(..., description="ID of the schedule to delete"),
) -> dict[str, Any]:
try:
@@ -1086,10 +993,10 @@ async def delete_graph_execution_schedule(
summary="Create new API key",
response_model=CreateAPIKeyResponse,
tags=["api-keys"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def create_api_key(
request: CreateAPIKeyRequest, user_id: Annotated[str, Security(get_user_id)]
request: CreateAPIKeyRequest, user_id: Annotated[str, Depends(get_user_id)]
) -> CreateAPIKeyResponse:
"""Create a new API key"""
try:
@@ -1117,10 +1024,10 @@ async def create_api_key(
summary="List user API keys",
response_model=list[APIKeyWithoutHash] | dict[str, str],
tags=["api-keys"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_api_keys(
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> list[APIKeyWithoutHash]:
"""List all API keys for the user"""
try:
@@ -1138,10 +1045,10 @@ async def get_api_keys(
summary="Get specific API key",
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
async def get_api_key(
key_id: str, user_id: Annotated[str, Security(get_user_id)]
key_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> APIKeyWithoutHash:
"""Get a specific API key"""
try:
@@ -1162,10 +1069,11 @@ async def get_api_key(
summary="Revoke API key",
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
@feature_flag("api-keys-enabled")
async def delete_api_key(
key_id: str, user_id: Annotated[str, Security(get_user_id)]
key_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Optional[APIKeyWithoutHash]:
"""Revoke an API key"""
try:
@@ -1190,10 +1098,11 @@ async def delete_api_key(
summary="Suspend API key",
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
@feature_flag("api-keys-enabled")
async def suspend_key(
key_id: str, user_id: Annotated[str, Security(get_user_id)]
key_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Optional[APIKeyWithoutHash]:
"""Suspend an API key"""
try:
@@ -1215,12 +1124,13 @@ async def suspend_key(
summary="Update key permissions",
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Security(requires_user)],
dependencies=[Depends(auth_middleware)],
)
@feature_flag("api-keys-enabled")
async def update_permissions(
key_id: str,
request: UpdatePermissionsRequest,
user_id: Annotated[str, Security(get_user_id)],
user_id: Annotated[str, Depends(get_user_id)],
) -> Optional[APIKeyWithoutHash]:
"""Update API key permissions"""
try:

View File

@@ -2,6 +2,7 @@ import json
from io import BytesIO
from unittest.mock import AsyncMock, Mock, patch
import autogpt_libs.auth.depends
import fastapi
import fastapi.testclient
import pytest
@@ -13,7 +14,9 @@ from pytest_snapshot.plugin import Snapshot
import backend.server.routers.v1 as v1_routes
from backend.data.credit import AutoTopUpConfig
from backend.data.graph import GraphModel
from backend.server.conftest import TEST_USER_ID
from backend.server.routers.v1 import upload_file
from backend.server.utils import get_user_id
app = fastapi.FastAPI()
app.include_router(v1_routes.v1_router)
@@ -21,26 +24,31 @@ app.include_router(v1_routes.v1_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
def override_auth_middleware(request: fastapi.Request) -> dict[str, str]:
"""Override auth middleware for testing"""
return {"sub": TEST_USER_ID, "role": "user", "email": "test@example.com"}
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def override_get_user_id() -> str:
"""Override get_user_id for testing"""
return TEST_USER_ID
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
override_auth_middleware
)
app.dependency_overrides[get_user_id] = override_get_user_id
# Auth endpoints tests
def test_get_or_create_user_route(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test get or create user endpoint"""
mock_user = Mock()
mock_user.model_dump.return_value = {
"id": test_user_id,
"id": TEST_USER_ID,
"email": "test@example.com",
"name": "Test User",
}
@@ -255,7 +263,6 @@ def test_get_auto_top_up(
def test_get_graphs(
mocker: pytest_mock.MockFixture,
snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test get graphs endpoint"""
mock_graph = GraphModel(
@@ -264,7 +271,7 @@ def test_get_graphs(
is_active=True,
name="Test Graph",
description="A test graph",
user_id=test_user_id,
user_id="test-user-id",
)
mocker.patch(
@@ -289,7 +296,6 @@ def test_get_graphs(
def test_get_graph(
mocker: pytest_mock.MockFixture,
snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test get single graph endpoint"""
mock_graph = GraphModel(
@@ -298,7 +304,7 @@ def test_get_graph(
is_active=True,
name="Test Graph",
description="A test graph",
user_id=test_user_id,
user_id="test-user-id",
)
mocker.patch(
@@ -337,7 +343,6 @@ def test_get_graph_not_found(
def test_delete_graph(
mocker: pytest_mock.MockFixture,
snapshot: Snapshot,
test_user_id: str,
) -> None:
"""Test delete graph endpoint"""
# Mock active graph for deactivation
@@ -347,7 +352,7 @@ def test_delete_graph(
is_active=True,
name="Test Graph",
description="A test graph",
user_id=test_user_id,
user_id="test-user-id",
)
mocker.patch(
@@ -394,7 +399,7 @@ def test_missing_required_field() -> None:
@pytest.mark.asyncio
async def test_upload_file_success(test_user_id: str):
async def test_upload_file_success():
"""Test successful file upload."""
# Create mock upload file
file_content = b"test file content"
@@ -420,7 +425,7 @@ async def test_upload_file_success(test_user_id: str):
result = await upload_file(
file=upload_file_mock,
user_id=test_user_id,
user_id="test-user-123",
provider="gcs",
expiration_hours=24,
)
@@ -441,12 +446,12 @@ async def test_upload_file_success(test_user_id: str):
filename="test.txt",
provider="gcs",
expiration_hours=24,
user_id=test_user_id,
user_id="test-user-123",
)
@pytest.mark.asyncio
async def test_upload_file_no_filename(test_user_id: str):
async def test_upload_file_no_filename():
"""Test file upload without filename."""
file_content = b"test content"
file_obj = BytesIO(file_content)
@@ -471,7 +476,7 @@ async def test_upload_file_no_filename(test_user_id: str):
upload_file_mock.read = AsyncMock(return_value=file_content)
result = await upload_file(file=upload_file_mock, user_id=test_user_id)
result = await upload_file(file=upload_file_mock, user_id="test-user-123")
assert result.file_name == "uploaded_file"
assert result.content_type == "application/octet-stream"
@@ -481,7 +486,7 @@ async def test_upload_file_no_filename(test_user_id: str):
@pytest.mark.asyncio
async def test_upload_file_invalid_expiration(test_user_id: str):
async def test_upload_file_invalid_expiration():
"""Test file upload with invalid expiration hours."""
file_obj = BytesIO(b"content")
upload_file_mock = UploadFile(
@@ -493,7 +498,7 @@ async def test_upload_file_invalid_expiration(test_user_id: str):
# Test expiration too short
with pytest.raises(HTTPException) as exc_info:
await upload_file(
file=upload_file_mock, user_id=test_user_id, expiration_hours=0
file=upload_file_mock, user_id="test-user-123", expiration_hours=0
)
assert exc_info.value.status_code == 400
assert "between 1 and 48" in exc_info.value.detail
@@ -501,14 +506,14 @@ async def test_upload_file_invalid_expiration(test_user_id: str):
# Test expiration too long
with pytest.raises(HTTPException) as exc_info:
await upload_file(
file=upload_file_mock, user_id=test_user_id, expiration_hours=49
file=upload_file_mock, user_id="test-user-123", expiration_hours=49
)
assert exc_info.value.status_code == 400
assert "between 1 and 48" in exc_info.value.detail
@pytest.mark.asyncio
async def test_upload_file_virus_scan_failure(test_user_id: str):
async def test_upload_file_virus_scan_failure():
"""Test file upload when virus scan fails."""
file_content = b"malicious content"
file_obj = BytesIO(file_content)
@@ -525,11 +530,11 @@ async def test_upload_file_virus_scan_failure(test_user_id: str):
upload_file_mock.read = AsyncMock(return_value=file_content)
with pytest.raises(RuntimeError, match="Virus detected!"):
await upload_file(file=upload_file_mock, user_id=test_user_id)
await upload_file(file=upload_file_mock, user_id="test-user-123")
@pytest.mark.asyncio
async def test_upload_file_cloud_storage_failure(test_user_id: str):
async def test_upload_file_cloud_storage_failure():
"""Test file upload when cloud storage fails."""
file_content = b"test content"
file_obj = BytesIO(file_content)
@@ -551,11 +556,11 @@ async def test_upload_file_cloud_storage_failure(test_user_id: str):
upload_file_mock.read = AsyncMock(return_value=file_content)
with pytest.raises(RuntimeError, match="Storage error!"):
await upload_file(file=upload_file_mock, user_id=test_user_id)
await upload_file(file=upload_file_mock, user_id="test-user-123")
@pytest.mark.asyncio
async def test_upload_file_size_limit_exceeded(test_user_id: str):
async def test_upload_file_size_limit_exceeded():
"""Test file upload when file size exceeds the limit."""
# Create a file that exceeds the default 256MB limit
large_file_content = b"x" * (257 * 1024 * 1024) # 257MB
@@ -569,14 +574,14 @@ async def test_upload_file_size_limit_exceeded(test_user_id: str):
upload_file_mock.read = AsyncMock(return_value=large_file_content)
with pytest.raises(HTTPException) as exc_info:
await upload_file(file=upload_file_mock, user_id=test_user_id)
await upload_file(file=upload_file_mock, user_id="test-user-123")
assert exc_info.value.status_code == 400
assert "exceeds the maximum allowed size of 256MB" in exc_info.value.detail
@pytest.mark.asyncio
async def test_upload_file_gcs_not_configured_fallback(test_user_id: str):
async def test_upload_file_gcs_not_configured_fallback():
"""Test file upload fallback to base64 when GCS is not configured."""
file_content = b"test file content"
file_obj = BytesIO(file_content)
@@ -597,7 +602,7 @@ async def test_upload_file_gcs_not_configured_fallback(test_user_id: str):
upload_file_mock.read = AsyncMock(return_value=file_content)
result = await upload_file(file=upload_file_mock, user_id=test_user_id)
result = await upload_file(file=upload_file_mock, user_id="test-user-123")
# Verify fallback behavior
assert result.file_name == "test.txt"

View File

@@ -0,0 +1,144 @@
"""Common test fixtures with proper setup and teardown."""
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from unittest.mock import Mock, patch
import pytest
from prisma import Prisma
@pytest.fixture
async def test_db_connection() -> AsyncGenerator[Prisma, None]:
"""Provide a test database connection with proper cleanup.
This fixture ensures the database connection is properly
closed after the test, even if the test fails.
"""
db = Prisma()
try:
await db.connect()
yield db
finally:
await db.disconnect()
@pytest.fixture
def mock_transaction():
"""Mock database transaction with proper async context manager."""
@asynccontextmanager
async def mock_context(*args, **kwargs):
yield None
with patch("backend.data.db.locked_transaction", side_effect=mock_context) as mock:
yield mock
@pytest.fixture
def isolated_app_state():
"""Fixture that ensures app state is isolated between tests."""
# Example: Save original state
# from backend.server.app import app
# original_overrides = app.dependency_overrides.copy()
# try:
# yield app
# finally:
# # Restore original state
# app.dependency_overrides = original_overrides
# For now, just yield None as this is an example
yield None
@pytest.fixture
def cleanup_files():
"""Fixture to track and cleanup files created during tests."""
created_files = []
def track_file(filepath: str):
created_files.append(filepath)
yield track_file
# Cleanup
import os
for filepath in created_files:
try:
if os.path.exists(filepath):
os.remove(filepath)
except Exception as e:
print(f"Warning: Failed to cleanup {filepath}: {e}")
@pytest.fixture
async def async_mock_with_cleanup():
"""Create async mocks that are properly cleaned up."""
mocks = []
def create_mock(**kwargs):
mock = Mock(**kwargs)
mocks.append(mock)
return mock
yield create_mock
# Reset all mocks
for mock in mocks:
mock.reset_mock()
class TestDatabaseIsolation:
"""Example of proper test isolation with database operations."""
@pytest.fixture(autouse=True)
async def setup_and_teardown(self, test_db_connection):
"""Setup and teardown for each test method."""
# Setup: Clear test data
await test_db_connection.user.delete_many(
where={"email": {"contains": "@test.example"}}
)
yield
# Teardown: Clear test data again
await test_db_connection.user.delete_many(
where={"email": {"contains": "@test.example"}}
)
@pytest.fixture(scope="session")
async def test_create_user(self, test_db_connection):
"""Test that demonstrates proper isolation."""
# This test has access to a clean database
user = await test_db_connection.user.create(
data={
"id": "test-user-id",
"email": "test@test.example",
"name": "Test User",
}
)
assert user.email == "test@test.example"
# User will be cleaned up automatically
@pytest.fixture(scope="function") # Explicitly use function scope
def reset_singleton_state():
"""Reset singleton state between tests."""
# Example: Reset a singleton instance
# from backend.data.some_singleton import SingletonClass
# # Save original state
# original_instance = getattr(SingletonClass, "_instance", None)
# try:
# # Clear singleton
# SingletonClass._instance = None
# yield
# finally:
# # Restore original state
# SingletonClass._instance = original_instance
# For now, just yield None as this is an example
yield None

View File

@@ -0,0 +1,74 @@
"""Common test utilities and constants for server tests."""
from typing import Any, Dict
from unittest.mock import Mock
import pytest
# Test ID constants
TEST_USER_ID = "test-user-id"
ADMIN_USER_ID = "admin-user-id"
TARGET_USER_ID = "target-user-id"
# Common test data constants
FIXED_TIMESTAMP = "2024-01-01T00:00:00Z"
TRANSACTION_UUID = "transaction-123-uuid"
METRIC_UUID = "metric-123-uuid"
ANALYTICS_UUID = "analytics-123-uuid"
def create_mock_with_id(mock_id: str) -> Mock:
"""Create a mock object with an id attribute.
Args:
mock_id: The ID value to set on the mock
Returns:
Mock object with id attribute set
"""
return Mock(id=mock_id)
def assert_status_and_parse_json(
response: Any, expected_status: int = 200
) -> Dict[str, Any]:
"""Assert response status and return parsed JSON.
Args:
response: The HTTP response object
expected_status: Expected status code (default: 200)
Returns:
Parsed JSON response data
Raises:
AssertionError: If status code doesn't match expected
"""
assert (
response.status_code == expected_status
), f"Expected status {expected_status}, got {response.status_code}: {response.text}"
return response.json()
@pytest.mark.parametrize(
"metric_value,metric_name,data_string",
[
(100, "api_calls_count", "external_api"),
(0, "error_count", "no_errors"),
(-5.2, "temperature_delta", "cooling"),
(1.23456789, "precision_test", "float_precision"),
(999999999, "large_number", "max_value"),
],
)
def parametrized_metric_values_decorator(func):
"""Decorator for parametrized metric value tests."""
return pytest.mark.parametrize(
"metric_value,metric_name,data_string",
[
(100, "api_calls_count", "external_api"),
(0, "error_count", "no_errors"),
(-5.2, "temperature_delta", "cooling"),
(1.23456789, "precision_test", "float_precision"),
(999999999, "large_number", "max_value"),
],
)(func)

View File

@@ -0,0 +1,11 @@
from autogpt_libs.auth.depends import requires_user
from autogpt_libs.auth.models import User
from fastapi import Depends
from backend.util.settings import Settings
settings = Settings()
def get_user_id(user: User = Depends(requires_user)) -> str:
return user.user_id

View File

@@ -1,109 +0,0 @@
"""
API Key authentication utilities for FastAPI applications.
"""
import inspect
import secrets
from typing import Any, Callable, Optional
from fastapi import HTTPException, Request
from fastapi.security import APIKeyHeader
from starlette.status import HTTP_401_UNAUTHORIZED
from backend.util.exceptions import MissingConfigError
class APIKeyAuthenticator(APIKeyHeader):
"""
Configurable API key authenticator for FastAPI applications,
with support for custom validation functions.
This class provides a flexible way to implement API key authentication with optional
custom validation logic. It can be used for simple token matching
or more complex validation scenarios like database lookups.
Examples:
Simple token validation:
```python
api_key_auth = APIKeyAuthenticator(
header_name="X-API-Key",
expected_token="your-secret-token"
)
@app.get("/protected", dependencies=[Security(api_key_auth)])
def protected_endpoint():
return {"message": "Access granted"}
```
Custom validation with database lookup:
```python
async def validate_with_db(api_key: str):
api_key_obj = await db.get_api_key(api_key)
return api_key_obj if api_key_obj and api_key_obj.is_active else None
api_key_auth = APIKeyAuthenticator(
header_name="X-API-Key",
validator=validate_with_db
)
```
Args:
header_name (str): The name of the header containing the API key
expected_token (Optional[str]): The expected API key value for simple token matching
validator (Optional[Callable]): Custom validation function that takes an API key
string and returns a boolean or object. Can be async.
status_if_missing (int): HTTP status code to use for validation errors
message_if_invalid (str): Error message to return when validation fails
"""
def __init__(
self,
header_name: str,
expected_token: Optional[str] = None,
validator: Optional[Callable[[str], bool]] = None,
status_if_missing: int = HTTP_401_UNAUTHORIZED,
message_if_invalid: str = "Invalid API key",
):
super().__init__(
name=header_name,
scheme_name=f"{__class__.__name__}-{header_name}",
auto_error=False,
)
self.expected_token = expected_token
self.custom_validator = validator
self.status_if_missing = status_if_missing
self.message_if_invalid = message_if_invalid
async def __call__(self, request: Request) -> Any:
api_key = await super()(request)
if api_key is None:
raise HTTPException(
status_code=self.status_if_missing, detail="No API key in request"
)
# Use custom validation if provided, otherwise use default equality check
validator = self.custom_validator or self.default_validator
result = (
await validator(api_key)
if inspect.iscoroutinefunction(validator)
else validator(api_key)
)
if not result:
raise HTTPException(
status_code=self.status_if_missing, detail=self.message_if_invalid
)
# Store validation result in request state if it's not just a boolean
if result is not True:
request.state.api_key = result
return result
async def default_validator(self, api_key: str) -> bool:
if not self.expected_token:
raise MissingConfigError(
f"{self.__class__.__name__}.expected_token is not set; "
"either specify it or provide a custom validator"
)
return secrets.compare_digest(api_key, self.expected_token)

View File

@@ -83,7 +83,7 @@ class AutoModManager:
f"Moderating inputs for graph execution {graph_exec.graph_exec_id}"
)
try:
moderation_passed, content_id = await self._moderate_content(
moderation_passed = await self._moderate_content(
content,
{
"user_id": graph_exec.user_id,
@@ -99,7 +99,7 @@ class AutoModManager:
)
# Update node statuses for frontend display before raising error
await self._update_failed_nodes_for_moderation(
db_client, graph_exec.graph_exec_id, "input", content_id
db_client, graph_exec.graph_exec_id, "input"
)
return ModerationError(
@@ -107,7 +107,6 @@ class AutoModManager:
user_id=graph_exec.user_id,
graph_exec_id=graph_exec.graph_exec_id,
moderation_type="input",
content_id=content_id,
)
return None
@@ -147,10 +146,8 @@ class AutoModManager:
return None
# Get completed executions and collect outputs
completed_executions = await db_client.get_node_executions( # type: ignore
graph_exec_id=graph_exec_id,
statuses=[ExecutionStatus.COMPLETED],
include_exec_data=True,
completed_executions = await db_client.get_node_executions(
graph_exec_id, statuses=[ExecutionStatus.COMPLETED], include_exec_data=True
)
if not completed_executions:
@@ -170,7 +167,7 @@ class AutoModManager:
# Run moderation
logger.warning(f"Moderating outputs for graph execution {graph_exec_id}")
try:
moderation_passed, content_id = await self._moderate_content(
moderation_passed = await self._moderate_content(
content,
{
"user_id": user_id,
@@ -184,7 +181,7 @@ class AutoModManager:
logger.warning(f"Moderation failed for graph execution {graph_exec_id}")
# Update node statuses for frontend display before raising error
await self._update_failed_nodes_for_moderation(
db_client, graph_exec_id, "output", content_id
db_client, graph_exec_id, "output"
)
return ModerationError(
@@ -192,7 +189,6 @@ class AutoModManager:
user_id=user_id,
graph_exec_id=graph_exec_id,
moderation_type="output",
content_id=content_id,
)
return None
@@ -216,11 +212,10 @@ class AutoModManager:
db_client: "DatabaseManagerAsyncClient",
graph_exec_id: str,
moderation_type: Literal["input", "output"],
content_id: str | None = None,
):
"""Update node execution statuses for frontend display when moderation fails"""
# Import here to avoid circular imports
from backend.util.clients import get_async_execution_event_bus
from backend.executor.manager import send_async_execution_update
if moderation_type == "input":
# For input moderation, mark queued/running/incomplete nodes as failed
@@ -234,20 +229,13 @@ class AutoModManager:
target_statuses = [ExecutionStatus.COMPLETED]
# Get the executions that need to be updated
executions_to_update = await db_client.get_node_executions( # type: ignore
graph_exec_id=graph_exec_id,
statuses=target_statuses,
include_exec_data=True,
executions_to_update = await db_client.get_node_executions(
graph_exec_id, statuses=target_statuses, include_exec_data=True
)
if not executions_to_update:
return
# Create error message with content_id if available
error_message = "Failed due to content moderation"
if content_id:
error_message += f" (Moderation ID: {content_id})"
# Prepare database update tasks
exec_updates = []
for exec_entry in executions_to_update:
@@ -257,11 +245,11 @@ class AutoModManager:
if exec_entry.input_data:
for name in exec_entry.input_data.keys():
cleared_inputs[name] = [error_message]
cleared_inputs[name] = ["Failed due to content moderation"]
if exec_entry.output_data:
for name in exec_entry.output_data.keys():
cleared_outputs[name] = [error_message]
cleared_outputs[name] = ["Failed due to content moderation"]
# Add update task to list
exec_updates.append(
@@ -269,7 +257,7 @@ class AutoModManager:
exec_entry.node_exec_id,
status=ExecutionStatus.FAILED,
stats={
"error": error_message,
"error": "Failed due to content moderation",
"cleared_inputs": cleared_inputs,
"cleared_outputs": cleared_outputs,
},
@@ -280,24 +268,19 @@ class AutoModManager:
updated_execs = await asyncio.gather(*exec_updates)
# Send all websocket updates in parallel
event_bus = get_async_execution_event_bus()
await asyncio.gather(
*[
event_bus.publish(updated_exec)
send_async_execution_update(updated_exec)
for updated_exec in updated_execs
if updated_exec is not None
]
)
async def _moderate_content(
self, content: str, metadata: dict[str, Any]
) -> tuple[bool, str | None]:
async def _moderate_content(self, content: str, metadata: dict[str, Any]) -> bool:
"""Moderate content using AutoMod API
Returns:
Tuple of (approval_status, content_id)
- approval_status: True if approved or timeout occurred, False if rejected
- content_id: Reference ID from moderation API, or None if not available
True: Content approved or timeout occurred
False: Content rejected by moderation
Raises:
asyncio.TimeoutError: When moderation times out (should be bypassed)
@@ -315,12 +298,12 @@ class AutoModManager:
logger.debug(
f"Content approved for {metadata.get('graph_exec_id', 'unknown')}"
)
return True, response.content_id
return True
else:
reasons = [r.reason for r in response.moderation_results if r.reason]
error_msg = f"Content rejected by AutoMod: {'; '.join(reasons)}"
logger.warning(f"Content rejected: {error_msg}")
return False, response.content_id
return False
except asyncio.TimeoutError:
# Re-raise timeout to be handled by calling methods
@@ -330,7 +313,7 @@ class AutoModManager:
raise
except Exception as e:
logger.error(f"AutoMod moderation error: {e}")
return self.config.fail_open, None
return self.config.fail_open
async def _make_request(self, request_data: AutoModRequest) -> AutoModResponse:
"""Make HTTP request to AutoMod API using the standard request utility"""

View File

@@ -26,9 +26,6 @@ class AutoModResponse(BaseModel):
"""Response model for AutoMod API"""
success: bool = Field(..., description="Whether the request was successful")
content_id: str = Field(
..., description="Unique reference ID for this moderation request"
)
status: str = Field(
..., description="Overall status: 'approved', 'rejected', 'flagged', 'pending'"
)

View File

@@ -1,8 +1,9 @@
import logging
import typing
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, Body, Security
from autogpt_libs.auth import requires_admin_user
from autogpt_libs.auth.depends import get_user_id
from fastapi import APIRouter, Body, Depends
from prisma.enums import CreditTransactionType
from backend.data.credit import admin_get_user_history, get_user_credit_model
@@ -17,7 +18,7 @@ _user_credit_model = get_user_credit_model()
router = APIRouter(
prefix="/admin",
tags=["credits", "admin"],
dependencies=[Security(requires_admin_user)],
dependencies=[Depends(requires_admin_user)],
)
@@ -28,16 +29,18 @@ async def add_user_credits(
user_id: typing.Annotated[str, Body()],
amount: typing.Annotated[int, Body()],
comments: typing.Annotated[str, Body()],
admin_user_id: str = Security(get_user_id),
admin_user: typing.Annotated[
str,
Depends(get_user_id),
],
):
logger.info(
f"Admin user {admin_user_id} is adding {amount} credits to user {user_id}"
)
""" """
logger.info(f"Admin user {admin_user} is adding {amount} credits to user {user_id}")
new_balance, transaction_key = await _user_credit_model._add_transaction(
user_id,
amount,
transaction_type=CreditTransactionType.GRANT,
metadata=SafeJson({"admin_id": admin_user_id, "reason": comments}),
metadata=SafeJson({"admin_id": admin_user, "reason": comments}),
)
return {
"new_balance": new_balance,
@@ -51,14 +54,17 @@ async def add_user_credits(
summary="Get All Users History",
)
async def admin_get_all_user_history(
admin_user_id: str = Security(get_user_id),
admin_user: typing.Annotated[
str,
Depends(get_user_id),
],
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
transaction_filter: typing.Optional[CreditTransactionType] = None,
):
""" """
logger.info(f"Admin user {admin_user_id} is getting grant history")
logger.info(f"Admin user {admin_user} is getting grant history")
try:
resp = await admin_get_user_history(
@@ -67,7 +73,7 @@ async def admin_get_all_user_history(
search=search,
transaction_filter=transaction_filter,
)
logger.info(f"Admin user {admin_user_id} got {len(resp.history)} grant history")
logger.info(f"Admin user {admin_user} got {len(resp.history)} grant history")
return resp
except Exception as e:
logger.exception(f"Error getting grant history: {e}")

View File

@@ -1,19 +1,20 @@
import json
from unittest.mock import AsyncMock
import autogpt_libs.auth
import autogpt_libs.auth.depends
import fastapi
import fastapi.testclient
import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from prisma import Json
from pytest_snapshot.plugin import Snapshot
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
import backend.server.v2.admin.model as admin_model
from backend.data.model import UserTransaction
from backend.util.models import Pagination
from backend.server.conftest import ADMIN_USER_ID, TARGET_USER_ID
from backend.server.model import Pagination
app = fastapi.FastAPI()
app.include_router(credit_admin_routes.router)
@@ -21,19 +22,25 @@ app.include_router(credit_admin_routes.router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def override_requires_admin_user() -> dict[str, str]:
"""Override admin user check for testing"""
return {"sub": ADMIN_USER_ID, "role": "admin"}
def override_get_user_id() -> str:
"""Override get_user_id for testing"""
return ADMIN_USER_ID
app.dependency_overrides[autogpt_libs.auth.requires_admin_user] = (
override_requires_admin_user
)
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = override_get_user_id
def test_add_user_credits_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
admin_user_id: str,
target_user_id: str,
) -> None:
"""Test successful credit addition by admin"""
# Mock the credit model
@@ -45,7 +52,7 @@ def test_add_user_credits_success(
)
request_data = {
"user_id": target_user_id,
"user_id": TARGET_USER_ID,
"amount": 500,
"comments": "Test credit grant for debugging",
}
@@ -60,12 +67,12 @@ def test_add_user_credits_success(
# Verify the function was called with correct parameters
mock_credit_model._add_transaction.assert_called_once()
call_args = mock_credit_model._add_transaction.call_args
assert call_args[0] == (target_user_id, 500)
assert call_args[0] == (TARGET_USER_ID, 500)
assert call_args[1]["transaction_type"] == prisma.enums.CreditTransactionType.GRANT
# Check that metadata is a Json object with the expected content
assert isinstance(call_args[1]["metadata"], Json)
assert call_args[1]["metadata"] == Json(
{"admin_id": admin_user_id, "reason": "Test credit grant for debugging"}
{"admin_id": ADMIN_USER_ID, "reason": "Test credit grant for debugging"}
)
# Snapshot test the response
@@ -283,10 +290,18 @@ def test_add_credits_invalid_request() -> None:
assert response.status_code == 422
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
def test_admin_endpoints_require_admin_role(mocker: pytest_mock.MockFixture) -> None:
"""Test that admin endpoints require admin role"""
# Simulate regular non-admin user
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
# Clear the admin override to test authorization
app.dependency_overrides.clear()
# Mock requires_admin_user to raise an exception
mocker.patch(
"autogpt_libs.auth.requires_admin_user",
side_effect=fastapi.HTTPException(
status_code=403, detail="Admin access required"
),
)
# Test add_credits endpoint
response = client.post(
@@ -297,8 +312,20 @@ def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
"comments": "test",
},
)
assert response.status_code == 403
assert (
response.status_code == 401
) # Auth middleware returns 401 when auth is disabled
# Test users_history endpoint
response = client.get("/admin/users_history")
assert response.status_code == 403
assert (
response.status_code == 401
) # Auth middleware returns 401 when auth is disabled
# Restore the override
app.dependency_overrides[autogpt_libs.auth.requires_admin_user] = (
override_requires_admin_user
)
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = (
override_get_user_id
)

View File

@@ -1,7 +1,7 @@
from pydantic import BaseModel
from backend.data.model import UserTransaction
from backend.util.models import Pagination
from backend.server.model import Pagination
class UserHistoryResponse(BaseModel):

View File

@@ -2,28 +2,26 @@ import logging
import tempfile
import typing
import autogpt_libs.auth
import autogpt_libs.auth.depends
import fastapi
import fastapi.responses
import prisma.enums
import backend.server.v2.store.db
import backend.server.v2.store.exceptions
import backend.server.v2.store.model
import backend.util.json
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
prefix="/admin",
tags=["store", "admin"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
)
router = fastapi.APIRouter(prefix="/admin", tags=["store", "admin"])
@router.get(
"/listings",
summary="Get Admin Listings History",
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)
async def get_admin_listings_with_versions(
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
@@ -68,11 +66,15 @@ async def get_admin_listings_with_versions(
"/submissions/{store_listing_version_id}/review",
summary="Review Store Submission",
response_model=backend.server.v2.store.model.StoreSubmission,
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)
async def review_submission(
store_listing_version_id: str,
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
user: typing.Annotated[
autogpt_libs.auth.models.User,
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
],
):
"""
Review a store listing submission.
@@ -80,7 +82,7 @@ async def review_submission(
Args:
store_listing_version_id: ID of the submission to review
request: Review details including approval status and comments
user_id: Authenticated admin user performing the review
user: Authenticated admin user performing the review
Returns:
StoreSubmission with updated review information
@@ -91,7 +93,7 @@ async def review_submission(
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
reviewer_id=user.user_id,
)
return submission
except Exception as e:
@@ -106,9 +108,13 @@ async def review_submission(
"/submissions/download/{store_listing_version_id}",
summary="Admin Download Agent File",
tags=["store", "admin"],
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)
async def admin_download_agent_file(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
user: typing.Annotated[
autogpt_libs.auth.models.User,
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
],
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
@@ -126,7 +132,7 @@ async def admin_download_agent_file(
HTTPException: If the agent is not found or an unexpected error occurs.
"""
graph_data = await backend.server.v2.store.db.get_agent_as_admin(
user_id=user_id,
user_id=user.user_id,
store_listing_version_id=store_listing_version_id,
)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"

View File

@@ -1,376 +0,0 @@
import functools
import logging
from datetime import datetime, timedelta, timezone
import prisma
import backend.data.block
from backend.blocks import load_all_blocks
from backend.blocks.llm import LlmModel
from backend.data.block import Block, BlockCategory, BlockSchema
from backend.data.credit import get_block_costs
from backend.integrations.providers import ProviderName
from backend.server.v2.builder.model import (
BlockCategoryResponse,
BlockData,
BlockResponse,
BlockType,
CountResponse,
Provider,
ProviderResponse,
SearchBlocksResponse,
)
from backend.util.models import Pagination
logger = logging.getLogger(__name__)
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
_static_counts_cache: dict | None = None
_suggested_blocks: list[BlockData] | None = None
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
categories: dict[BlockCategory, BlockCategoryResponse] = {}
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't have categories (all should have at least one)
if not block.categories:
continue
# Add block to the categories
for category in block.categories:
if category not in categories:
categories[category] = BlockCategoryResponse(
name=category.name.lower(),
total_blocks=0,
blocks=[],
)
categories[category].total_blocks += 1
# Append if the category has less than the specified number of blocks
if len(categories[category].blocks) < category_blocks:
categories[category].blocks.append(block.to_dict())
# Sort categories by name
return sorted(categories.values(), key=lambda x: x.name)
def get_blocks(
*,
category: str | None = None,
type: BlockType | None = None,
provider: ProviderName | None = None,
page: int = 1,
page_size: int = 50,
) -> BlockResponse:
"""
Get blocks based on either category, type or provider.
Providing nothing fetches all block types.
"""
# Only one of category, type, or provider can be specified
if (category and type) or (category and provider) or (type and provider):
raise ValueError("Only one of category, type, or provider can be specified")
blocks: list[Block[BlockSchema, BlockSchema]] = []
skip = (page - 1) * page_size
take = page_size
total = 0
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't match the category
if category and category not in {c.name.lower() for c in block.categories}:
continue
# Skip blocks that don't match the type
if (
(type == "input" and block.block_type.value != "Input")
or (type == "output" and block.block_type.value != "Output")
or (type == "action" and block.block_type.value in ("Input", "Output"))
):
continue
# Skip blocks that don't match the provider
if provider:
credentials_info = block.input_schema.get_credentials_fields_info().values()
if not any(provider in info.provider for info in credentials_info):
continue
total += 1
if skip > 0:
skip -= 1
continue
if take > 0:
take -= 1
blocks.append(block)
costs = get_block_costs()
return BlockResponse(
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
pagination=Pagination(
total_items=total,
total_pages=(total + page_size - 1) // page_size,
current_page=page,
page_size=page_size,
),
)
def search_blocks(
include_blocks: bool = True,
include_integrations: bool = True,
query: str = "",
page: int = 1,
page_size: int = 50,
) -> SearchBlocksResponse:
"""
Get blocks based on the filter and query.
`providers` only applies for `integrations` filter.
"""
blocks: list[Block[BlockSchema, BlockSchema]] = []
query = query.lower()
total = 0
skip = (page - 1) * page_size
take = page_size
block_count = 0
integration_count = 0
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't match the query
if (
query not in block.name.lower()
and query not in block.description.lower()
and not _matches_llm_model(block.input_schema, query)
):
continue
keep = False
credentials = list(block.input_schema.get_credentials_fields().values())
if include_integrations and len(credentials) > 0:
keep = True
integration_count += 1
if include_blocks and len(credentials) == 0:
keep = True
block_count += 1
if not keep:
continue
total += 1
if skip > 0:
skip -= 1
continue
if take > 0:
take -= 1
blocks.append(block)
costs = get_block_costs()
return SearchBlocksResponse(
blocks=BlockResponse(
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
pagination=Pagination(
total_items=total,
total_pages=(total + page_size - 1) // page_size,
current_page=page,
page_size=page_size,
),
),
total_block_count=block_count,
total_integration_count=integration_count,
)
def get_providers(
query: str = "",
page: int = 1,
page_size: int = 50,
) -> ProviderResponse:
providers = []
query = query.lower()
skip = (page - 1) * page_size
take = page_size
all_providers = _get_all_providers()
for provider in all_providers.values():
if (
query not in provider.name.value.lower()
and query not in provider.description.lower()
):
continue
if skip > 0:
skip -= 1
continue
if take > 0:
take -= 1
providers.append(provider)
total = len(all_providers)
return ProviderResponse(
providers=providers,
pagination=Pagination(
total_items=total,
total_pages=(total + page_size - 1) // page_size,
current_page=page,
page_size=page_size,
),
)
async def get_counts(user_id: str) -> CountResponse:
my_agents = await prisma.models.LibraryAgent.prisma().count(
where={
"userId": user_id,
"isDeleted": False,
"isArchived": False,
}
)
counts = await _get_static_counts()
return CountResponse(
my_agents=my_agents,
**counts,
)
async def _get_static_counts():
"""
Get counts of blocks, integrations, and marketplace agents.
This is cached to avoid unnecessary database queries and calculations.
Can't use functools.cache here because the function is async.
"""
global _static_counts_cache
if _static_counts_cache is not None:
return _static_counts_cache
all_blocks = 0
input_blocks = 0
action_blocks = 0
output_blocks = 0
integrations = 0
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
if block.disabled:
continue
all_blocks += 1
if block.block_type.value == "Input":
input_blocks += 1
elif block.block_type.value == "Output":
output_blocks += 1
else:
action_blocks += 1
credentials = list(block.input_schema.get_credentials_fields().values())
if len(credentials) > 0:
integrations += 1
marketplace_agents = await prisma.models.StoreAgent.prisma().count()
_static_counts_cache = {
"all_blocks": all_blocks,
"input_blocks": input_blocks,
"action_blocks": action_blocks,
"output_blocks": output_blocks,
"integrations": integrations,
"marketplace_agents": marketplace_agents,
}
return _static_counts_cache
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if field.annotation == LlmModel:
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
return True
return False
@functools.cache
def _get_all_providers() -> dict[ProviderName, Provider]:
providers: dict[ProviderName, Provider] = {}
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
if block.disabled:
continue
credentials_info = block.input_schema.get_credentials_fields_info().values()
for info in credentials_info:
for provider in info.provider: # provider is a ProviderName enum member
if provider in providers:
providers[provider].integration_count += 1
else:
providers[provider] = Provider(
name=provider, description="", integration_count=1
)
return providers
async def get_suggested_blocks(count: int = 5) -> list[BlockData]:
global _suggested_blocks
if _suggested_blocks is not None and len(_suggested_blocks) >= count:
return _suggested_blocks[:count]
_suggested_blocks = []
# Sum the number of executions for each block type
# Prisma cannot group by nested relations, so we do a raw query
# Calculate the cutoff timestamp
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
results = await prisma.get_client().query_raw(
"""
SELECT
agent_node."agentBlockId" AS block_id,
COUNT(execution.id) AS execution_count
FROM "AgentNodeExecution" execution
JOIN "AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
WHERE execution."endedTime" >= $1::timestamp
GROUP BY agent_node."agentBlockId"
ORDER BY execution_count DESC;
""",
timestamp_threshold,
)
# Get the top blocks based on execution count
# But ignore Input and Output blocks
blocks: list[tuple[BlockData, int]] = []
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
if block.disabled or block.block_type in (
backend.data.block.BlockType.INPUT,
backend.data.block.BlockType.OUTPUT,
backend.data.block.BlockType.AGENT,
):
continue
# Find the execution count for this block
execution_count = next(
(row["execution_count"] for row in results if row["block_id"] == block.id),
0,
)
blocks.append((block.to_dict(), execution_count))
# Sort blocks by execution count
blocks.sort(key=lambda x: x[1], reverse=True)
_suggested_blocks = [block[0] for block in blocks]
# Return the top blocks
return _suggested_blocks[:count]

View File

@@ -1,87 +0,0 @@
from typing import Any, Literal
from pydantic import BaseModel
import backend.server.v2.library.model as library_model
import backend.server.v2.store.model as store_model
from backend.integrations.providers import ProviderName
from backend.util.models import Pagination
FilterType = Literal[
"blocks",
"integrations",
"marketplace_agents",
"my_agents",
]
BlockType = Literal["all", "input", "action", "output"]
BlockData = dict[str, Any]
# Suggestions
class SuggestionsResponse(BaseModel):
otto_suggestions: list[str]
recent_searches: list[str]
providers: list[ProviderName]
top_blocks: list[BlockData]
# All blocks
class BlockCategoryResponse(BaseModel):
name: str
total_blocks: int
blocks: list[BlockData]
model_config = {"use_enum_values": False} # <== use enum names like "AI"
# Input/Action/Output and see all for block categories
class BlockResponse(BaseModel):
blocks: list[BlockData]
pagination: Pagination
# Providers
class Provider(BaseModel):
name: ProviderName
description: str
integration_count: int
class ProviderResponse(BaseModel):
providers: list[Provider]
pagination: Pagination
# Search
class SearchRequest(BaseModel):
search_query: str | None = None
filter: list[FilterType] | None = None
by_creator: list[str] | None = None
search_id: str | None = None
page: int | None = None
page_size: int | None = None
class SearchBlocksResponse(BaseModel):
blocks: BlockResponse
total_block_count: int
total_integration_count: int
class SearchResponse(BaseModel):
items: list[BlockData | library_model.LibraryAgent | store_model.StoreAgent]
total_items: dict[FilterType, int]
page: int
more_pages: bool
class CountResponse(BaseModel):
all_blocks: int
input_blocks: int
action_blocks: int
output_blocks: int
integrations: int
marketplace_agents: int
my_agents: int

Some files were not shown because too many files have changed in this diff Show More