mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-11 16:18:07 -05:00
Compare commits
42 Commits
fixes-to-t
...
feat/execu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f652cb978 | ||
|
|
279552a2a3 | ||
|
|
fb6ac1d6ca | ||
|
|
9db15bff02 | ||
|
|
3e4ca19036 | ||
|
|
a595da02f7 | ||
|
|
12cdd45551 | ||
|
|
df3c81a7a6 | ||
|
|
db4b94e0dc | ||
|
|
0f477e2392 | ||
|
|
b713093276 | ||
|
|
3718b948ea | ||
|
|
44d739386b | ||
|
|
533d2d0277 | ||
|
|
c6821484c7 | ||
|
|
fecbd3042d | ||
|
|
c0172c93aa | ||
|
|
8a68e03eb1 | ||
|
|
da585a34e1 | ||
|
|
bc43d05cac | ||
|
|
469b1fccbb | ||
|
|
1aa7e10cbd | ||
|
|
890bb3b8b4 | ||
|
|
2bb8e91040 | ||
|
|
76090f0ba2 | ||
|
|
5b12e02c4e | ||
|
|
e0520f5e0a | ||
|
|
a9530b7304 | ||
|
|
8a9c165faf | ||
|
|
476bfc6c84 | ||
|
|
61f17e5b97 | ||
|
|
a54bed6d68 | ||
|
|
5502256bea | ||
|
|
bd97727763 | ||
|
|
aa256f21cd | ||
|
|
2848e62f8a | ||
|
|
fa5ff9ca3c | ||
|
|
7c908c10b8 | ||
|
|
68cd1cb398 | ||
|
|
f4538d6f5a | ||
|
|
4589b15450 | ||
|
|
ccc4d0dc6c |
244
.github/copilot-instructions.md
vendored
Normal file
244
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,244 @@
|
||||
# GitHub Copilot Instructions for AutoGPT
|
||||
|
||||
This file provides comprehensive onboarding information for GitHub Copilot coding agent to work efficiently with the AutoGPT repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
**AutoGPT** is a powerful platform for creating, deploying, and managing continuous AI agents that automate complex workflows. This is a large monorepo (~150MB) containing multiple components:
|
||||
|
||||
- **AutoGPT Platform** (`autogpt_platform/`) - Main focus: Modern AI agent platform (Polyform Shield License)
|
||||
- **Classic AutoGPT** (`classic/`) - Legacy agent system (MIT License)
|
||||
- **Documentation** (`docs/`) - MkDocs-based documentation site
|
||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
||||
|
||||
**Primary Languages & Frameworks:**
|
||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
||||
|
||||
## Build and Validation Instructions
|
||||
|
||||
### Essential Setup Commands
|
||||
|
||||
**Always run these commands in the correct directory and in this order:**
|
||||
|
||||
1. **Initial Setup** (required once):
|
||||
```bash
|
||||
# Clone and enter repository
|
||||
git clone <repo> && cd AutoGPT
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
2. **Backend Setup** (always run before backend development):
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry install # Install dependencies
|
||||
poetry run prisma migrate dev # Run database migrations
|
||||
poetry run prisma generate # Generate Prisma client
|
||||
```
|
||||
|
||||
3. **Frontend Setup** (always run before frontend development):
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm install # Install dependencies
|
||||
```
|
||||
|
||||
### Runtime Requirements
|
||||
|
||||
**Critical:** Always ensure Docker services are running before starting development:
|
||||
```bash
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
**Python Version:** Use Python 3.11 (required; managed by Poetry via pyproject.toml)
|
||||
**Node.js Version:** Use Node.js 21+ with pnpm package manager
|
||||
|
||||
### Development Commands
|
||||
|
||||
**Backend Development:**
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run serve # Start development server (port 8000)
|
||||
poetry run test # Run all tests (requires ~5 minutes)
|
||||
poetry run pytest path/to/test.py # Run specific test
|
||||
poetry run format # Format code (Black + isort) - always run first
|
||||
poetry run lint # Lint code (ruff) - run after format
|
||||
```
|
||||
|
||||
**Frontend Development:**
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm dev # Start development server (port 3000) - use for active development
|
||||
pnpm build # Build for production (only needed for E2E tests or deployment)
|
||||
pnpm test # Run Playwright E2E tests (requires build first)
|
||||
pnpm test-ui # Run tests with UI
|
||||
pnpm format # Format and lint code
|
||||
pnpm storybook # Start component development server
|
||||
```
|
||||
|
||||
### Testing Strategy
|
||||
|
||||
**Backend Tests:**
|
||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
||||
|
||||
**Frontend Tests:**
|
||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
||||
- **Component Tests**: Use Storybook for isolated component development
|
||||
|
||||
### Critical Validation Steps
|
||||
|
||||
**Before committing changes:**
|
||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
||||
2. Ensure all tests pass in modified areas
|
||||
3. Verify Docker services are still running
|
||||
4. Check that database migrations apply cleanly
|
||||
|
||||
**Common Issues & Workarounds:**
|
||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
||||
- **Permission errors**: Ensure Docker has proper permissions
|
||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
||||
- **Test timeouts**: Backend tests can take 5+ minutes, use `-x` flag to stop on first failure
|
||||
|
||||
## Project Layout & Architecture
|
||||
|
||||
### Core Architecture
|
||||
|
||||
**AutoGPT Platform** (`autogpt_platform/`):
|
||||
- `backend/` - FastAPI server with async support
|
||||
- `backend/backend/` - Core API logic
|
||||
- `backend/blocks/` - Agent execution blocks
|
||||
- `backend/data/` - Database models and schemas
|
||||
- `schema.prisma` - Database schema definition
|
||||
- `frontend/` - Next.js application
|
||||
- `src/app/` - App Router pages and layouts
|
||||
- `src/components/` - Reusable React components
|
||||
- `src/lib/` - Utilities and configurations
|
||||
- `autogpt_libs/` - Shared Python utilities
|
||||
- `docker-compose.yml` - Development stack orchestration
|
||||
|
||||
**Key Configuration Files:**
|
||||
- `pyproject.toml` - Python dependencies and tooling
|
||||
- `package.json` - Node.js dependencies and scripts
|
||||
- `schema.prisma` - Database schema and migrations
|
||||
- `next.config.mjs` - Next.js configuration
|
||||
- `tailwind.config.ts` - Styling configuration
|
||||
|
||||
### Security & Middleware
|
||||
|
||||
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
|
||||
**Authentication**: JWT-based with Supabase integration
|
||||
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
|
||||
|
||||
### Development Workflow
|
||||
|
||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
||||
- `platform-backend-ci.yml` - Backend testing and validation
|
||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
||||
|
||||
**Pre-commit Hooks**: Run linting and formatting checks
|
||||
**Conventional Commits**: Use format `type(scope): description` (e.g., `feat(backend): add API`)
|
||||
|
||||
### Key Source Files
|
||||
|
||||
**Backend Entry Points:**
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
**Frontend Entry Points:**
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
|
||||
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
|
||||
|
||||
### Agent Block System
|
||||
|
||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
||||
- Block definition with input/output schemas
|
||||
- Execution logic with proper error handling
|
||||
- Tests validating functionality
|
||||
|
||||
### Database & ORM
|
||||
|
||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
||||
- Schema in `schema.prisma`
|
||||
- Migrations in `backend/migrations/`
|
||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
### Configuration Files Priority Order
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
### Docker Environment Setup
|
||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Copy `.env.default` files to `.env` for local development customization
|
||||
|
||||
## Advanced Development Patterns
|
||||
|
||||
### Adding New Blocks
|
||||
1. Create file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class with input/output schemas
|
||||
3. Implement `run` method with proper error handling
|
||||
4. Generate block UUID using `uuid.uuid4()`
|
||||
5. Register in block registry
|
||||
6. Write tests alongside block implementation
|
||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
||||
|
||||
### API Development
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
4. For `data/*.py` changes, validate user ID checks
|
||||
5. Run `poetry run test` to verify changes
|
||||
|
||||
### Frontend Development
|
||||
1. Components in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for component development
|
||||
4. Test user-facing features with Playwright E2E tests
|
||||
5. Update protected routes in middleware when needed
|
||||
|
||||
### Security Guidelines
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
- Prevents sensitive data caching in browsers/proxies
|
||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
||||
|
||||
### CI/CD Alignment
|
||||
The repository has comprehensive CI workflows that test:
|
||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
||||
- **Integration**: Full-stack type checking and E2E testing
|
||||
|
||||
Match these patterns when developing locally - the copilot setup environment mirrors these CI configurations.
|
||||
|
||||
## Collaboration with Other AI Assistants
|
||||
|
||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
||||
- Check for existing CLAUDE.md files that provide additional context
|
||||
- Follow established patterns and conventions already in the codebase
|
||||
- Maintain consistency with existing code style and architecture
|
||||
- Consider that changes may be reviewed and extended by both human developers and AI assistants
|
||||
|
||||
## Trust These Instructions
|
||||
|
||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
||||
1. Information here is incomplete for your specific task
|
||||
2. You encounter errors not covered by the workarounds
|
||||
3. You need to understand implementation details not covered above
|
||||
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
302
.github/workflows/copilot-setup-steps.yml
vendored
Normal file
302
.github/workflows/copilot-setup-steps.yml
vendored
Normal file
@@ -0,0 +1,302 @@
|
||||
name: "Copilot Setup Steps"
|
||||
|
||||
# Automatically run the setup steps when they are changed to allow for easy validation, and
|
||||
# allow manual testing through the repository's "Actions" tab
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
paths:
|
||||
- .github/workflows/copilot-setup-steps.yml
|
||||
pull_request:
|
||||
paths:
|
||||
- .github/workflows/copilot-setup-steps.yml
|
||||
|
||||
jobs:
|
||||
# The job MUST be called `copilot-setup-steps` or it will not be picked up by Copilot.
|
||||
copilot-setup-steps:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
# Set the permissions to the lowest permissions possible needed for your steps.
|
||||
# Copilot will be given its own token for its operations.
|
||||
permissions:
|
||||
# If you want to clone the repository as part of your setup steps, for example to install dependencies, you'll need the `contents: read` permission. If you don't clone the repository in your setup steps, Copilot will do this for you automatically after the steps complete.
|
||||
contents: read
|
||||
|
||||
# You can define any steps you want, and they will run before the agent starts.
|
||||
# If you do not check out your code, Copilot will do this for you.
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
4
.github/workflows/platform-backend-ci.yml
vendored
4
.github/workflows/platform-backend-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
@@ -201,7 +201,7 @@ jobs:
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
|
||||
@@ -149,14 +149,23 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class
|
||||
3. Define input/output schemas
|
||||
4. Implement `run` method
|
||||
5. Register in block registry
|
||||
6. Generate the block uuid using `uuid.uuid4()`
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
Quick steps:
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from .depends import requires_admin_user, requires_user
|
||||
from .jwt_utils import parse_jwt_token
|
||||
from .middleware import APIKeyValidator, auth_middleware
|
||||
from .config import verify_settings
|
||||
from .dependencies import get_user_id, requires_admin_user, requires_user
|
||||
from .helpers import add_auth_responses_to_openapi
|
||||
from .models import User
|
||||
|
||||
__all__ = [
|
||||
"parse_jwt_token",
|
||||
"requires_user",
|
||||
"verify_settings",
|
||||
"get_user_id",
|
||||
"requires_admin_user",
|
||||
"APIKeyValidator",
|
||||
"auth_middleware",
|
||||
"requires_user",
|
||||
"add_auth_responses_to_openapi",
|
||||
"User",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,90 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from jwt.algorithms import get_default_algorithms, has_crypto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfigError(ValueError):
|
||||
"""Raised when authentication configuration is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
ALGO_RECOMMENDATION = (
|
||||
"We highly recommend using an asymmetric algorithm such as ES256, "
|
||||
"because when leaked, a shared secret would allow anyone to "
|
||||
"forge valid tokens and impersonate users. "
|
||||
"More info: https://supabase.com/docs/guides/auth/signing-keys#choosing-the-right-signing-algorithm" # noqa
|
||||
)
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
self.JWT_VERIFY_KEY: str = os.getenv(
|
||||
"JWT_VERIFY_KEY", os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
).strip()
|
||||
self.JWT_ALGORITHM: str = os.getenv("JWT_SIGN_ALGORITHM", "HS256").strip()
|
||||
|
||||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
if not self.JWT_VERIFY_KEY:
|
||||
raise AuthConfigError(
|
||||
"JWT_VERIFY_KEY must be set. "
|
||||
"An empty JWT secret would allow anyone to forge valid tokens."
|
||||
)
|
||||
|
||||
if len(self.JWT_VERIFY_KEY) < 32:
|
||||
logger.warning(
|
||||
"⚠️ JWT_VERIFY_KEY appears weak (less than 32 characters). "
|
||||
"Consider using a longer, cryptographically secure secret."
|
||||
)
|
||||
|
||||
supported_algorithms = get_default_algorithms().keys()
|
||||
|
||||
if not has_crypto:
|
||||
logger.warning(
|
||||
"⚠️ Asymmetric JWT verification is not available "
|
||||
"because the 'cryptography' package is not installed. "
|
||||
+ ALGO_RECOMMENDATION
|
||||
)
|
||||
|
||||
if (
|
||||
self.JWT_ALGORITHM not in supported_algorithms
|
||||
or self.JWT_ALGORITHM == "none"
|
||||
):
|
||||
raise AuthConfigError(
|
||||
f"Invalid JWT_SIGN_ALGORITHM: '{self.JWT_ALGORITHM}'. "
|
||||
"Supported algorithms are listed on "
|
||||
"https://pyjwt.readthedocs.io/en/stable/algorithms.html"
|
||||
)
|
||||
|
||||
if self.JWT_ALGORITHM.startswith("HS"):
|
||||
logger.warning(
|
||||
f"⚠️ JWT_SIGN_ALGORITHM is set to '{self.JWT_ALGORITHM}', "
|
||||
"a symmetric shared-key signature algorithm. " + ALGO_RECOMMENDATION
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
_settings: Settings = None # type: ignore
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
global _settings
|
||||
|
||||
if not _settings:
|
||||
_settings = Settings()
|
||||
|
||||
return _settings
|
||||
|
||||
|
||||
def verify_settings() -> None:
|
||||
global _settings
|
||||
|
||||
if not _settings:
|
||||
_settings = Settings() # calls validation indirectly
|
||||
return
|
||||
|
||||
_settings.validate()
|
||||
|
||||
306
autogpt_platform/autogpt_libs/autogpt_libs/auth/config_test.py
Normal file
306
autogpt_platform/autogpt_libs/autogpt_libs/auth/config_test.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Comprehensive tests for auth configuration to ensure 100% line and branch coverage.
|
||||
These tests verify critical security checks preventing JWT token forgery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth.config import AuthConfigError, Settings
|
||||
|
||||
|
||||
def test_environment_variable_precedence(mocker: MockerFixture):
|
||||
"""Test that environment variables take precedence over defaults."""
|
||||
secret = "environment-secret-key-with-proper-length-123456"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_environment_variable_backwards_compatible(mocker: MockerFixture):
|
||||
"""Test that SUPABASE_JWT_SECRET is read if JWT_VERIFY_KEY is not set."""
|
||||
secret = "environment-secret-key-with-proper-length-123456"
|
||||
mocker.patch.dict(os.environ, {"SUPABASE_JWT_SECRET": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_auth_config_error_inheritance():
|
||||
"""Test that AuthConfigError is properly defined as an Exception."""
|
||||
assert issubclass(AuthConfigError, Exception)
|
||||
error = AuthConfigError("test message")
|
||||
assert str(error) == "test message"
|
||||
|
||||
|
||||
def test_settings_static_after_creation(mocker: MockerFixture):
|
||||
"""Test that settings maintain their values after creation."""
|
||||
secret = "immutable-secret-key-with-proper-length-12345"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
original_secret = settings.JWT_VERIFY_KEY
|
||||
|
||||
# Changing environment after creation shouldn't affect settings
|
||||
os.environ["JWT_VERIFY_KEY"] = "different-secret"
|
||||
|
||||
assert settings.JWT_VERIFY_KEY == original_secret
|
||||
|
||||
|
||||
def test_settings_load_with_valid_secret(mocker: MockerFixture):
|
||||
"""Test auth enabled with a valid JWT secret."""
|
||||
valid_secret = "a" * 32 # 32 character secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": valid_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == valid_secret
|
||||
|
||||
|
||||
def test_settings_load_with_strong_secret(mocker: MockerFixture):
|
||||
"""Test auth enabled with a cryptographically strong secret."""
|
||||
strong_secret = "super-secret-jwt-token-with-at-least-32-characters-long"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": strong_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == strong_secret
|
||||
assert len(settings.JWT_VERIFY_KEY) >= 32
|
||||
|
||||
|
||||
def test_secret_empty_raises_error(mocker: MockerFixture):
|
||||
"""Test that auth enabled with empty secret raises AuthConfigError."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": ""}, clear=True)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
Settings()
|
||||
assert "JWT_VERIFY_KEY" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_secret_missing_raises_error(mocker: MockerFixture):
|
||||
"""Test that auth enabled without secret env var raises AuthConfigError."""
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
Settings()
|
||||
assert "JWT_VERIFY_KEY" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("secret", [" ", " ", "\t", "\n", " \t\n "])
|
||||
def test_secret_only_whitespace_raises_error(mocker: MockerFixture, secret: str):
|
||||
"""Test that auth enabled with whitespace-only secret raises error."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Settings()
|
||||
|
||||
|
||||
def test_secret_weak_logs_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that weak JWT secret triggers warning log."""
|
||||
weak_secret = "short" # Less than 32 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": weak_secret}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == weak_secret
|
||||
assert "key appears weak" in caplog.text.lower()
|
||||
assert "less than 32 characters" in caplog.text
|
||||
|
||||
|
||||
def test_secret_31_char_logs_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that 31-character secret triggers warning (boundary test)."""
|
||||
secret_31 = "a" * 31 # Exactly 31 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_31}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert len(settings.JWT_VERIFY_KEY) == 31
|
||||
assert "key appears weak" in caplog.text.lower()
|
||||
|
||||
|
||||
def test_secret_32_char_no_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that 32-character secret does not trigger warning (boundary test)."""
|
||||
secret_32 = "a" * 32 # Exactly 32 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_32}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert len(settings.JWT_VERIFY_KEY) == 32
|
||||
assert "JWT secret appears weak" not in caplog.text
|
||||
|
||||
|
||||
def test_secret_whitespace_stripped(mocker: MockerFixture):
|
||||
"""Test that JWT secret whitespace is stripped."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": f" {secret} "}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_secret_with_special_characters(mocker: MockerFixture):
|
||||
"""Test JWT secret with special characters."""
|
||||
special_secret = "!@#$%^&*()_+-=[]{}|;:,.<>?`~" + "a" * 10 # 40 chars total
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": special_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == special_secret
|
||||
|
||||
|
||||
def test_secret_with_unicode(mocker: MockerFixture):
|
||||
"""Test JWT secret with unicode characters."""
|
||||
unicode_secret = "秘密🔐キー" + "a" * 25 # Ensure >32 bytes
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": unicode_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == unicode_secret
|
||||
|
||||
|
||||
def test_secret_very_long(mocker: MockerFixture):
|
||||
"""Test JWT secret with excessive length."""
|
||||
long_secret = "a" * 1000 # 1000 character secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": long_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == long_secret
|
||||
assert len(settings.JWT_VERIFY_KEY) == 1000
|
||||
|
||||
|
||||
def test_secret_with_newline(mocker: MockerFixture):
|
||||
"""Test JWT secret containing newlines."""
|
||||
multiline_secret = "secret\nwith\nnewlines" + "a" * 20
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": multiline_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == multiline_secret
|
||||
|
||||
|
||||
def test_secret_base64_encoded(mocker: MockerFixture):
|
||||
"""Test JWT secret that looks like base64."""
|
||||
base64_secret = "dGhpc19pc19hX3NlY3JldF9rZXlfd2l0aF9wcm9wZXJfbGVuZ3Ro"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": base64_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == base64_secret
|
||||
|
||||
|
||||
def test_secret_numeric_only(mocker: MockerFixture):
|
||||
"""Test JWT secret with only numbers."""
|
||||
numeric_secret = "1234567890" * 4 # 40 character numeric secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": numeric_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == numeric_secret
|
||||
|
||||
|
||||
def test_algorithm_default_hs256(mocker: MockerFixture):
|
||||
"""Test that JWT algorithm defaults to HS256."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": "a" * 32}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_ALGORITHM == "HS256"
|
||||
|
||||
|
||||
def test_algorithm_whitespace_stripped(mocker: MockerFixture):
|
||||
"""Test that JWT algorithm whitespace is stripped."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": " HS256 "},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_ALGORITHM == "HS256"
|
||||
|
||||
|
||||
def test_no_crypto_warning(mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
|
||||
"""Test warning when crypto package is not available."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
# Mock has_crypto to return False
|
||||
mocker.patch("autogpt_libs.auth.config.has_crypto", False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
Settings()
|
||||
assert "Asymmetric JWT verification is not available" in caplog.text
|
||||
assert "cryptography" in caplog.text
|
||||
|
||||
|
||||
def test_algorithm_invalid_raises_error(mocker: MockerFixture):
|
||||
"""Test that invalid JWT algorithm raises AuthConfigError."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "INVALID_ALG"},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthConfigError) as exc_info:
|
||||
Settings()
|
||||
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
|
||||
assert "INVALID_ALG" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_algorithm_none_raises_error(mocker: MockerFixture):
|
||||
"""Test that 'none' algorithm raises AuthConfigError."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "none"},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthConfigError) as exc_info:
|
||||
Settings()
|
||||
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm", ["HS256", "HS384", "HS512"])
|
||||
def test_algorithm_symmetric_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
|
||||
):
|
||||
"""Test warning for symmetric algorithms (HS256, HS384, HS512)."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert algorithm in caplog.text
|
||||
assert "symmetric shared-key signature algorithm" in caplog.text
|
||||
assert settings.JWT_ALGORITHM == algorithm
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"algorithm",
|
||||
["ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512"],
|
||||
)
|
||||
def test_algorithm_asymmetric_no_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
|
||||
):
|
||||
"""Test that asymmetric algorithms do not trigger warning."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
# Should not contain the symmetric algorithm warning
|
||||
assert "symmetric shared-key signature algorithm" not in caplog.text
|
||||
assert settings.JWT_ALGORITHM == algorithm
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
FastAPI dependency functions for JWT-based authentication and authorization.
|
||||
|
||||
These are the high-level dependency functions used in route definitions.
|
||||
"""
|
||||
|
||||
import fastapi
|
||||
|
||||
from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
|
||||
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures
|
||||
"""
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures, 403 for insufficient permissions
|
||||
"""
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures or missing user ID
|
||||
"""
|
||||
user_id = jwt_payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
return user_id
|
||||
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Comprehensive integration tests for authentication dependencies.
|
||||
Tests the full authentication flow from HTTP requests to user validation.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Security
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth.dependencies import (
|
||||
get_user_id,
|
||||
requires_admin_user,
|
||||
requires_user,
|
||||
)
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
|
||||
class TestAuthDependencies:
|
||||
"""Test suite for authentication dependency functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create a test FastAPI application."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/user")
|
||||
def get_user_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
@app.get("/admin")
|
||||
def get_admin_endpoint(user: User = Security(requires_admin_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
@app.get("/user-id")
|
||||
def get_user_id_endpoint(user_id: str = Security(get_user_id)):
|
||||
return {"user_id": user_id}
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app):
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
# Mock get_jwt_payload to return our test payload
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
requires_admin_user(jwt_payload)
|
||||
|
||||
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = get_user_id(jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestAuthDependenciesIntegration:
|
||||
"""Integration tests for auth dependencies with FastAPI."""
|
||||
|
||||
acceptable_jwt_secret = "test-secret-with-proper-length-123456"
|
||||
|
||||
@pytest.fixture
|
||||
def create_token(self, mocker: MockerFixture):
|
||||
"""Helper to create JWT tokens."""
|
||||
import jwt
|
||||
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": self.acceptable_jwt_secret},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
def _create_token(payload, secret=self.acceptable_jwt_secret):
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
return _create_token
|
||||
|
||||
def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Should fail without auth header
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
token = create_token(
|
||||
{"sub": "test-user", "role": "user", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {token}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/admin")
|
||||
def admin_endpoint(user: User = Security(requires_admin_user)):
|
||||
return {"user_id": user.user_id}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Regular user token
|
||||
user_token = create_token(
|
||||
{"sub": "regular-user", "role": "user", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Admin token
|
||||
admin_token = create_token(
|
||||
{"sub": "admin-user", "role": "admin", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin", headers={"Authorization": f"Bearer {admin_token}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "admin-user"
|
||||
|
||||
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
"role": "admin",
|
||||
"email": "test@example.com",
|
||||
"app_metadata": {"provider": "email", "providers": ["email"]},
|
||||
"user_metadata": {
|
||||
"full_name": "Test User",
|
||||
"avatar_url": "https://example.com/avatar.jpg",
|
||||
},
|
||||
"aud": "authenticated",
|
||||
"iat": 1234567890,
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
"role": "user",
|
||||
"email": "测试@example.com",
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
"role": "user",
|
||||
"email": None,
|
||||
"phone": None,
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = requires_user(payload1)
|
||||
user2 = requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
assert user1.role == "user"
|
||||
assert user2.role == "admin"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload,expected_error,admin_only",
|
||||
[
|
||||
(None, "Authorization header is missing", False),
|
||||
({}, "User ID not found", False),
|
||||
({"sub": ""}, "User ID not found", False),
|
||||
({"role": "user"}, "User ID not found", False),
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
# Valid case
|
||||
user = verify_user({"sub": "user", "role": "user"}, admin_only=False)
|
||||
assert user.user_id == "user"
|
||||
@@ -1,46 +0,0 @@
|
||||
import fastapi
|
||||
|
||||
from .config import settings
|
||||
from .middleware import auth_middleware
|
||||
from .models import DEFAULT_USER_ID, User
|
||||
|
||||
|
||||
def requires_user(payload: dict = fastapi.Depends(auth_middleware)) -> User:
|
||||
return verify_user(payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(
|
||||
payload: dict = fastapi.Depends(auth_middleware),
|
||||
) -> User:
|
||||
return verify_user(payload, admin_only=True)
|
||||
|
||||
|
||||
def verify_user(payload: dict | None, admin_only: bool) -> User:
|
||||
if not payload:
|
||||
if settings.ENABLE_AUTH:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="Authorization header is missing"
|
||||
)
|
||||
# This handles the case when authentication is disabled
|
||||
payload = {"sub": DEFAULT_USER_ID, "role": "admin"}
|
||||
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
|
||||
if admin_only and payload["role"] != "admin":
|
||||
raise fastapi.HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return User.from_payload(payload)
|
||||
|
||||
|
||||
def get_user_id(payload: dict = fastapi.Depends(auth_middleware)) -> str:
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
return user_id
|
||||
@@ -1,68 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from .depends import requires_admin_user, requires_user, verify_user
|
||||
|
||||
|
||||
def test_verify_user_no_payload():
|
||||
user = verify_user(None, admin_only=False)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_no_user_id():
|
||||
with pytest.raises(Exception):
|
||||
verify_user({"role": "admin"}, admin_only=False)
|
||||
|
||||
|
||||
def test_verify_user_not_admin():
|
||||
with pytest.raises(Exception):
|
||||
verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
|
||||
admin_only=True,
|
||||
)
|
||||
|
||||
|
||||
def test_verify_user_with_admin_role():
|
||||
user = verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"},
|
||||
admin_only=True,
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_with_user_role():
|
||||
user = verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
|
||||
admin_only=False,
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "user"
|
||||
|
||||
|
||||
def test_requires_user():
|
||||
user = requires_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "user"
|
||||
|
||||
|
||||
def test_requires_user_no_user_id():
|
||||
with pytest.raises(Exception):
|
||||
requires_user({"role": "user"})
|
||||
|
||||
|
||||
def test_requires_admin_user():
|
||||
user = requires_admin_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"}
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_requires_admin_user_not_admin():
|
||||
with pytest.raises(Exception):
|
||||
requires_admin_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
|
||||
)
|
||||
68
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers.py
Normal file
68
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
for method, details in methods.items():
|
||||
security_schemas = [
|
||||
schema
|
||||
for auth_option in details.get("security", [])
|
||||
for schema in auth_option.keys()
|
||||
]
|
||||
if bearer_jwt_auth.scheme_name not in security_schemas:
|
||||
continue
|
||||
|
||||
if "responses" not in details:
|
||||
details["responses"] = {}
|
||||
|
||||
details["responses"]["401"] = {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
|
||||
# Ensure #/components/responses exists
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "responses" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["responses"] = {}
|
||||
|
||||
# Define 401 response
|
||||
openapi_schema["components"]["responses"]["HTTP401NotAuthenticatedError"] = {
|
||||
"description": "Authentication required",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"detail": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
435
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers_test.py
Normal file
435
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers_test.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Comprehensive tests for auth helpers module to achieve 100% coverage.
|
||||
Tests OpenAPI schema generation and authentication response handling.
|
||||
"""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from autogpt_libs.auth.helpers import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth.jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_basic():
|
||||
"""Test adding 401 responses to OpenAPI schema."""
|
||||
app = FastAPI(title="Test App", version="1.0.0")
|
||||
|
||||
# Add some test endpoints with authentication
|
||||
from fastapi import Depends
|
||||
|
||||
from autogpt_libs.auth.dependencies import requires_user
|
||||
|
||||
@app.get("/protected", dependencies=[Depends(requires_user)])
|
||||
def protected_endpoint():
|
||||
return {"message": "Protected"}
|
||||
|
||||
@app.get("/public")
|
||||
def public_endpoint():
|
||||
return {"message": "Public"}
|
||||
|
||||
# Apply the OpenAPI customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get the OpenAPI schema
|
||||
schema = app.openapi()
|
||||
|
||||
# Verify basic schema properties
|
||||
assert schema["info"]["title"] == "Test App"
|
||||
assert schema["info"]["version"] == "1.0.0"
|
||||
|
||||
# Verify 401 response component is added
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# Verify 401 response structure
|
||||
error_response = schema["components"]["responses"]["HTTP401NotAuthenticatedError"]
|
||||
assert error_response["description"] == "Authentication required"
|
||||
assert "application/json" in error_response["content"]
|
||||
assert "schema" in error_response["content"]["application/json"]
|
||||
|
||||
# Verify schema properties
|
||||
response_schema = error_response["content"]["application/json"]["schema"]
|
||||
assert response_schema["type"] == "object"
|
||||
assert "detail" in response_schema["properties"]
|
||||
assert response_schema["properties"]["detail"]["type"] == "string"
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_with_security():
|
||||
"""Test that 401 responses are added only to secured endpoints."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock endpoint with security
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.dependencies import get_user_id
|
||||
|
||||
@app.get("/secured")
|
||||
def secured_endpoint(user_id: str = Security(get_user_id)):
|
||||
return {"user_id": user_id}
|
||||
|
||||
@app.post("/also-secured")
|
||||
def another_secured(user_id: str = Security(get_user_id)):
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/unsecured")
|
||||
def unsecured_endpoint():
|
||||
return {"public": True}
|
||||
|
||||
# Apply OpenAPI customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema
|
||||
schema = app.openapi()
|
||||
|
||||
# Check that secured endpoints have 401 responses
|
||||
if "/secured" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/secured"]:
|
||||
secured_get = schema["paths"]["/secured"]["get"]
|
||||
if "responses" in secured_get:
|
||||
assert "401" in secured_get["responses"]
|
||||
assert (
|
||||
secured_get["responses"]["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
if "/also-secured" in schema["paths"]:
|
||||
if "post" in schema["paths"]["/also-secured"]:
|
||||
secured_post = schema["paths"]["/also-secured"]["post"]
|
||||
if "responses" in secured_post:
|
||||
assert "401" in secured_post["responses"]
|
||||
|
||||
# Check that unsecured endpoint does not have 401 response
|
||||
if "/unsecured" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/unsecured"]:
|
||||
unsecured_get = schema["paths"]["/unsecured"]["get"]
|
||||
if "responses" in unsecured_get:
|
||||
assert "401" not in unsecured_get.get("responses", {})
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_cached_schema():
|
||||
"""Test that OpenAPI schema is cached after first generation."""
|
||||
app = FastAPI()
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema twice
|
||||
schema1 = app.openapi()
|
||||
schema2 = app.openapi()
|
||||
|
||||
# Should return the same cached object
|
||||
assert schema1 is schema2
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_existing_responses():
|
||||
"""Test handling endpoints that already have responses defined."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get(
|
||||
"/with-responses",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
404: {"description": "Not found"},
|
||||
},
|
||||
)
|
||||
def endpoint_with_responses(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"data": "test"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Check that existing responses are preserved and 401 is added
|
||||
if "/with-responses" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/with-responses"]:
|
||||
responses = schema["paths"]["/with-responses"]["get"].get("responses", {})
|
||||
# Original responses should be preserved
|
||||
if "200" in responses:
|
||||
assert responses["200"]["description"] == "Success"
|
||||
if "404" in responses:
|
||||
assert responses["404"]["description"] == "Not found"
|
||||
# 401 should be added
|
||||
if "401" in responses:
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_no_security_endpoints():
|
||||
"""Test with app that has no secured endpoints."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/public1")
|
||||
def public1():
|
||||
return {"message": "public1"}
|
||||
|
||||
@app.post("/public2")
|
||||
def public2():
|
||||
return {"message": "public2"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Component should still be added for consistency
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# But no endpoints should have 401 responses
|
||||
for path in schema["paths"].values():
|
||||
for method in path.values():
|
||||
if isinstance(method, dict) and "responses" in method:
|
||||
assert "401" not in method["responses"]
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_multiple_security_schemes():
|
||||
"""Test endpoints with multiple security requirements."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.dependencies import requires_admin_user, requires_user
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
@app.get("/multi-auth")
|
||||
def multi_auth(
|
||||
user: User = Security(requires_user),
|
||||
admin: User = Security(requires_admin_user),
|
||||
):
|
||||
return {"status": "super secure"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Should have 401 response
|
||||
if "/multi-auth" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/multi-auth"]:
|
||||
responses = schema["paths"]["/multi-auth"]["get"].get("responses", {})
|
||||
if "401" in responses:
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_empty_components():
|
||||
"""Test when OpenAPI schema has no components section initially."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock get_openapi to return schema without components
|
||||
original_get_openapi = get_openapi
|
||||
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Remove components if it exists
|
||||
if "components" in schema:
|
||||
del schema["components"]
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Components should be created
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_all_http_methods():
|
||||
"""Test that all HTTP methods are handled correctly."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get("/resource")
|
||||
def get_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "GET"}
|
||||
|
||||
@app.post("/resource")
|
||||
def post_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "POST"}
|
||||
|
||||
@app.put("/resource")
|
||||
def put_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "PUT"}
|
||||
|
||||
@app.patch("/resource")
|
||||
def patch_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "PATCH"}
|
||||
|
||||
@app.delete("/resource")
|
||||
def delete_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "DELETE"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# All methods should have 401 response
|
||||
if "/resource" in schema["paths"]:
|
||||
for method in ["get", "post", "put", "patch", "delete"]:
|
||||
if method in schema["paths"]["/resource"]:
|
||||
method_spec = schema["paths"]["/resource"][method]
|
||||
if "responses" in method_spec:
|
||||
assert "401" in method_spec["responses"]
|
||||
|
||||
|
||||
def test_bearer_jwt_auth_scheme_config():
|
||||
"""Test that bearer_jwt_auth is configured correctly."""
|
||||
assert bearer_jwt_auth.scheme_name == "HTTPBearerJWT"
|
||||
assert bearer_jwt_auth.auto_error is False
|
||||
|
||||
|
||||
def test_add_auth_responses_with_no_routes():
|
||||
"""Test OpenAPI generation with app that has no routes."""
|
||||
app = FastAPI(title="Empty App")
|
||||
|
||||
# Apply customization to empty app
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Should still have basic structure
|
||||
assert schema["info"]["title"] == "Empty App"
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
|
||||
def test_custom_openapi_function_replacement():
|
||||
"""Test that the custom openapi function properly replaces the default."""
|
||||
app = FastAPI()
|
||||
|
||||
# Store original function
|
||||
original_openapi = app.openapi
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Function should be replaced
|
||||
assert app.openapi != original_openapi
|
||||
assert callable(app.openapi)
|
||||
|
||||
|
||||
def test_endpoint_without_responses_section():
|
||||
"""Test endpoint that has security but no responses section initially."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
from fastapi.openapi.utils import get_openapi as original_get_openapi
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
# Create endpoint
|
||||
@app.get("/no-responses")
|
||||
def endpoint_without_responses(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"data": "test"}
|
||||
|
||||
# Mock get_openapi to remove responses from the endpoint
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Remove responses from our endpoint to trigger line 40
|
||||
if "/no-responses" in schema.get("paths", {}):
|
||||
if "get" in schema["paths"]["/no-responses"]:
|
||||
# Delete responses to force the code to create it
|
||||
if "responses" in schema["paths"]["/no-responses"]["get"]:
|
||||
del schema["paths"]["/no-responses"]["get"]["responses"]
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema and verify 401 was added
|
||||
schema = app.openapi()
|
||||
|
||||
# The endpoint should now have 401 response
|
||||
if "/no-responses" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/no-responses"]:
|
||||
responses = schema["paths"]["/no-responses"]["get"].get("responses", {})
|
||||
assert "401" in responses
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_components_with_existing_responses():
|
||||
"""Test when components already has a responses section."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock get_openapi to return schema with existing components/responses
|
||||
from fastapi.openapi.utils import get_openapi as original_get_openapi
|
||||
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Add existing components/responses
|
||||
if "components" not in schema:
|
||||
schema["components"] = {}
|
||||
schema["components"]["responses"] = {
|
||||
"ExistingResponse": {"description": "An existing response"}
|
||||
}
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Both responses should exist
|
||||
assert "ExistingResponse" in schema["components"]["responses"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# Verify our 401 response structure
|
||||
error_response = schema["components"]["responses"][
|
||||
"HTTP401NotAuthenticatedError"
|
||||
]
|
||||
assert error_response["description"] == "Authentication required"
|
||||
|
||||
|
||||
def test_openapi_schema_persistence():
|
||||
"""Test that modifications to OpenAPI schema persist correctly."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"test": True}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema multiple times
|
||||
schema1 = app.openapi()
|
||||
|
||||
# Modify the cached schema (shouldn't affect future calls)
|
||||
schema1["info"]["title"] = "Modified Title"
|
||||
|
||||
# Clear cache and get again
|
||||
app.openapi_schema = None
|
||||
schema2 = app.openapi()
|
||||
|
||||
# Should regenerate with original title
|
||||
assert schema2["info"]["title"] == app.title
|
||||
assert schema2["info"]["title"] != "Modified Title"
|
||||
@@ -1,11 +1,48 @@
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from .config import settings
|
||||
from .config import get_settings
|
||||
from .models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Bearer token authentication scheme
|
||||
bearer_jwt_auth = HTTPBearer(
|
||||
bearerFormat="jwt", scheme_name="HTTPBearerJWT", auto_error=False
|
||||
)
|
||||
|
||||
|
||||
def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract and validate JWT payload from HTTP Authorization header.
|
||||
|
||||
This is the core authentication function that handles:
|
||||
- Reading the `Authorization` header to obtain the JWT token
|
||||
- Verifying the JWT token's signature
|
||||
- Decoding the JWT token's payload
|
||||
|
||||
:param credentials: HTTP Authorization credentials from bearer token
|
||||
:return: JWT payload dictionary
|
||||
:raises HTTPException: 401 if authentication fails
|
||||
"""
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
logger.debug("Token decoded successfully")
|
||||
return payload
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
|
||||
def parse_jwt_token(token: str) -> dict[str, Any]:
|
||||
"""
|
||||
Parse and validate a JWT token.
|
||||
|
||||
@@ -13,10 +50,11 @@ def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
:return: The decoded payload
|
||||
:raises ValueError: If the token is invalid or expired
|
||||
"""
|
||||
settings = get_settings()
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
settings.JWT_VERIFY_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience="authenticated",
|
||||
)
|
||||
@@ -25,3 +63,18 @@ def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise ValueError(f"Invalid token: {str(e)}")
|
||||
|
||||
|
||||
def verify_user(jwt_payload: dict | None, admin_only: bool) -> User:
|
||||
if jwt_payload is None:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
user_id = jwt_payload.get("sub")
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||
|
||||
if admin_only and jwt_payload["role"] != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return User.from_payload(jwt_payload)
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Comprehensive tests for JWT token parsing and validation.
|
||||
Ensures 100% line and branch coverage for JWT security functions.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth import config, jwt_utils
|
||||
from autogpt_libs.auth.config import Settings
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
MOCK_JWT_SECRET = "test-secret-key-with-at-least-32-characters"
|
||||
TEST_USER_PAYLOAD = {
|
||||
"sub": "test-user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
TEST_ADMIN_PAYLOAD = {
|
||||
"sub": "admin-user-id",
|
||||
"role": "admin",
|
||||
"aud": "authenticated",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_config(mocker: MockerFixture):
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": MOCK_JWT_SECRET}, clear=True)
|
||||
mocker.patch.object(config, "_settings", Settings())
|
||||
yield
|
||||
|
||||
|
||||
def create_token(payload, secret=None, algorithm="HS256"):
|
||||
"""Helper to create JWT tokens."""
|
||||
if secret is None:
|
||||
secret = MOCK_JWT_SECRET
|
||||
return jwt.encode(payload, secret, algorithm=algorithm)
|
||||
|
||||
|
||||
def test_parse_jwt_token_valid():
|
||||
"""Test parsing a valid JWT token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
result = jwt_utils.parse_jwt_token(token)
|
||||
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
assert result["aud"] == "authenticated"
|
||||
|
||||
|
||||
def test_parse_jwt_token_expired():
|
||||
"""Test parsing an expired JWT token."""
|
||||
expired_payload = {
|
||||
**TEST_USER_PAYLOAD,
|
||||
"exp": datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
}
|
||||
token = create_token(expired_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Token has expired" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_invalid_signature():
|
||||
"""Test parsing a token with invalid signature."""
|
||||
# Create token with different secret
|
||||
token = create_token(TEST_USER_PAYLOAD, secret="wrong-secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_malformed():
|
||||
"""Test parsing a malformed token."""
|
||||
malformed_tokens = [
|
||||
"not.a.token",
|
||||
"invalid",
|
||||
"",
|
||||
# Header only
|
||||
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9",
|
||||
# No signature
|
||||
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0",
|
||||
]
|
||||
|
||||
for token in malformed_tokens:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_wrong_audience():
|
||||
"""Test parsing a token with wrong audience."""
|
||||
wrong_aud_payload = {**TEST_USER_PAYLOAD, "aud": "wrong-audience"}
|
||||
token = create_token(wrong_aud_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_missing_audience():
|
||||
"""Test parsing a token without audience claim."""
|
||||
no_aud_payload = {k: v for k, v in TEST_USER_PAYLOAD.items() if k != "aud"}
|
||||
token = create_token(no_aud_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_with_valid_user():
|
||||
"""Test verifying a valid user."""
|
||||
user = jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=False)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "test-user-id"
|
||||
assert user.role == "user"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
|
||||
def test_verify_user_with_admin():
|
||||
"""Test verifying an admin user."""
|
||||
user = jwt_utils.verify_user(TEST_ADMIN_PAYLOAD, admin_only=True)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "admin-user-id"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_admin_only_with_regular_user():
|
||||
"""Test verifying regular user when admin is required."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=True)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_no_payload():
|
||||
"""Test verifying user with no payload."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(None, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_missing_sub():
|
||||
"""Test verifying user with payload missing 'sub' field."""
|
||||
invalid_payload = {"role": "user", "email": "test@example.com"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_empty_sub():
|
||||
"""Test verifying user with empty 'sub' field."""
|
||||
invalid_payload = {"sub": "", "role": "user"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_none_sub():
|
||||
"""Test verifying user with None 'sub' field."""
|
||||
invalid_payload = {"sub": None, "role": "user"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_missing_role_admin_check():
|
||||
"""Test verifying admin when role field is missing."""
|
||||
no_role_payload = {"sub": "user-id"}
|
||||
with pytest.raises(KeyError):
|
||||
# This will raise KeyError when checking payload["role"]
|
||||
jwt_utils.verify_user(no_role_payload, admin_only=True)
|
||||
|
||||
|
||||
# ======================== EDGE CASES ======================== #
|
||||
|
||||
|
||||
def test_jwt_with_additional_claims():
|
||||
"""Test JWT token with additional custom claims."""
|
||||
extra_claims_payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"custom_claim": "custom_value",
|
||||
"permissions": ["read", "write"],
|
||||
"metadata": {"key": "value"},
|
||||
}
|
||||
token = create_token(extra_claims_payload)
|
||||
|
||||
result = jwt_utils.parse_jwt_token(token)
|
||||
assert result["sub"] == "user-id"
|
||||
assert result["custom_claim"] == "custom_value"
|
||||
assert result["permissions"] == ["read", "write"]
|
||||
|
||||
|
||||
def test_jwt_with_numeric_sub():
|
||||
"""Test JWT token with numeric user ID."""
|
||||
payload = {
|
||||
"sub": 12345, # Numeric ID
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
# Should convert to string internally
|
||||
user = jwt_utils.verify_user(payload, admin_only=False)
|
||||
assert user.user_id == 12345
|
||||
|
||||
|
||||
def test_jwt_with_very_long_sub():
|
||||
"""Test JWT token with very long user ID."""
|
||||
long_id = "a" * 1000
|
||||
payload = {
|
||||
"sub": long_id,
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
user = jwt_utils.verify_user(payload, admin_only=False)
|
||||
assert user.user_id == long_id
|
||||
|
||||
|
||||
def test_jwt_with_special_characters_in_claims():
|
||||
"""Test JWT token with special characters in claims."""
|
||||
payload = {
|
||||
"sub": "user@example.com/special-chars!@#$%",
|
||||
"role": "admin",
|
||||
"aud": "authenticated",
|
||||
"email": "test+special@example.com",
|
||||
}
|
||||
user = jwt_utils.verify_user(payload, admin_only=True)
|
||||
assert "special-chars!@#$%" in user.user_id
|
||||
|
||||
|
||||
def test_jwt_with_future_iat():
|
||||
"""Test JWT token with issued-at time in future."""
|
||||
future_payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"iat": datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
}
|
||||
token = create_token(future_payload)
|
||||
|
||||
# PyJWT validates iat claim and should reject future tokens
|
||||
with pytest.raises(ValueError, match="not yet valid"):
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
|
||||
|
||||
def test_jwt_with_different_algorithms():
|
||||
"""Test that only HS256 algorithm is accepted."""
|
||||
payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
|
||||
# Try different algorithms
|
||||
algorithms = ["HS384", "HS512", "none"]
|
||||
for algo in algorithms:
|
||||
if algo == "none":
|
||||
# Special case for 'none' algorithm (security vulnerability if accepted)
|
||||
token = create_token(payload, "", algorithm="none")
|
||||
else:
|
||||
token = create_token(payload, algorithm=algo)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
@@ -1,140 +0,0 @@
|
||||
import inspect
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Security
|
||||
from fastapi.security import APIKeyHeader, HTTPBearer
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
from .config import settings
|
||||
from .jwt_utils import parse_jwt_token
|
||||
|
||||
security = HTTPBearer()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def auth_middleware(request: Request):
|
||||
if not settings.ENABLE_AUTH:
|
||||
# If authentication is disabled, allow the request to proceed
|
||||
logger.warning("Auth disabled")
|
||||
return {}
|
||||
|
||||
security = HTTPBearer()
|
||||
credentials = await security(request)
|
||||
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
request.state.user = payload
|
||||
logger.debug("Token decoded successfully")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
return payload
|
||||
|
||||
|
||||
class APIKeyValidator:
|
||||
"""
|
||||
Configurable API key validator that supports custom validation functions
|
||||
for FastAPI applications.
|
||||
|
||||
This class provides a flexible way to implement API key authentication with optional
|
||||
custom validation logic. It can be used for simple token matching
|
||||
or more complex validation scenarios like database lookups.
|
||||
|
||||
Examples:
|
||||
Simple token validation:
|
||||
```python
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="your-secret-token"
|
||||
)
|
||||
|
||||
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
|
||||
def protected_endpoint():
|
||||
return {"message": "Access granted"}
|
||||
```
|
||||
|
||||
Custom validation with database lookup:
|
||||
```python
|
||||
async def validate_with_db(api_key: str):
|
||||
api_key_obj = await db.get_api_key(api_key)
|
||||
return api_key_obj if api_key_obj and api_key_obj.is_active else None
|
||||
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
validate_fn=validate_with_db
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
header_name (str): The name of the header containing the API key
|
||||
expected_token (Optional[str]): The expected API key value for simple token matching
|
||||
validate_fn (Optional[Callable]): Custom validation function that takes an API key
|
||||
string and returns a boolean or object. Can be async.
|
||||
error_status (int): HTTP status code to use for validation errors
|
||||
error_message (str): Error message to return when validation fails
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
header_name: str,
|
||||
expected_token: Optional[str] = None,
|
||||
validate_fn: Optional[Callable[[str], bool]] = None,
|
||||
error_status: int = HTTP_401_UNAUTHORIZED,
|
||||
error_message: str = "Invalid API key",
|
||||
):
|
||||
# Create the APIKeyHeader as a class property
|
||||
self.security_scheme = APIKeyHeader(name=header_name)
|
||||
self.expected_token = expected_token
|
||||
self.custom_validate_fn = validate_fn
|
||||
self.error_status = error_status
|
||||
self.error_message = error_message
|
||||
|
||||
async def default_validator(self, api_key: str) -> bool:
|
||||
if not self.expected_token:
|
||||
raise ValueError(
|
||||
"Expected Token Required to be set when uisng API Key Validator default validation"
|
||||
)
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, api_key: str = Security(APIKeyHeader)
|
||||
) -> Any:
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=self.error_status, detail="Missing API key")
|
||||
|
||||
# Use custom validation if provided, otherwise use default equality check
|
||||
validator = self.custom_validate_fn or self.default_validator
|
||||
result = (
|
||||
await validator(api_key)
|
||||
if inspect.iscoroutinefunction(validator)
|
||||
else validator(api_key)
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=self.error_status, detail=self.error_message
|
||||
)
|
||||
|
||||
# Store validation result in request state if it's not just a boolean
|
||||
if result is not True:
|
||||
request.state.api_key = result
|
||||
|
||||
return result
|
||||
|
||||
def get_dependency(self):
|
||||
"""
|
||||
Returns a callable dependency that FastAPI will recognize as a security scheme
|
||||
"""
|
||||
|
||||
async def validate_api_key(
|
||||
request: Request, api_key: str = Security(self.security_scheme)
|
||||
) -> Any:
|
||||
return await self(request, api_key)
|
||||
|
||||
# This helps FastAPI recognize it as a security dependency
|
||||
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
|
||||
return validate_api_key
|
||||
347
autogpt_platform/autogpt_libs/poetry.lock
generated
347
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -54,7 +54,7 @@ version = "1.2.0"
|
||||
description = "Backport of asyncio.Runner, a context manager that controls event loop life cycle."
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
@@ -85,6 +85,87 @@ files = [
|
||||
{file = "certifi-2025.7.14.tar.gz", hash = "sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cffi"
|
||||
version = "1.17.1"
|
||||
description = "Foreign Function Interface for Python calling C code."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "platform_python_implementation != \"PyPy\""
|
||||
files = [
|
||||
{file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-win32.whl", hash = "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"},
|
||||
{file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pycparser = "*"
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.2"
|
||||
@@ -208,12 +289,176 @@ version = "0.4.6"
|
||||
description = "Cross-platform colored terminal text."
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.10.5"
|
||||
description = "Code coverage measurement for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "coverage-7.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c6a5c3414bfc7451b879141ce772c546985163cf553f08e0f135f0699a911801"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bc8e4d99ce82f1710cc3c125adc30fd1487d3cf6c2cd4994d78d68a47b16989a"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:02252dc1216e512a9311f596b3169fad54abcb13827a8d76d5630c798a50a754"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:73269df37883e02d460bee0cc16be90509faea1e3bd105d77360b512d5bb9c33"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f8a81b0614642f91c9effd53eec284f965577591f51f547a1cbeb32035b4c2f"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6a29f8e0adb7f8c2b95fa2d4566a1d6e6722e0a637634c6563cb1ab844427dd9"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fcf6ab569436b4a647d4e91accba12509ad9f2554bc93d3aee23cc596e7f99c3"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:90dc3d6fb222b194a5de60af8d190bedeeddcbc7add317e4a3cd333ee6b7c879"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-win32.whl", hash = "sha256:414a568cd545f9dc75f0686a0049393de8098414b58ea071e03395505b73d7a8"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:e551f9d03347196271935fd3c0c165f0e8c049220280c1120de0084d65e9c7ff"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c177e6ffe2ebc7c410785307758ee21258aa8e8092b44d09a2da767834f075f2"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:14d6071c51ad0f703d6440827eaa46386169b5fdced42631d5a5ac419616046f"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:61f78c7c3bc272a410c5ae3fde7792b4ffb4acc03d35a7df73ca8978826bb7ab"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f39071caa126f69d63f99b324fb08c7b1da2ec28cbb1fe7b5b1799926492f65c"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343a023193f04d46edc46b2616cdbee68c94dd10208ecd3adc56fcc54ef2baa1"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:585ffe93ae5894d1ebdee69fc0b0d4b7c75d8007983692fb300ac98eed146f78"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b0ef4e66f006ed181df29b59921bd8fc7ed7cd6a9289295cd8b2824b49b570df"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eb7b0bbf7cc1d0453b843eca7b5fa017874735bef9bfdfa4121373d2cc885ed6"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-win32.whl", hash = "sha256:1d043a8a06987cc0c98516e57c4d3fc2c1591364831e9deb59c9e1b4937e8caf"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:fefafcca09c3ac56372ef64a40f5fe17c5592fab906e0fdffd09543f3012ba50"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-win_arm64.whl", hash = "sha256:7e78b767da8b5fc5b2faa69bb001edafcd6f3995b42a331c53ef9572c55ceb82"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c2d05c7e73c60a4cecc7d9b60dbfd603b4ebc0adafaef371445b47d0f805c8a9"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:32ddaa3b2c509778ed5373b177eb2bf5662405493baeff52278a0b4f9415188b"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:dd382410039fe062097aa0292ab6335a3f1e7af7bba2ef8d27dcda484918f20c"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7fa22800f3908df31cea6fb230f20ac49e343515d968cc3a42b30d5c3ebf9b5a"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f366a57ac81f5e12797136552f5b7502fa053c861a009b91b80ed51f2ce651c6"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5f1dc8f1980a272ad4a6c84cba7981792344dad33bf5869361576b7aef42733a"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:2285c04ee8676f7938b02b4936d9b9b672064daab3187c20f73a55f3d70e6b4a"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c2492e4dd9daab63f5f56286f8a04c51323d237631eb98505d87e4c4ff19ec34"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-win32.whl", hash = "sha256:38a9109c4ee8135d5df5505384fc2f20287a47ccbe0b3f04c53c9a1989c2bbaf"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:6b87f1ad60b30bc3c43c66afa7db6b22a3109902e28c5094957626a0143a001f"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-win_arm64.whl", hash = "sha256:672a6c1da5aea6c629819a0e1461e89d244f78d7b60c424ecf4f1f2556c041d8"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ef3b83594d933020f54cf65ea1f4405d1f4e41a009c46df629dd964fcb6e907c"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2b96bfdf7c0ea9faebce088a3ecb2382819da4fbc05c7b80040dbc428df6af44"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:63df1fdaffa42d914d5c4d293e838937638bf75c794cf20bee12978fc8c4e3bc"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8002dc6a049aac0e81ecec97abfb08c01ef0c1fbf962d0c98da3950ace89b869"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63d4bb2966d6f5f705a6b0c6784c8969c468dbc4bcf9d9ded8bff1c7e092451f"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1f672efc0731a6846b157389b6e6d5d5e9e59d1d1a23a5c66a99fd58339914d5"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:3f39cef43d08049e8afc1fde4a5da8510fc6be843f8dea350ee46e2a26b2f54c"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2968647e3ed5a6c019a419264386b013979ff1fb67dd11f5c9886c43d6a31fc2"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-win32.whl", hash = "sha256:0d511dda38595b2b6934c2b730a1fd57a3635c6aa2a04cb74714cdfdd53846f4"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:9a86281794a393513cf117177fd39c796b3f8e3759bb2764259a2abba5cce54b"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-win_arm64.whl", hash = "sha256:cebd8e906eb98bb09c10d1feed16096700b1198d482267f8bf0474e63a7b8d84"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0520dff502da5e09d0d20781df74d8189ab334a1e40d5bafe2efaa4158e2d9e7"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d9cd64aca68f503ed3f1f18c7c9174cbb797baba02ca8ab5112f9d1c0328cd4b"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0913dd1613a33b13c4f84aa6e3f4198c1a21ee28ccb4f674985c1f22109f0aae"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1b7181c0feeb06ed8a02da02792f42f829a7b29990fef52eff257fef0885d760"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36d42b7396b605f774d4372dd9c49bed71cbabce4ae1ccd074d155709dd8f235"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b4fdc777e05c4940b297bf47bf7eedd56a39a61dc23ba798e4b830d585486ca5"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:42144e8e346de44a6f1dbd0a56575dd8ab8dfa7e9007da02ea5b1c30ab33a7db"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:66c644cbd7aed8fe266d5917e2c9f65458a51cfe5eeff9c05f15b335f697066e"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-win32.whl", hash = "sha256:2d1b73023854068c44b0c554578a4e1ef1b050ed07cf8b431549e624a29a66ee"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-win_amd64.whl", hash = "sha256:54a1532c8a642d8cc0bd5a9a51f5a9dcc440294fd06e9dda55e743c5ec1a8f14"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-win_arm64.whl", hash = "sha256:74d5b63fe3f5f5d372253a4ef92492c11a4305f3550631beaa432fc9df16fcff"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:68c5e0bc5f44f68053369fa0d94459c84548a77660a5f2561c5e5f1e3bed7031"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:cf33134ffae93865e32e1e37df043bef15a5e857d8caebc0099d225c579b0fa3"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ad8fa9d5193bafcf668231294241302b5e683a0518bf1e33a9a0dfb142ec3031"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:146fa1531973d38ab4b689bc764592fe6c2f913e7e80a39e7eeafd11f0ef6db2"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6013a37b8a4854c478d3219ee8bc2392dea51602dd0803a12d6f6182a0061762"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:eb90fe20db9c3d930fa2ad7a308207ab5b86bf6a76f54ab6a40be4012d88fcae"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:384b34482272e960c438703cafe63316dfbea124ac62006a455c8410bf2a2262"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:467dc74bd0a1a7de2bedf8deaf6811f43602cb532bd34d81ffd6038d6d8abe99"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-win32.whl", hash = "sha256:556d23d4e6393ca898b2e63a5bca91e9ac2d5fb13299ec286cd69a09a7187fde"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-win_amd64.whl", hash = "sha256:f4446a9547681533c8fa3e3c6cf62121eeee616e6a92bd9201c6edd91beffe13"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-win_arm64.whl", hash = "sha256:5e78bd9cf65da4c303bf663de0d73bf69f81e878bf72a94e9af67137c69b9fe9"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:5661bf987d91ec756a47c7e5df4fbcb949f39e32f9334ccd3f43233bbb65e508"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a46473129244db42a720439a26984f8c6f834762fc4573616c1f37f13994b357"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1f64b8d3415d60f24b058b58d859e9512624bdfa57a2d1f8aff93c1ec45c429b"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:44d43de99a9d90b20e0163f9770542357f58860a26e24dc1d924643bd6aa7cb4"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a931a87e5ddb6b6404e65443b742cb1c14959622777f2a4efd81fba84f5d91ba"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f9559b906a100029274448f4c8b8b0a127daa4dade5661dfd821b8c188058842"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:b08801e25e3b4526ef9ced1aa29344131a8f5213c60c03c18fe4c6170ffa2874"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ed9749bb8eda35f8b636fb7632f1c62f735a236a5d4edadd8bbcc5ea0542e732"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-win32.whl", hash = "sha256:609b60d123fc2cc63ccee6d17e4676699075db72d14ac3c107cc4976d516f2df"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-win_amd64.whl", hash = "sha256:0666cf3d2c1626b5a3463fd5b05f5e21f99e6aec40a3192eee4d07a15970b07f"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-win_arm64.whl", hash = "sha256:bc85eb2d35e760120540afddd3044a5bf69118a91a296a8b3940dfc4fdcfe1e2"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:62835c1b00c4a4ace24c1a88561a5a59b612fbb83a525d1c70ff5720c97c0610"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5255b3bbcc1d32a4069d6403820ac8e6dbcc1d68cb28a60a1ebf17e47028e898"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3876385722e335d6e991c430302c24251ef9c2a9701b2b390f5473199b1b8ebf"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8048ce4b149c93447a55d279078c8ae98b08a6951a3c4d2d7e87f4efc7bfe100"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4028e7558e268dd8bcf4d9484aad393cafa654c24b4885f6f9474bf53183a82a"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03f47dc870eec0367fcdd603ca6a01517d2504e83dc18dbfafae37faec66129a"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2d488d7d42b6ded7ea0704884f89dcabd2619505457de8fc9a6011c62106f6e5"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b3dcf2ead47fa8be14224ee817dfc1df98043af568fe120a22f81c0eb3c34ad2"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-win32.whl", hash = "sha256:02650a11324b80057b8c9c29487020073d5e98a498f1857f37e3f9b6ea1b2426"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:b45264dd450a10f9e03237b41a9a24e85cbb1e278e5a32adb1a303f58f0017f3"},
|
||||
{file = "coverage-7.10.5-py3-none-any.whl", hash = "sha256:0be24d35e4db1d23d0db5c0f6a74a962e2ec83c426b5cac09f4234aadef38e4a"},
|
||||
{file = "coverage-7.10.5.tar.gz", hash = "sha256:f2e57716a78bc3ae80b2207be0709a3b2b63b9f2dcf9740ee6ac03588a2015b6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
|
||||
|
||||
[package.extras]
|
||||
toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "45.0.6"
|
||||
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||
optional = false
|
||||
python-versions = "!=3.9.0,!=3.9.1,>=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "cryptography-45.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:048e7ad9e08cf4c0ab07ff7f36cc3115924e22e2266e034450a890d9e312dd74"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:44647c5d796f5fc042bbc6d61307d04bf29bccb74d188f18051b635f20a9c75f"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e40b80ecf35ec265c452eea0ba94c9587ca763e739b8e559c128d23bff7ebbbf"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:00e8724bdad672d75e6f069b27970883179bd472cd24a63f6e620ca7e41cc0c5"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7a3085d1b319d35296176af31c90338eeb2ddac8104661df79f80e1d9787b8b2"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1b7fa6a1c1188c7ee32e47590d16a5a0646270921f8020efc9a511648e1b2e08"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:275ba5cc0d9e320cd70f8e7b96d9e59903c815ca579ab96c1e37278d231fc402"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f4028f29a9f38a2025abedb2e409973709c660d44319c61762202206ed577c42"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ee411a1b977f40bd075392c80c10b58025ee5c6b47a822a33c1198598a7a5f05"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e2a21a8eda2d86bb604934b6b37691585bd095c1f788530c1fcefc53a82b3453"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-win32.whl", hash = "sha256:d063341378d7ee9c91f9d23b431a3502fc8bfacd54ef0a27baa72a0843b29159"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:833dc32dfc1e39b7376a87b9a6a4288a10aae234631268486558920029b086ec"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:3436128a60a5e5490603ab2adbabc8763613f638513ffa7d311c900a8349a2a0"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0d9ef57b6768d9fa58e92f4947cea96ade1233c0e236db22ba44748ffedca394"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea3c42f2016a5bbf71825537c2ad753f2870191134933196bee408aac397b3d9"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:20ae4906a13716139d6d762ceb3e0e7e110f7955f3bc3876e3a07f5daadec5f3"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2dac5ec199038b8e131365e2324c03d20e97fe214af051d20c49db129844e8b3"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:18f878a34b90d688982e43f4b700408b478102dd58b3e39de21b5ebf6509c301"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:5bd6020c80c5b2b2242d6c48487d7b85700f5e0038e67b29d706f98440d66eb5"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:eccddbd986e43014263eda489abbddfbc287af5cddfd690477993dbb31e31016"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:550ae02148206beb722cfe4ef0933f9352bab26b087af00e48fdfb9ade35c5b3"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5b64e668fc3528e77efa51ca70fadcd6610e8ab231e3e06ae2bab3b31c2b8ed9"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-win32.whl", hash = "sha256:780c40fb751c7d2b0c6786ceee6b6f871e86e8718a8ff4bc35073ac353c7cd02"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-win_amd64.whl", hash = "sha256:20d15aed3ee522faac1a39fbfdfee25d17b1284bafd808e1640a74846d7c4d1b"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:705bb7c7ecc3d79a50f236adda12ca331c8e7ecfbea51edd931ce5a7a7c4f012"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:826b46dae41a1155a0c0e66fafba43d0ede1dc16570b95e40c4d83bfcf0a451d"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:cc4d66f5dc4dc37b89cfef1bd5044387f7a1f6f0abb490815628501909332d5d"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:f68f833a9d445cc49f01097d95c83a850795921b3f7cc6488731e69bde3288da"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3b5bf5267e98661b9b888a9250d05b063220dfa917a8203744454573c7eb79db"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2384f2ab18d9be88a6e4f8972923405e2dbb8d3e16c6b43f15ca491d7831bd18"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fc022c1fa5acff6def2fc6d7819bbbd31ccddfe67d075331a65d9cfb28a20983"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3de77e4df42ac8d4e4d6cdb342d989803ad37707cf8f3fbf7b088c9cbdd46427"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:599c8d7df950aa68baa7e98f7b73f4f414c9f02d0e8104a30c0182a07732638b"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:31a2b9a10530a1cb04ffd6aa1cd4d3be9ed49f7d77a4dafe198f3b382f41545c"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:e5b3dda1b00fb41da3af4c5ef3f922a200e33ee5ba0f0bc9ecf0b0c173958385"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:629127cfdcdc6806dfe234734d7cb8ac54edaf572148274fa377a7d3405b0043"},
|
||||
{file = "cryptography-45.0.6.tar.gz", hash = "sha256:5c966c732cf6e4a276ce83b6e4c729edda2df6929083a952cc7da973c539c719"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cffi = {version = ">=1.14", markers = "platform_python_implementation != \"PyPy\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3.0)", "sphinx-inline-tabs ; python_full_version >= \"3.8.0\"", "sphinx-rtd-theme (>=3.0.0) ; python_full_version >= \"3.8.0\""]
|
||||
docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"]
|
||||
nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_full_version >= \"3.8.0\""]
|
||||
pep8test = ["check-sdist ; python_full_version >= \"3.8.0\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"]
|
||||
sdist = ["build (>=1.0.0)"]
|
||||
ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==45.0.6)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test-randomorder = ["pytest-randomly"]
|
||||
|
||||
[[package]]
|
||||
name = "deprecation"
|
||||
version = "2.1.0"
|
||||
@@ -235,7 +480,7 @@ version = "1.3.0"
|
||||
description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||
@@ -710,7 +955,7 @@ version = "2.1.0"
|
||||
description = "brain-dead simple config-ini parsing"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"},
|
||||
{file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
|
||||
@@ -779,7 +1024,7 @@ version = "25.0"
|
||||
description = "Core utilities for Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
|
||||
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
|
||||
@@ -791,7 +1036,7 @@ version = "1.6.0"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"},
|
||||
{file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"},
|
||||
@@ -883,6 +1128,19 @@ files = [
|
||||
[package.dependencies]
|
||||
pyasn1 = ">=0.6.1,<0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.22"
|
||||
description = "C parser in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "platform_python_implementation != \"PyPy\""
|
||||
files = [
|
||||
{file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"},
|
||||
{file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.11.7"
|
||||
@@ -1047,7 +1305,7 @@ version = "2.19.2"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"},
|
||||
{file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"},
|
||||
@@ -1068,6 +1326,9 @@ files = [
|
||||
{file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"crypto\""}
|
||||
|
||||
[package.extras]
|
||||
crypto = ["cryptography (>=3.4.0)"]
|
||||
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"]
|
||||
@@ -1092,7 +1353,7 @@ version = "8.4.1"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7"},
|
||||
{file = "pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c"},
|
||||
@@ -1116,7 +1377,7 @@ version = "1.1.0"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"},
|
||||
{file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"},
|
||||
@@ -1130,13 +1391,33 @@ pytest = ">=8.2,<9"
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "6.2.1"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5"},
|
||||
{file = "pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
coverage = {version = ">=7.5", extras = ["toml"]}
|
||||
pluggy = ">=1.2"
|
||||
pytest = ">=6.2.5"
|
||||
|
||||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-mock"
|
||||
version = "3.14.1"
|
||||
description = "Thin-wrapper around the mock package for easier use with pytest"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0"},
|
||||
{file = "pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e"},
|
||||
@@ -1253,30 +1534,31 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.3"
|
||||
version = "0.12.9"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2"},
|
||||
{file = "ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041"},
|
||||
{file = "ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f"},
|
||||
{file = "ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d"},
|
||||
{file = "ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7"},
|
||||
{file = "ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1"},
|
||||
{file = "ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77"},
|
||||
{file = "ruff-0.12.9-py3-none-linux_armv6l.whl", hash = "sha256:fcebc6c79fcae3f220d05585229463621f5dbf24d79fdc4936d9302e177cfa3e"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aed9d15f8c5755c0e74467731a007fcad41f19bcce41cd75f768bbd687f8535f"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5b15ea354c6ff0d7423814ba6d44be2807644d0c05e9ed60caca87e963e93f70"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d596c2d0393c2502eaabfef723bd74ca35348a8dac4267d18a94910087807c53"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b15599931a1a7a03c388b9c5df1bfa62be7ede6eb7ef753b272381f39c3d0ff"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d02faa2977fb6f3f32ddb7828e212b7dd499c59eb896ae6c03ea5c303575756"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:17d5b6b0b3a25259b69ebcba87908496e6830e03acfb929ef9fd4c58675fa2ea"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72db7521860e246adbb43f6ef464dd2a532ef2ef1f5dd0d470455b8d9f1773e0"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a03242c1522b4e0885af63320ad754d53983c9599157ee33e77d748363c561ce"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fc83e4e9751e6c13b5046d7162f205d0a7bac5840183c5beebf824b08a27340"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:881465ed56ba4dd26a691954650de6ad389a2d1fdb130fe51ff18a25639fe4bb"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:43f07a3ccfc62cdb4d3a3348bf0588358a66da756aa113e071b8ca8c3b9826af"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:07adb221c54b6bba24387911e5734357f042e5669fa5718920ee728aba3cbadc"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f5cd34fabfdea3933ab85d72359f118035882a01bff15bd1d2b15261d85d5f66"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:f6be1d2ca0686c54564da8e7ee9e25f93bdd6868263805f8c0b8fc6a449db6d7"},
|
||||
{file = "ruff-0.12.9-py3-none-win32.whl", hash = "sha256:cc7a37bd2509974379d0115cc5608a1a4a6c4bff1b452ea69db83c8855d53f93"},
|
||||
{file = "ruff-0.12.9-py3-none-win_amd64.whl", hash = "sha256:6fb15b1977309741d7d098c8a3cb7a30bc112760a00fb6efb7abc85f00ba5908"},
|
||||
{file = "ruff-0.12.9-py3-none-win_arm64.whl", hash = "sha256:63c8c819739d86b96d500cce885956a1a48ab056bbcbc61b747ad494b2485089"},
|
||||
{file = "ruff-0.12.9.tar.gz", hash = "sha256:fbd94b2e3c623f659962934e52c2bea6fc6da11f667a427a368adaf3af2c866a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1410,7 +1692,7 @@ version = "2.2.1"
|
||||
description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
@@ -1453,11 +1735,12 @@ version = "4.14.1"
|
||||
description = "Backported and Experimental Type Hints for Python 3.9+"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
|
||||
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
|
||||
]
|
||||
markers = {dev = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspection"
|
||||
@@ -1614,4 +1897,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "f67db13e6f68b1d67a55eee908c1c560bfa44da8509f98f842889a7570a9830f"
|
||||
content-hash = "ef7818fba061cea2841c6d7ca4852acde83e4f73b32fca1315e58660002bb0d0"
|
||||
|
||||
@@ -15,15 +15,17 @@ google-cloud-logging = "^3.12.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.12.3"
|
||||
ruff = "^0.12.9"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -16,7 +16,6 @@ DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
ENABLE_AUTH=true
|
||||
|
||||
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
|
||||
# Redis Configuration
|
||||
@@ -31,7 +30,7 @@ RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
# Supabase Authentication
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
@@ -106,6 +105,15 @@ TODOIST_CLIENT_SECRET=
|
||||
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
|
||||
# Discord OAuth App credentials
|
||||
# 1. Go to https://discord.com/developers/applications
|
||||
# 2. Create a new application
|
||||
# 3. Go to OAuth2 section and add redirect URI: http://localhost:3000/auth/integrations/oauth_callback
|
||||
# 4. Copy Client ID and Client Secret below
|
||||
DISCORD_CLIENT_ID=
|
||||
DISCORD_CLIENT_SECRET=
|
||||
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
|
||||
@@ -166,4 +174,4 @@ SMARTLEAD_API_KEY=
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
AUTOMOD_API_KEY=
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
FROM python:3.11.10-slim-bookworm AS builder
|
||||
FROM debian:13-slim AS builder
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Update package list and install build dependencies in a single layer
|
||||
# Update package list and install Python and build dependencies
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
python3.13-venv \
|
||||
python3-pip \
|
||||
build-essential \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
@@ -19,13 +24,11 @@ RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV POETRY_VIRTUALENVS_CREATE=true
|
||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
RUN pip3 install poetry
|
||||
RUN pip3 install poetry --break-system-packages
|
||||
|
||||
# Copy and install dependencies
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
@@ -37,27 +40,30 @@ RUN poetry install --no-ansi --no-root
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM python:3.11.10-slim-bookworm AS server_dependencies
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false
|
||||
POETRY_VIRTUALENVS_CREATE=true \
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3.11 /usr/local/lib/python3.11
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||
RUN mkdir -p /app/autogpt_platform/backend
|
||||
|
||||
@@ -132,17 +132,58 @@ def test_endpoint_success(snapshot: Snapshot):
|
||||
|
||||
### Testing with Authentication
|
||||
|
||||
For the main API routes that use JWT authentication, auth is provided by the `autogpt_libs.auth` module. If the test actually uses the `user_id`, the recommended approach for testing is to mock the `get_jwt_payload` function, which underpins all higher-level auth functions used in the API (`requires_user`, `requires_admin_user`, `get_user_id`).
|
||||
|
||||
If the test doesn't need the `user_id` specifically, mocking is not necessary as during tests auth is disabled anyway (see `conftest.py`).
|
||||
|
||||
#### Using Global Auth Fixtures
|
||||
|
||||
Two global auth fixtures are provided by `backend/server/conftest.py`:
|
||||
|
||||
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
||||
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
||||
|
||||
These provide the easiest way to set up authentication mocking in test modules:
|
||||
|
||||
```python
|
||||
def override_auth_middleware():
|
||||
return {"sub": "test-user-id"}
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
def override_get_user_id():
|
||||
return "test-user-id"
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
app.dependency_overrides[auth_middleware] = override_auth_middleware
|
||||
app.dependency_overrides[get_user_id] = override_get_user_id
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user['get_jwt_payload']
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
```
|
||||
|
||||
For admin-only endpoints, use `mock_jwt_admin` instead:
|
||||
|
||||
```python
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_admin):
|
||||
"""Setup auth overrides for admin tests"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin['get_jwt_payload']
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
```
|
||||
|
||||
The IDs are also available separately as fixtures:
|
||||
|
||||
- `test_user_id`
|
||||
- `admin_user_id`
|
||||
- `target_user_id` (for admin <-> user operations)
|
||||
|
||||
### Mocking External Services
|
||||
|
||||
```python
|
||||
@@ -153,10 +194,10 @@ def test_external_api_call(mocker, snapshot):
|
||||
"backend.services.external_api.call",
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
|
||||
response = client.post("/api/process")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(
|
||||
json.dumps(response.json(), indent=2, sort_keys=True),
|
||||
@@ -187,6 +228,17 @@ def test_external_api_call(mocker, snapshot):
|
||||
- Use `async def` with `@pytest.mark.asyncio` for testing async functions directly
|
||||
|
||||
### 5. Fixtures
|
||||
|
||||
#### Global Fixtures (conftest.py)
|
||||
|
||||
Authentication fixtures are available globally from `conftest.py`:
|
||||
|
||||
- `mock_jwt_user` - Standard user authentication
|
||||
- `mock_jwt_admin` - Admin user authentication
|
||||
- `configured_snapshot` - Pre-configured snapshot fixture
|
||||
|
||||
#### Custom Fixtures
|
||||
|
||||
Create reusable fixtures for common test data:
|
||||
|
||||
```python
|
||||
@@ -202,9 +254,18 @@ def test_create_user(sample_user, snapshot):
|
||||
# ... test implementation
|
||||
```
|
||||
|
||||
#### Test Isolation
|
||||
|
||||
All tests must use fixtures that ensure proper isolation:
|
||||
|
||||
- Authentication overrides are automatically cleaned up after each test
|
||||
- Database connections are properly managed with cleanup
|
||||
- Mock objects are reset between tests
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
The GitHub Actions workflow automatically runs tests on:
|
||||
|
||||
- Pull requests
|
||||
- Pushes to main branch
|
||||
|
||||
@@ -216,16 +277,19 @@ Snapshot tests work in CI by:
|
||||
## Troubleshooting
|
||||
|
||||
### Snapshot Mismatches
|
||||
|
||||
- Review the diff carefully
|
||||
- If changes are expected: `poetry run pytest --snapshot-update`
|
||||
- If changes are unexpected: Fix the code causing the difference
|
||||
|
||||
### Async Test Issues
|
||||
|
||||
- Ensure async functions use `@pytest.mark.asyncio`
|
||||
- Use `AsyncMock` for mocking async functions
|
||||
- FastAPI TestClient handles async automatically
|
||||
|
||||
### Import Errors
|
||||
|
||||
- Check that all dependencies are in `pyproject.toml`
|
||||
- Run `poetry install` to ensure dependencies are installed
|
||||
- Verify import paths are correct
|
||||
@@ -234,4 +298,4 @@ Snapshot tests work in CI by:
|
||||
|
||||
Snapshot testing provides a powerful way to ensure API responses remain consistent. Combined with traditional assertions, it creates a robust test suite that catches regressions while remaining maintainable.
|
||||
|
||||
Remember: Good tests are as important as good code!
|
||||
Remember: Good tests are as important as good code!
|
||||
|
||||
@@ -25,6 +25,9 @@ class AgentExecutorBlock(Block):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
agent_name: Optional[str] = SchemaField(
|
||||
default=None, description="Name to display in the Builder UI"
|
||||
)
|
||||
|
||||
inputs: BlockInput = SchemaField(description="Input data for the graph")
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
|
||||
@@ -166,7 +166,7 @@ class AIMusicGeneratorBlock(Block):
|
||||
output_format=input_data.output_format,
|
||||
normalization_strategy=input_data.normalization_strategy,
|
||||
)
|
||||
if result and result != "No output received":
|
||||
if result and isinstance(result, str) and result.startswith("http"):
|
||||
yield "result", result
|
||||
return
|
||||
else:
|
||||
|
||||
205
autogpt_platform/backend/backend/blocks/baas/_api.py
Normal file
205
autogpt_platform/backend/backend/blocks/baas/_api.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Meeting BaaS API client module.
|
||||
All API calls centralized for consistency and maintainability.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import Requests
|
||||
|
||||
|
||||
class MeetingBaasAPI:
|
||||
"""Client for Meeting BaaS API endpoints."""
|
||||
|
||||
BASE_URL = "https://api.meetingbaas.com"
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""Initialize API client with authentication key."""
|
||||
self.api_key = api_key
|
||||
self.headers = {"x-meeting-baas-api-key": api_key}
|
||||
self.requests = Requests()
|
||||
|
||||
# Bot Management Endpoints
|
||||
|
||||
async def join_meeting(
|
||||
self,
|
||||
bot_name: str,
|
||||
meeting_url: str,
|
||||
reserved: bool = False,
|
||||
bot_image: Optional[str] = None,
|
||||
entry_message: Optional[str] = None,
|
||||
start_time: Optional[int] = None,
|
||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
automatic_leave: Optional[Dict[str, Any]] = None,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
recording_mode: str = "speaker_view",
|
||||
streaming: Optional[Dict[str, Any]] = None,
|
||||
deduplication_key: Optional[str] = None,
|
||||
zoom_sdk_id: Optional[str] = None,
|
||||
zoom_sdk_pwd: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Deploy a bot to join and record a meeting.
|
||||
|
||||
POST /bots
|
||||
"""
|
||||
body = {
|
||||
"bot_name": bot_name,
|
||||
"meeting_url": meeting_url,
|
||||
"reserved": reserved,
|
||||
"recording_mode": recording_mode,
|
||||
}
|
||||
|
||||
# Add optional fields if provided
|
||||
if bot_image is not None:
|
||||
body["bot_image"] = bot_image
|
||||
if entry_message is not None:
|
||||
body["entry_message"] = entry_message
|
||||
if start_time is not None:
|
||||
body["start_time"] = start_time
|
||||
if speech_to_text is not None:
|
||||
body["speech_to_text"] = speech_to_text
|
||||
if webhook_url is not None:
|
||||
body["webhook_url"] = webhook_url
|
||||
if automatic_leave is not None:
|
||||
body["automatic_leave"] = automatic_leave
|
||||
if extra is not None:
|
||||
body["extra"] = extra
|
||||
if streaming is not None:
|
||||
body["streaming"] = streaming
|
||||
if deduplication_key is not None:
|
||||
body["deduplication_key"] = deduplication_key
|
||||
if zoom_sdk_id is not None:
|
||||
body["zoom_sdk_id"] = zoom_sdk_id
|
||||
if zoom_sdk_pwd is not None:
|
||||
body["zoom_sdk_pwd"] = zoom_sdk_pwd
|
||||
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots",
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def leave_meeting(self, bot_id: str) -> bool:
|
||||
"""
|
||||
Remove a bot from an ongoing meeting.
|
||||
|
||||
DELETE /bots/{uuid}
|
||||
"""
|
||||
response = await self.requests.delete(
|
||||
f"{self.BASE_URL}/bots/{bot_id}",
|
||||
headers=self.headers,
|
||||
)
|
||||
return response.status in [200, 204]
|
||||
|
||||
async def retranscribe(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Re-run transcription on a bot's audio.
|
||||
|
||||
POST /bots/retranscribe
|
||||
"""
|
||||
body: Dict[str, Any] = {"bot_uuid": bot_uuid}
|
||||
|
||||
if speech_to_text is not None:
|
||||
body["speech_to_text"] = speech_to_text
|
||||
if webhook_url is not None:
|
||||
body["webhook_url"] = webhook_url
|
||||
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots/retranscribe",
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
if response.status == 202:
|
||||
return {"accepted": True}
|
||||
return response.json()
|
||||
|
||||
# Data Retrieval Endpoints
|
||||
|
||||
async def get_meeting_data(
|
||||
self, bot_id: str, include_transcripts: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve meeting data including recording and transcripts.
|
||||
|
||||
GET /bots/meeting_data
|
||||
"""
|
||||
params = {
|
||||
"bot_id": bot_id,
|
||||
"include_transcripts": str(include_transcripts).lower(),
|
||||
}
|
||||
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/meeting_data",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def get_screenshots(self, bot_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve screenshots captured during a meeting.
|
||||
|
||||
GET /bots/{uuid}/screenshots
|
||||
"""
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/{bot_id}/screenshots",
|
||||
headers=self.headers,
|
||||
)
|
||||
result = response.json()
|
||||
# Ensure we return a list
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
return []
|
||||
|
||||
async def delete_data(self, bot_id: str) -> bool:
|
||||
"""
|
||||
Delete a bot's recorded data.
|
||||
|
||||
POST /bots/{uuid}/delete_data
|
||||
"""
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots/{bot_id}/delete_data",
|
||||
headers=self.headers,
|
||||
)
|
||||
return response.status == 200
|
||||
|
||||
async def list_bots_with_metadata(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: Optional[str] = None,
|
||||
filter_by: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
List bots with metadata including IDs, names, and meeting details.
|
||||
|
||||
GET /bots/bots_with_metadata
|
||||
"""
|
||||
params = {}
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if sort_by is not None:
|
||||
params["sort_by"] = sort_by
|
||||
if sort_order is not None:
|
||||
params["sort_order"] = sort_order
|
||||
if filter_by is not None:
|
||||
params.update(filter_by)
|
||||
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/bots_with_metadata",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
13
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
13
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Shared configuration for all Meeting BaaS blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Configure the Meeting BaaS provider with API key authentication
|
||||
baas = (
|
||||
ProviderBuilder("baas")
|
||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
||||
.build()
|
||||
)
|
||||
217
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
217
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Meeting BaaS bot (recording) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import MeetingBaasAPI
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasBotJoinMeetingBlock(Block):
|
||||
"""
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
meeting_url: str = SchemaField(
|
||||
description="The URL of the meeting the bot should join"
|
||||
)
|
||||
bot_name: str = SchemaField(
|
||||
description="Display name for the bot in the meeting"
|
||||
)
|
||||
bot_image: str = SchemaField(
|
||||
description="URL to an image for the bot's avatar (16:9 ratio recommended)",
|
||||
default="",
|
||||
)
|
||||
entry_message: str = SchemaField(
|
||||
description="Chat message the bot will post upon entry", default=""
|
||||
)
|
||||
reserved: bool = SchemaField(
|
||||
description="Use a reserved bot slot (joins 4 min before meeting)",
|
||||
default=False,
|
||||
)
|
||||
start_time: Optional[int] = SchemaField(
|
||||
description="Unix timestamp (ms) when bot should join", default=None
|
||||
)
|
||||
webhook_url: str | None = SchemaField(
|
||||
description="URL to receive webhook events for this bot", default=None
|
||||
)
|
||||
timeouts: dict = SchemaField(
|
||||
description="Automatic leave timeouts configuration", default={}
|
||||
)
|
||||
extra: dict = SchemaField(
|
||||
description="Custom metadata to attach to the bot", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
||||
join_response: dict = SchemaField(
|
||||
description="Full response from join operation"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="377d1a6a-a99b-46cf-9af3-1d1b12758e04",
|
||||
description="Deploy a bot to join and record a meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Call API with all parameters
|
||||
data = await api.join_meeting(
|
||||
bot_name=input_data.bot_name,
|
||||
meeting_url=input_data.meeting_url,
|
||||
reserved=input_data.reserved,
|
||||
bot_image=input_data.bot_image if input_data.bot_image else None,
|
||||
entry_message=(
|
||||
input_data.entry_message if input_data.entry_message else None
|
||||
),
|
||||
start_time=input_data.start_time,
|
||||
speech_to_text={"provider": "Default"},
|
||||
webhook_url=input_data.webhook_url if input_data.webhook_url else None,
|
||||
automatic_leave=input_data.timeouts if input_data.timeouts else None,
|
||||
extra=input_data.extra if input_data.extra else None,
|
||||
)
|
||||
|
||||
yield "bot_id", data.get("bot_id", "")
|
||||
yield "join_response", data
|
||||
|
||||
|
||||
class BaasBotLeaveMeetingBlock(Block):
|
||||
"""
|
||||
Force the bot to exit the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
||||
|
||||
class Output(BlockSchema):
|
||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf77d128-8b25-4280-b5c7-2d553ba7e482",
|
||||
description="Remove a bot from an ongoing meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Leave meeting
|
||||
left = await api.leave_meeting(input_data.bot_id)
|
||||
|
||||
yield "left", left
|
||||
|
||||
|
||||
class BaasBotFetchMeetingDataBlock(Block):
|
||||
"""
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to fetch")
|
||||
include_transcripts: bool = SchemaField(
|
||||
description="Include transcript data in response", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
mp4_url: str = SchemaField(
|
||||
description="URL to download the meeting recording (time-limited)"
|
||||
)
|
||||
transcript: list = SchemaField(description="Meeting transcript data")
|
||||
metadata: dict = SchemaField(description="Meeting metadata and bot information")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ea7c1309-303c-4da1-893f-89c0e9d64e78",
|
||||
description="Retrieve recorded meeting data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Fetch meeting data
|
||||
data = await api.get_meeting_data(
|
||||
bot_id=input_data.bot_id,
|
||||
include_transcripts=input_data.include_transcripts,
|
||||
)
|
||||
|
||||
yield "mp4_url", data.get("mp4", "")
|
||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
||||
|
||||
|
||||
class BaasBotDeleteRecordingBlock(Block):
|
||||
"""
|
||||
Purge MP4 + transcript data for privacy or storage management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the data was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf8d1aa6-42d8-4944-b6bd-6bac554c0d3b",
|
||||
description="Permanently delete a meeting's recorded data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Delete recording data
|
||||
deleted = await api.delete_data(input_data.bot_id)
|
||||
|
||||
yield "deleted", deleted
|
||||
178
autogpt_platform/backend/backend/blocks/dataforseo/_api.py
Normal file
178
autogpt_platform/backend/backend/blocks/dataforseo/_api.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
DataForSEO API client with async support using the SDK patterns.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import Requests, UserPasswordCredentials
|
||||
|
||||
|
||||
class DataForSeoClient:
|
||||
"""Client for the DataForSEO API using async requests."""
|
||||
|
||||
API_URL = "https://api.dataforseo.com"
|
||||
|
||||
def __init__(self, credentials: UserPasswordCredentials):
|
||||
self.credentials = credentials
|
||||
self.requests = Requests(
|
||||
trusted_origins=["https://api.dataforseo.com"],
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Generate the authorization header using Basic Auth."""
|
||||
username = self.credentials.username.get_secret_value()
|
||||
password = self.credentials.password.get_secret_value()
|
||||
credentials_str = f"{username}:{password}"
|
||||
encoded = base64.b64encode(credentials_str.encode("ascii")).decode("ascii")
|
||||
return {
|
||||
"Authorization": f"Basic {encoded}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def keyword_suggestions(
|
||||
self,
|
||||
keyword: str,
|
||||
location_code: Optional[int] = None,
|
||||
language_code: Optional[str] = None,
|
||||
include_seed_keyword: bool = True,
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get keyword suggestions from DataForSEO Labs.
|
||||
|
||||
Args:
|
||||
keyword: Seed keyword
|
||||
location_code: Location code for targeting
|
||||
language_code: Language code (e.g., "en")
|
||||
include_seed_keyword: Include seed keyword in results
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
|
||||
Returns:
|
||||
API response with keyword suggestions
|
||||
"""
|
||||
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/keyword_suggestions/live"
|
||||
|
||||
# Build payload only with non-None values to avoid sending null fields
|
||||
task_data: dict[str, Any] = {
|
||||
"keyword": keyword,
|
||||
}
|
||||
|
||||
if location_code is not None:
|
||||
task_data["location_code"] = location_code
|
||||
if language_code is not None:
|
||||
task_data["language_code"] = language_code
|
||||
if include_seed_keyword is not None:
|
||||
task_data["include_seed_keyword"] = include_seed_keyword
|
||||
if include_serp_info is not None:
|
||||
task_data["include_serp_info"] = include_serp_info
|
||||
if include_clickstream_data is not None:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
response = await self.requests.post(
|
||||
endpoint,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check for API errors
|
||||
if response.status != 200:
|
||||
error_message = data.get("status_message", "Unknown error")
|
||||
raise Exception(
|
||||
f"DataForSEO API error ({response.status}): {error_message}"
|
||||
)
|
||||
|
||||
# Extract the results from the response
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
raise Exception(f"DataForSEO task error: {error_msg}")
|
||||
|
||||
return []
|
||||
|
||||
async def related_keywords(
|
||||
self,
|
||||
keyword: str,
|
||||
location_code: Optional[int] = None,
|
||||
language_code: Optional[str] = None,
|
||||
include_seed_keyword: bool = True,
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get related keywords from DataForSEO Labs.
|
||||
|
||||
Args:
|
||||
keyword: Seed keyword
|
||||
location_code: Location code for targeting
|
||||
language_code: Language code (e.g., "en")
|
||||
include_seed_keyword: Include seed keyword in results
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
|
||||
Returns:
|
||||
API response with related keywords
|
||||
"""
|
||||
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/related_keywords/live"
|
||||
|
||||
# Build payload only with non-None values to avoid sending null fields
|
||||
task_data: dict[str, Any] = {
|
||||
"keyword": keyword,
|
||||
}
|
||||
|
||||
if location_code is not None:
|
||||
task_data["location_code"] = location_code
|
||||
if language_code is not None:
|
||||
task_data["language_code"] = language_code
|
||||
if include_seed_keyword is not None:
|
||||
task_data["include_seed_keyword"] = include_seed_keyword
|
||||
if include_serp_info is not None:
|
||||
task_data["include_serp_info"] = include_serp_info
|
||||
if include_clickstream_data is not None:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
response = await self.requests.post(
|
||||
endpoint,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check for API errors
|
||||
if response.status != 200:
|
||||
error_message = data.get("status_message", "Unknown error")
|
||||
raise Exception(
|
||||
f"DataForSEO API error ({response.status}): {error_message}"
|
||||
)
|
||||
|
||||
# Extract the results from the response
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
raise Exception(f"DataForSEO task error: {error_msg}")
|
||||
|
||||
return []
|
||||
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Configuration for all DataForSEO blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Build the DataForSEO provider with username/password authentication
|
||||
dataforseo = (
|
||||
ProviderBuilder("dataforseo")
|
||||
.with_user_password(
|
||||
username_env_var="DATAFORSEO_USERNAME",
|
||||
password_env_var="DATAFORSEO_PASSWORD",
|
||||
title="DataForSEO Credentials",
|
||||
)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
DataForSEO Google Keyword Suggestions block.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class KeywordSuggestion(BlockSchema):
|
||||
"""Schema for a keyword suggestion result."""
|
||||
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="data from SERP for each keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
|
||||
class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
"""Block for getting keyword suggestions from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
keyword: str = SchemaField(description="Seed keyword to get suggestions for")
|
||||
location_code: Optional[int] = SchemaField(
|
||||
description="Location code for targeting (e.g., 2840 for USA)",
|
||||
default=2840, # USA
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (e.g., 'en' for English)",
|
||||
default="en",
|
||||
)
|
||||
include_seed_keyword: bool = SchemaField(
|
||||
description="Include the seed keyword in results",
|
||||
default=True,
|
||||
)
|
||||
include_serp_info: bool = SchemaField(
|
||||
description="Include SERP information",
|
||||
default=False,
|
||||
)
|
||||
include_clickstream_data: bool = SchemaField(
|
||||
description="Include clickstream metrics",
|
||||
default=False,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results (up to 3000)",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
suggestions: List[KeywordSuggestion] = SchemaField(
|
||||
description="List of keyword suggestions with metrics"
|
||||
)
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="A single keyword suggestion with metrics"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of suggestions returned"
|
||||
)
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="73c3e7c4-2b3f-4e9f-9e3e-8f7a5c3e2d45",
|
||||
description="Get keyword suggestions from DataForSEO Labs Google API",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": dataforseo.get_test_credentials().model_dump(),
|
||||
"keyword": "digital marketing",
|
||||
"location_code": 2840,
|
||||
"language_code": "en",
|
||||
"limit": 1,
|
||||
},
|
||||
test_credentials=dataforseo.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"suggestion",
|
||||
lambda x: hasattr(x, "keyword")
|
||||
and x.keyword == "digital marketing strategy",
|
||||
),
|
||||
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
|
||||
("total_count", 1),
|
||||
("seed_keyword", "digital marketing"),
|
||||
],
|
||||
test_mock={
|
||||
"_fetch_keyword_suggestions": lambda *args, **kwargs: [
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"keyword": "digital marketing strategy",
|
||||
"keyword_info": {
|
||||
"search_volume": 10000,
|
||||
"competition": 0.5,
|
||||
"cpc": 2.5,
|
||||
},
|
||||
"keyword_properties": {
|
||||
"keyword_difficulty": 50,
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def _fetch_keyword_suggestions(
|
||||
self,
|
||||
client: DataForSeoClient,
|
||||
input_data: Input,
|
||||
) -> Any:
|
||||
"""Private method to fetch keyword suggestions - can be mocked for testing."""
|
||||
return await client.keyword_suggestions(
|
||||
keyword=input_data.keyword,
|
||||
location_code=input_data.location_code,
|
||||
language_code=input_data.language_code,
|
||||
include_seed_keyword=input_data.include_seed_keyword,
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: UserPasswordCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
"""Extracts individual fields from a KeywordSuggestion object."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="The keyword suggestion object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="data from SERP for each keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4193cb94-677c-48b0-9eec-6ac72fffd0f2",
|
||||
description="Extract individual fields from a KeywordSuggestion object",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"suggestion": KeywordSuggestion(
|
||||
keyword="test keyword",
|
||||
search_volume=1000,
|
||||
competition=0.5,
|
||||
cpc=2.5,
|
||||
keyword_difficulty=60,
|
||||
).model_dump()
|
||||
},
|
||||
test_output=[
|
||||
("keyword", "test keyword"),
|
||||
("search_volume", 1000),
|
||||
("competition", 0.5),
|
||||
("cpc", 2.5),
|
||||
("keyword_difficulty", 60),
|
||||
("serp_info", None),
|
||||
("clickstream_data", None),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Extract fields from the KeywordSuggestion object."""
|
||||
suggestion = input_data.suggestion
|
||||
|
||||
yield "keyword", suggestion.keyword
|
||||
yield "search_volume", suggestion.search_volume
|
||||
yield "competition", suggestion.competition
|
||||
yield "cpc", suggestion.cpc
|
||||
yield "keyword_difficulty", suggestion.keyword_difficulty
|
||||
yield "serp_info", suggestion.serp_info
|
||||
yield "clickstream_data", suggestion.clickstream_data
|
||||
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
DataForSEO Google Related Keywords block.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class RelatedKeyword(BlockSchema):
|
||||
"""Schema for a related keyword result."""
|
||||
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="SERP data for the keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
|
||||
class DataForSeoRelatedKeywordsBlock(Block):
|
||||
"""Block for getting related keywords from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
keyword: str = SchemaField(
|
||||
description="Seed keyword to find related keywords for"
|
||||
)
|
||||
location_code: Optional[int] = SchemaField(
|
||||
description="Location code for targeting (e.g., 2840 for USA)",
|
||||
default=2840, # USA
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (e.g., 'en' for English)",
|
||||
default="en",
|
||||
)
|
||||
include_seed_keyword: bool = SchemaField(
|
||||
description="Include the seed keyword in results",
|
||||
default=True,
|
||||
)
|
||||
include_serp_info: bool = SchemaField(
|
||||
description="Include SERP information",
|
||||
default=False,
|
||||
)
|
||||
include_clickstream_data: bool = SchemaField(
|
||||
description="Include clickstream metrics",
|
||||
default=False,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results (up to 3000)",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
||||
description="List of related keywords with metrics"
|
||||
)
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="A related keyword with metrics"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of related keywords returned"
|
||||
)
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f2e4d6a-1b3c-4a5e-9d7f-2c8e6a4b3f1d",
|
||||
description="Get related keywords from DataForSEO Labs Google API",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": dataforseo.get_test_credentials().model_dump(),
|
||||
"keyword": "content marketing",
|
||||
"location_code": 2840,
|
||||
"language_code": "en",
|
||||
"limit": 1,
|
||||
},
|
||||
test_credentials=dataforseo.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"related_keyword",
|
||||
lambda x: hasattr(x, "keyword") and x.keyword == "content strategy",
|
||||
),
|
||||
("related_keywords", lambda x: isinstance(x, list) and len(x) == 1),
|
||||
("total_count", 1),
|
||||
("seed_keyword", "content marketing"),
|
||||
],
|
||||
test_mock={
|
||||
"_fetch_related_keywords": lambda *args, **kwargs: [
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"keyword_data": {
|
||||
"keyword": "content strategy",
|
||||
"keyword_info": {
|
||||
"search_volume": 8000,
|
||||
"competition": 0.4,
|
||||
"cpc": 3.0,
|
||||
},
|
||||
"keyword_properties": {
|
||||
"keyword_difficulty": 45,
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def _fetch_related_keywords(
|
||||
self,
|
||||
client: DataForSeoClient,
|
||||
input_data: Input,
|
||||
) -> Any:
|
||||
"""Private method to fetch related keywords - can be mocked for testing."""
|
||||
return await client.related_keywords(
|
||||
keyword=input_data.keyword,
|
||||
location_code=input_data.location_code,
|
||||
language_code=input_data.language_code,
|
||||
include_seed_keyword=input_data.include_seed_keyword,
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: UserPasswordCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
"""Extracts individual fields from a RelatedKeyword object."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="The related keyword object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="SERP data for the keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="98342061-09d2-4952-bf77-0761fc8cc9a8",
|
||||
description="Extract individual fields from a RelatedKeyword object",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"related_keyword": RelatedKeyword(
|
||||
keyword="test related keyword",
|
||||
search_volume=800,
|
||||
competition=0.4,
|
||||
cpc=3.0,
|
||||
keyword_difficulty=55,
|
||||
).model_dump()
|
||||
},
|
||||
test_output=[
|
||||
("keyword", "test related keyword"),
|
||||
("search_volume", 800),
|
||||
("competition", 0.4),
|
||||
("cpc", 3.0),
|
||||
("keyword_difficulty", 55),
|
||||
("serp_info", None),
|
||||
("clickstream_data", None),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Extract fields from the RelatedKeyword object."""
|
||||
related_keyword = input_data.related_keyword
|
||||
|
||||
yield "keyword", related_keyword.keyword
|
||||
yield "search_volume", related_keyword.search_volume
|
||||
yield "competition", related_keyword.competition
|
||||
yield "cpc", related_keyword.cpc
|
||||
yield "keyword_difficulty", related_keyword.keyword_difficulty
|
||||
yield "serp_info", related_keyword.serp_info
|
||||
yield "clickstream_data", related_keyword.clickstream_data
|
||||
117
autogpt_platform/backend/backend/blocks/discord/_api.py
Normal file
117
autogpt_platform/backend/backend/blocks/discord/_api.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Discord API helper functions for making authenticated requests.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscordAPIException(Exception):
|
||||
"""Exception raised for Discord API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class DiscordOAuthUser(BaseModel):
|
||||
"""Model for Discord OAuth user response."""
|
||||
|
||||
user_id: str
|
||||
username: str
|
||||
avatar_url: str
|
||||
banner: Optional[str] = None
|
||||
accent_color: Optional[int] = None
|
||||
|
||||
|
||||
def get_api(credentials: OAuth2Credentials) -> Requests:
|
||||
"""
|
||||
Create a Requests instance configured for Discord API calls with OAuth2 credentials.
|
||||
|
||||
Args:
|
||||
credentials: The OAuth2 credentials containing the access token.
|
||||
|
||||
Returns:
|
||||
A configured Requests instance for Discord API calls.
|
||||
"""
|
||||
return Requests(
|
||||
trusted_origins=[],
|
||||
extra_headers={
|
||||
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
|
||||
"""
|
||||
Fetch the current user's information using Discord OAuth2 API.
|
||||
|
||||
Reference: https://discord.com/developers/docs/resources/user#get-current-user
|
||||
|
||||
Args:
|
||||
credentials: The OAuth2 credentials.
|
||||
|
||||
Returns:
|
||||
A model containing user data with avatar URL.
|
||||
|
||||
Raises:
|
||||
DiscordAPIException: If the API request fails.
|
||||
"""
|
||||
api = get_api(credentials)
|
||||
response = await api.get("https://discord.com/api/oauth2/@me")
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise DiscordAPIException(
|
||||
f"Failed to fetch user info: {response.status} - {error_text}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
logger.info(f"Discord OAuth2 API Response: {data}")
|
||||
|
||||
# The /api/oauth2/@me endpoint returns a user object nested in the response
|
||||
user_info = data.get("user", {})
|
||||
logger.info(f"User info extracted: {user_info}")
|
||||
|
||||
# Build avatar URL
|
||||
user_id = user_info.get("id")
|
||||
avatar_hash = user_info.get("avatar")
|
||||
if avatar_hash:
|
||||
# Custom avatar
|
||||
avatar_ext = "gif" if avatar_hash.startswith("a_") else "png"
|
||||
avatar_url = (
|
||||
f"https://cdn.discordapp.com/avatars/{user_id}/{avatar_hash}.{avatar_ext}"
|
||||
)
|
||||
else:
|
||||
# Default avatar based on discriminator or user ID
|
||||
discriminator = user_info.get("discriminator", "0")
|
||||
if discriminator == "0":
|
||||
# New username system - use user ID for default avatar
|
||||
default_avatar_index = (int(user_id) >> 22) % 6
|
||||
else:
|
||||
# Legacy discriminator system
|
||||
default_avatar_index = int(discriminator) % 5
|
||||
avatar_url = (
|
||||
f"https://cdn.discordapp.com/embed/avatars/{default_avatar_index}.png"
|
||||
)
|
||||
|
||||
result = DiscordOAuthUser(
|
||||
user_id=user_id,
|
||||
username=user_info.get("username", ""),
|
||||
avatar_url=avatar_url,
|
||||
banner=user_info.get("banner"),
|
||||
accent_color=user_info.get("accent_color"),
|
||||
)
|
||||
|
||||
logger.info(f"Returning user data: {result.model_dump()}")
|
||||
return result
|
||||
74
autogpt_platform/backend/backend/blocks/discord/_auth.py
Normal file
74
autogpt_platform/backend/backend/blocks/discord/_auth.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
DISCORD_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.discord_client_id and secrets.discord_client_secret
|
||||
)
|
||||
|
||||
# Bot token credentials (existing)
|
||||
DiscordBotCredentials = APIKeyCredentials
|
||||
DiscordBotCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||
]
|
||||
|
||||
# OAuth2 credentials (new)
|
||||
DiscordOAuthCredentials = OAuth2Credentials
|
||||
DiscordOAuthCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
def DiscordBotCredentialsField() -> DiscordBotCredentialsInput:
|
||||
"""Creates a Discord bot token credentials field."""
|
||||
return CredentialsField(description="Discord bot token")
|
||||
|
||||
|
||||
def DiscordOAuthCredentialsField(scopes: list[str]) -> DiscordOAuthCredentialsInput:
|
||||
"""Creates a Discord OAuth2 credentials field."""
|
||||
return CredentialsField(
|
||||
description="Discord OAuth2 credentials",
|
||||
required_scopes=set(scopes) | {"identify"}, # Basic user info scope
|
||||
)
|
||||
|
||||
|
||||
# Test credentials for bot tokens
|
||||
TEST_BOT_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
api_key=SecretStr("test_api_key"),
|
||||
title="Mock Discord API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_BOT_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_BOT_CREDENTIALS.provider,
|
||||
"id": TEST_BOT_CREDENTIALS.id,
|
||||
"type": TEST_BOT_CREDENTIALS.type,
|
||||
"title": TEST_BOT_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
# Test credentials for OAuth2
|
||||
TEST_OAUTH_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
access_token=SecretStr("test_access_token"),
|
||||
title="Mock Discord OAuth",
|
||||
scopes=["identify"],
|
||||
username="testuser",
|
||||
)
|
||||
TEST_OAUTH_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_OAUTH_CREDENTIALS.provider,
|
||||
"id": TEST_OAUTH_CREDENTIALS.id,
|
||||
"type": TEST_OAUTH_CREDENTIALS.type,
|
||||
"title": TEST_OAUTH_CREDENTIALS.type,
|
||||
}
|
||||
@@ -2,45 +2,29 @@ import base64
|
||||
import io
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
DiscordCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def DiscordCredentialsField() -> DiscordCredentials:
|
||||
return CredentialsField(description="Discord bot token")
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
api_key=SecretStr("test_api_key"),
|
||||
title="Mock Discord API key",
|
||||
expires_at=None,
|
||||
from ._auth import (
|
||||
TEST_BOT_CREDENTIALS,
|
||||
TEST_BOT_CREDENTIALS_INPUT,
|
||||
DiscordBotCredentialsField,
|
||||
DiscordBotCredentialsInput,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
# Keep backward compatibility alias
|
||||
DiscordCredentials = DiscordBotCredentialsInput
|
||||
DiscordCredentialsField = DiscordBotCredentialsField
|
||||
TEST_CREDENTIALS = TEST_BOT_CREDENTIALS
|
||||
TEST_CREDENTIALS_INPUT = TEST_BOT_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
class ReadDiscordMessagesBlock(Block):
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Discord OAuth-based blocks.
|
||||
"""
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import DiscordOAuthUser, get_current_user
|
||||
from ._auth import (
|
||||
DISCORD_OAUTH_IS_CONFIGURED,
|
||||
TEST_OAUTH_CREDENTIALS,
|
||||
TEST_OAUTH_CREDENTIALS_INPUT,
|
||||
DiscordOAuthCredentialsField,
|
||||
DiscordOAuthCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class DiscordGetCurrentUserBlock(Block):
|
||||
"""
|
||||
Gets information about the currently authenticated Discord user using OAuth2.
|
||||
This block requires Discord OAuth2 credentials (not bot tokens).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordOAuthCredentialsInput = DiscordOAuthCredentialsField(
|
||||
["identify"]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
user_id: str = SchemaField(description="The authenticated user's Discord ID")
|
||||
username: str = SchemaField(description="The user's username")
|
||||
avatar_url: str = SchemaField(description="URL to the user's avatar image")
|
||||
banner_url: str = SchemaField(
|
||||
description="URL to the user's banner image (if set)", default=""
|
||||
)
|
||||
accent_color: int = SchemaField(
|
||||
description="The user's accent color as an integer", default=0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8c7e39b8-4e9d-4f3a-b4e1-2a8c9d5f6e3b",
|
||||
input_schema=DiscordGetCurrentUserBlock.Input,
|
||||
output_schema=DiscordGetCurrentUserBlock.Output,
|
||||
description="Gets information about the currently authenticated Discord user using OAuth2 credentials.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=not DISCORD_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"credentials": TEST_OAUTH_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_OAUTH_CREDENTIALS,
|
||||
test_output=[
|
||||
("user_id", "123456789012345678"),
|
||||
("username", "testuser"),
|
||||
(
|
||||
"avatar_url",
|
||||
"https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
|
||||
),
|
||||
("banner_url", ""),
|
||||
("accent_color", 0),
|
||||
],
|
||||
test_mock={
|
||||
"get_user": lambda _: DiscordOAuthUser(
|
||||
user_id="123456789012345678",
|
||||
username="testuser",
|
||||
avatar_url="https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
|
||||
banner=None,
|
||||
accent_color=0,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
|
||||
user_info = await get_current_user(credentials)
|
||||
return user_info
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.get_user(credentials)
|
||||
|
||||
# Yield each output field
|
||||
yield "user_id", result.user_id
|
||||
yield "username", result.username
|
||||
yield "avatar_url", result.avatar_url
|
||||
|
||||
# Handle banner URL if banner hash exists
|
||||
if result.banner:
|
||||
banner_url = f"https://cdn.discordapp.com/banners/{result.user_id}/{result.banner}.png"
|
||||
yield "banner_url", banner_url
|
||||
else:
|
||||
yield "banner_url", ""
|
||||
|
||||
yield "accent_color", result.accent_color or 0
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get Discord user info: {e}")
|
||||
@@ -93,11 +93,11 @@ class Webset(BaseModel):
|
||||
"""
|
||||
Set of key-value pairs you want to associate with this object.
|
||||
"""
|
||||
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
|
||||
created_at: Annotated[datetime | None, Field(alias="createdAt")] = None
|
||||
"""
|
||||
The date and time the webset was created
|
||||
"""
|
||||
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
|
||||
updated_at: Annotated[datetime | None, Field(alias="updatedAt")] = None
|
||||
"""
|
||||
The date and time the webset was last updated
|
||||
"""
|
||||
|
||||
@@ -30,6 +30,7 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
|
||||
class IdeogramModelName(str, Enum):
|
||||
V3 = "V_3"
|
||||
V2 = "V_2"
|
||||
V1 = "V_1"
|
||||
V1_TURBO = "V_1_TURBO"
|
||||
@@ -95,8 +96,8 @@ class IdeogramModelBlock(Block):
|
||||
title="Prompt",
|
||||
)
|
||||
ideogram_model_name: IdeogramModelName = SchemaField(
|
||||
description="The name of the Image Generation Model, e.g., V_2",
|
||||
default=IdeogramModelName.V2,
|
||||
description="The name of the Image Generation Model, e.g., V_3",
|
||||
default=IdeogramModelName.V3,
|
||||
title="Image Generation Model",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -236,6 +237,111 @@ class IdeogramModelBlock(Block):
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
# Use V3 endpoint for V3 model, legacy endpoint for others
|
||||
if model_name == "V_3":
|
||||
return await self._run_model_v3(
|
||||
api_key,
|
||||
prompt,
|
||||
seed,
|
||||
aspect_ratio,
|
||||
magic_prompt_option,
|
||||
style_type,
|
||||
negative_prompt,
|
||||
color_palette_name,
|
||||
custom_colors,
|
||||
)
|
||||
else:
|
||||
return await self._run_model_legacy(
|
||||
api_key,
|
||||
model_name,
|
||||
prompt,
|
||||
seed,
|
||||
aspect_ratio,
|
||||
magic_prompt_option,
|
||||
style_type,
|
||||
negative_prompt,
|
||||
color_palette_name,
|
||||
custom_colors,
|
||||
)
|
||||
|
||||
async def _run_model_v3(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
prompt: str,
|
||||
seed: Optional[int],
|
||||
aspect_ratio: str,
|
||||
magic_prompt_option: str,
|
||||
style_type: str,
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
url = "https://api.ideogram.ai/v1/ideogram-v3/generate"
|
||||
headers = {
|
||||
"Api-Key": api_key.get_secret_value(),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Map legacy aspect ratio values to V3 format
|
||||
aspect_ratio_map = {
|
||||
"ASPECT_10_16": "10x16",
|
||||
"ASPECT_16_10": "16x10",
|
||||
"ASPECT_9_16": "9x16",
|
||||
"ASPECT_16_9": "16x9",
|
||||
"ASPECT_3_2": "3x2",
|
||||
"ASPECT_2_3": "2x3",
|
||||
"ASPECT_4_3": "4x3",
|
||||
"ASPECT_3_4": "3x4",
|
||||
"ASPECT_1_1": "1x1",
|
||||
"ASPECT_1_3": "1x3",
|
||||
"ASPECT_3_1": "3x1",
|
||||
# Additional V3 supported ratios
|
||||
"ASPECT_1_2": "1x2",
|
||||
"ASPECT_2_1": "2x1",
|
||||
"ASPECT_4_5": "4x5",
|
||||
"ASPECT_5_4": "5x4",
|
||||
}
|
||||
|
||||
v3_aspect_ratio = aspect_ratio_map.get(
|
||||
aspect_ratio, "1x1"
|
||||
) # Default to 1x1 if not found
|
||||
|
||||
# Use JSON for V3 endpoint (simpler than multipart/form-data)
|
||||
data: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": v3_aspect_ratio,
|
||||
"magic_prompt": magic_prompt_option,
|
||||
"style_type": style_type,
|
||||
}
|
||||
|
||||
if seed is not None:
|
||||
data["seed"] = seed
|
||||
|
||||
if negative_prompt:
|
||||
data["negative_prompt"] = negative_prompt
|
||||
|
||||
# Note: V3 endpoint may have different color palette support
|
||||
# For now, we'll omit color palettes for V3 to avoid errors
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
|
||||
|
||||
async def _run_model_legacy(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
seed: Optional[int],
|
||||
aspect_ratio: str,
|
||||
magic_prompt_option: str,
|
||||
style_type: str,
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
url = "https://api.ideogram.ai/generate"
|
||||
headers = {
|
||||
@@ -249,28 +355,33 @@ class IdeogramModelBlock(Block):
|
||||
"model": model_name,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"magic_prompt_option": magic_prompt_option,
|
||||
"style_type": style_type,
|
||||
}
|
||||
}
|
||||
|
||||
# Only add style_type for V2, V2_TURBO, and V3 models (V1 models don't support it)
|
||||
if model_name in ["V_2", "V_2_TURBO", "V_3"]:
|
||||
data["image_request"]["style_type"] = style_type
|
||||
|
||||
if seed is not None:
|
||||
data["image_request"]["seed"] = seed
|
||||
|
||||
if negative_prompt:
|
||||
data["image_request"]["negative_prompt"] = negative_prompt
|
||||
|
||||
if color_palette_name != "NONE":
|
||||
data["color_palette"] = {"name": color_palette_name}
|
||||
elif custom_colors:
|
||||
data["color_palette"] = {
|
||||
"members": [{"color_hex": color} for color in custom_colors]
|
||||
}
|
||||
# Only add color palette for V2 and V2_TURBO models (V1 models don't support it)
|
||||
if model_name in ["V_2", "V_2_TURBO"]:
|
||||
if color_palette_name != "NONE":
|
||||
data["color_palette"] = {"name": color_palette_name}
|
||||
elif custom_colors:
|
||||
data["color_palette"] = {
|
||||
"members": [{"color_hex": color} for color in custom_colors]
|
||||
}
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image: {str(e)}")
|
||||
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
|
||||
|
||||
async def upscale_image(self, api_key: SecretStr, image_url: str):
|
||||
url = "https://api.ideogram.ai/upscale"
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
stagehand = (
|
||||
ProviderBuilder("stagehand")
|
||||
.with_api_key("STAGEHAND_API_KEY", "Stagehand API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
393
autogpt_platform/backend/backend/blocks/stagehand/blocks.py
Normal file
393
autogpt_platform/backend/backend/blocks/stagehand/blocks.py
Normal file
@@ -0,0 +1,393 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
# Monkey patch Stagehands to prevent signal handling in worker threads
|
||||
import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
ModelMetadata,
|
||||
)
|
||||
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
def safe_register_signal_handlers(self):
|
||||
"""Only register signal handlers in the main thread"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_register_signal_handlers(self)
|
||||
else:
|
||||
# Skip signal handling in worker threads
|
||||
pass
|
||||
|
||||
|
||||
# Replace the method
|
||||
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_signal_handling():
|
||||
"""Context manager to temporarily disable signal handling"""
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# In worker threads, temporarily replace signal.signal with a no-op
|
||||
original_signal = signal.signal
|
||||
|
||||
def noop_signal(*args, **kwargs):
|
||||
pass
|
||||
|
||||
signal.signal = noop_signal
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal = original_signal
|
||||
else:
|
||||
# In main thread, don't modify anything
|
||||
yield
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StagehandRecommendedLlmModel(str, Enum):
|
||||
"""
|
||||
This is subset of LLModel from autogpt_platform/backend/backend/blocks/llm.py
|
||||
|
||||
It contains only the models recommended by Stagehand
|
||||
"""
|
||||
|
||||
# OpenAI
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
|
||||
# Anthropic
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""
|
||||
Returns the provider name for the model in the required format for Stagehand:
|
||||
provider/model_name
|
||||
"""
|
||||
model_metadata = MODEL_METADATA[LlmModel(self.value)]
|
||||
model_name = self.value
|
||||
|
||||
if len(model_name.split("/")) == 1 and not self.value.startswith(
|
||||
model_metadata.provider
|
||||
):
|
||||
assert (
|
||||
model_metadata.provider != "open_router"
|
||||
), "Logic failed and open_router provider attempted to be prepended to model name! in stagehand/_config.py"
|
||||
model_name = f"{model_metadata.provider}/{model_name}"
|
||||
|
||||
logger.error(f"Model name: {model_name}")
|
||||
return model_name
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return MODEL_METADATA[LlmModel(self.value)].provider
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[LlmModel(self.value)]
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return MODEL_METADATA[LlmModel(self.value)].context_window
|
||||
|
||||
@property
|
||||
def max_output_tokens(self) -> int | None:
|
||||
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
|
||||
|
||||
|
||||
class StagehandObserveBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
# Browserbase credentials (Stagehand provider) or raw API key
|
||||
stagehand_credentials: CredentialsMetaInput = (
|
||||
stagehand_provider.credentials_field(
|
||||
description="Stagehand/Browserbase API key"
|
||||
)
|
||||
)
|
||||
browserbase_project_id: str = SchemaField(
|
||||
description="Browserbase project ID (required if using Browserbase)",
|
||||
)
|
||||
# Model selection and credentials (provider-discriminated like llm.py)
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
url: str = SchemaField(
|
||||
description="URL to navigate to.",
|
||||
)
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
selector: str = SchemaField(description="XPath selector to locate element.")
|
||||
description: str = SchemaField(description="Human-readable description")
|
||||
method: str | None = SchemaField(description="Suggested action method")
|
||||
arguments: list[str] | None = SchemaField(
|
||||
description="Additional action parameters"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d3863944-0eaf-45c4-a0c9-63e0fe1ee8b9",
|
||||
description="Find suggested actions for your workflows",
|
||||
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StagehandObserveBlock.Input,
|
||||
output_schema=StagehandObserveBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
stagehand_credentials: APIKeyCredentials,
|
||||
model_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
|
||||
observe_results = await page.observe(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
for result in observe_results:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
|
||||
|
||||
class StagehandActBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
# Browserbase credentials (Stagehand provider) or raw API key
|
||||
stagehand_credentials: CredentialsMetaInput = (
|
||||
stagehand_provider.credentials_field(
|
||||
description="Stagehand/Browserbase API key"
|
||||
)
|
||||
)
|
||||
browserbase_project_id: str = SchemaField(
|
||||
description="Browserbase project ID (required if using Browserbase)",
|
||||
)
|
||||
# Model selection and credentials (provider-discriminated like llm.py)
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
url: str = SchemaField(
|
||||
description="URL to navigate to.",
|
||||
)
|
||||
action: list[str] = SchemaField(
|
||||
description="Action to perform. Suggested actions are: click, fill, type, press, scroll, select from dropdown. For multi-step actions, add an entry for each step.",
|
||||
)
|
||||
variables: dict[str, str] = SchemaField(
|
||||
description="Variables to use in the action. Variables contains data you want the action to use.",
|
||||
default_factory=dict,
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
timeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
|
||||
default=60000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the action was completed successfully"
|
||||
)
|
||||
message: str = SchemaField(description="Details about the action’s execution.")
|
||||
action: str = SchemaField(description="Action performed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="86eba68b-9549-4c0b-a0db-47d85a56cc27",
|
||||
description="Interact with a web page by performing actions on a web page. Use it to build self-healing and deterministic automations that adapt to website chang.",
|
||||
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StagehandActBlock.Input,
|
||||
output_schema=StagehandActBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
stagehand_credentials: APIKeyCredentials,
|
||||
model_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
for action in input_data.action:
|
||||
action_results = await page.act(
|
||||
action,
|
||||
variables=input_data.variables,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
timeoutMs=input_data.timeoutMs,
|
||||
)
|
||||
yield "success", action_results.success
|
||||
yield "message", action_results.message
|
||||
yield "action", action_results.action
|
||||
|
||||
|
||||
class StagehandExtractBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
# Browserbase credentials (Stagehand provider) or raw API key
|
||||
stagehand_credentials: CredentialsMetaInput = (
|
||||
stagehand_provider.credentials_field(
|
||||
description="Stagehand/Browserbase API key"
|
||||
)
|
||||
)
|
||||
browserbase_project_id: str = SchemaField(
|
||||
description="Browserbase project ID (required if using Browserbase)",
|
||||
)
|
||||
# Model selection and credentials (provider-discriminated like llm.py)
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
url: str = SchemaField(
|
||||
description="URL to navigate to.",
|
||||
)
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
extraction: str = SchemaField(description="Extracted data from the page.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fd3c0b18-2ba6-46ae-9339-fcb40537ad98",
|
||||
description="Extract structured data from a webpage.",
|
||||
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StagehandExtractBlock.Input,
|
||||
output_schema=StagehandExtractBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
stagehand_credentials: APIKeyCredentials,
|
||||
model_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
extraction = await page.extract(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
yield "extraction", str(extraction.model_dump()["extraction"])
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,283 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
|
||||
class LibraryAgent(BaseModel):
|
||||
"""Model representing an agent in the user's library."""
|
||||
|
||||
library_agent_id: str = ""
|
||||
agent_id: str = ""
|
||||
agent_version: int = 0
|
||||
agent_name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
is_archived: bool = False
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class AddToLibraryFromStoreBlock(Block):
|
||||
"""
|
||||
Block that adds an agent from the store to the user's library.
|
||||
This enables users to easily import agents from the marketplace into their personal collection.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
store_listing_version_id: str = SchemaField(
|
||||
description="The ID of the store listing version to add to library"
|
||||
)
|
||||
agent_name: str | None = SchemaField(
|
||||
description="Optional custom name for the agent in your library",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the agent was successfully added to library"
|
||||
)
|
||||
library_agent_id: str = SchemaField(
|
||||
description="The ID of the library agent entry"
|
||||
)
|
||||
agent_id: str = SchemaField(description="The ID of the agent graph")
|
||||
agent_version: int = SchemaField(
|
||||
description="The version number of the agent graph"
|
||||
)
|
||||
agent_name: str = SchemaField(description="The name of the agent")
|
||||
message: str = SchemaField(description="Success or error message")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2602a7b1-3f4d-4e5f-9c8b-1a2b3c4d5e6f",
|
||||
description="Add an agent from the store to your personal library",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToLibraryFromStoreBlock.Input,
|
||||
output_schema=AddToLibraryFromStoreBlock.Output,
|
||||
test_input={
|
||||
"store_listing_version_id": "test-listing-id",
|
||||
"agent_name": "My Custom Agent",
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("library_agent_id", "test-library-id"),
|
||||
("agent_id", "test-agent-id"),
|
||||
("agent_version", 1),
|
||||
("agent_name", "Test Agent"),
|
||||
("message", "Agent successfully added to library"),
|
||||
],
|
||||
test_mock={
|
||||
"_add_to_library": lambda *_, **__: LibraryAgent(
|
||||
library_agent_id="test-library-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Agent",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
library_agent = await self._add_to_library(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=input_data.store_listing_version_id,
|
||||
custom_name=input_data.agent_name,
|
||||
)
|
||||
|
||||
yield "success", True
|
||||
yield "library_agent_id", library_agent.library_agent_id
|
||||
yield "agent_id", library_agent.agent_id
|
||||
yield "agent_version", library_agent.agent_version
|
||||
yield "agent_name", library_agent.agent_name
|
||||
yield "message", "Agent successfully added to library"
|
||||
|
||||
async def _add_to_library(
|
||||
self,
|
||||
user_id: str,
|
||||
store_listing_version_id: str,
|
||||
custom_name: str | None = None,
|
||||
) -> LibraryAgent:
|
||||
"""
|
||||
Add a store agent to the user's library using the existing library database function.
|
||||
"""
|
||||
library_agent = (
|
||||
await get_database_manager_async_client().add_store_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id, user_id=user_id
|
||||
)
|
||||
)
|
||||
|
||||
# If custom name is provided, we could update the library agent name here
|
||||
# For now, we'll just return the agent info
|
||||
agent_name = custom_name if custom_name else library_agent.name
|
||||
|
||||
return LibraryAgent(
|
||||
library_agent_id=library_agent.id,
|
||||
agent_id=library_agent.graph_id,
|
||||
agent_version=library_agent.graph_version,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
|
||||
class ListLibraryAgentsBlock(Block):
|
||||
"""
|
||||
Block that lists all agents in the user's library.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
search_query: str | None = SchemaField(
|
||||
description="Optional search query to filter agents", default=None
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of agents to return", default=50, ge=1, le=100
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination", default=1, ge=1
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
agents: list[LibraryAgent] = SchemaField(
|
||||
description="List of agents in the library",
|
||||
default_factory=list,
|
||||
)
|
||||
agent: LibraryAgent = SchemaField(
|
||||
description="Individual library agent (yielded for each agent)"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of agents in library", default=0
|
||||
)
|
||||
page: int = SchemaField(description="Current page number", default=1)
|
||||
total_pages: int = SchemaField(description="Total number of pages", default=1)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="082602d3-a74d-4600-9e9c-15b3af7eae98",
|
||||
description="List all agents in your personal library",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=ListLibraryAgentsBlock.Input,
|
||||
output_schema=ListLibraryAgentsBlock.Output,
|
||||
test_input={
|
||||
"search_query": None,
|
||||
"limit": 10,
|
||||
"page": 1,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"agents",
|
||||
[
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
),
|
||||
],
|
||||
),
|
||||
("total_count", 1),
|
||||
("page", 1),
|
||||
("total_pages", 1),
|
||||
(
|
||||
"agent",
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_list_library_agents": lambda *_, **__: {
|
||||
"agents": [
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
)
|
||||
],
|
||||
"total": 1,
|
||||
"page": 1,
|
||||
"total_pages": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self._list_library_agents(
|
||||
user_id=user_id,
|
||||
search_query=input_data.search_query,
|
||||
limit=input_data.limit,
|
||||
page=input_data.page,
|
||||
)
|
||||
|
||||
agents = result["agents"]
|
||||
|
||||
yield "agents", agents
|
||||
yield "total_count", result["total"]
|
||||
yield "page", result["page"]
|
||||
yield "total_pages", result["total_pages"]
|
||||
|
||||
# Yield each agent individually for better graph connectivity
|
||||
for agent in agents:
|
||||
yield "agent", agent
|
||||
|
||||
async def _list_library_agents(
|
||||
self,
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
limit: int = 50,
|
||||
page: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
List agents in the user's library using the database client.
|
||||
"""
|
||||
result = await get_database_manager_async_client().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=page,
|
||||
page_size=limit,
|
||||
)
|
||||
|
||||
agents = [
|
||||
LibraryAgent(
|
||||
library_agent_id=agent.id,
|
||||
agent_id=agent.graph_id,
|
||||
agent_version=agent.graph_version,
|
||||
agent_name=agent.name,
|
||||
description=getattr(agent, "description", ""),
|
||||
creator=getattr(agent, "creator", ""),
|
||||
is_archived=getattr(agent, "is_archived", False),
|
||||
categories=getattr(agent, "categories", []),
|
||||
)
|
||||
for agent in result.agents
|
||||
]
|
||||
|
||||
return {
|
||||
"agents": agents,
|
||||
"total": result.pagination.total_items,
|
||||
"page": result.pagination.current_page,
|
||||
"total_pages": result.pagination.total_pages,
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
|
||||
class StoreAgent(BaseModel):
|
||||
"""Model representing a store agent."""
|
||||
|
||||
slug: str = ""
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
rating: float = 0.0
|
||||
runs: int = 0
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class StoreAgentDict(BaseModel):
|
||||
"""Dictionary representation of a store agent."""
|
||||
|
||||
slug: str
|
||||
name: str
|
||||
description: str
|
||||
creator: str
|
||||
rating: float
|
||||
runs: int
|
||||
|
||||
|
||||
class SearchAgentsResponse(BaseModel):
|
||||
"""Response from searching store agents."""
|
||||
|
||||
agents: list[StoreAgentDict]
|
||||
total_count: int
|
||||
|
||||
|
||||
class StoreAgentDetails(BaseModel):
|
||||
"""Detailed information about a store agent."""
|
||||
|
||||
found: bool
|
||||
store_listing_version_id: str = ""
|
||||
agent_name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
categories: list[str] = []
|
||||
runs: int = 0
|
||||
rating: float = 0.0
|
||||
|
||||
|
||||
class GetStoreAgentDetailsBlock(Block):
|
||||
"""
|
||||
Block that retrieves detailed information about an agent from the store.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
creator: str = SchemaField(description="The username of the agent creator")
|
||||
slug: str = SchemaField(description="The name of the agent")
|
||||
|
||||
class Output(BlockSchema):
|
||||
found: bool = SchemaField(
|
||||
description="Whether the agent was found in the store"
|
||||
)
|
||||
store_listing_version_id: str = SchemaField(
|
||||
description="The store listing version ID"
|
||||
)
|
||||
agent_name: str = SchemaField(description="Name of the agent")
|
||||
description: str = SchemaField(description="Description of the agent")
|
||||
creator: str = SchemaField(description="Creator of the agent")
|
||||
categories: list[str] = SchemaField(
|
||||
description="Categories the agent belongs to", default_factory=list
|
||||
)
|
||||
runs: int = SchemaField(
|
||||
description="Number of times the agent has been run", default=0
|
||||
)
|
||||
rating: float = SchemaField(
|
||||
description="Average rating of the agent", default=0.0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b604f0ec-6e0d-40a7-bf55-9fd09997cced",
|
||||
description="Get detailed information about an agent from the store",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=GetStoreAgentDetailsBlock.Input,
|
||||
output_schema=GetStoreAgentDetailsBlock.Output,
|
||||
test_input={"creator": "test-creator", "slug": "test-agent-slug"},
|
||||
test_output=[
|
||||
("found", True),
|
||||
("store_listing_version_id", "test-listing-id"),
|
||||
("agent_name", "Test Agent"),
|
||||
("description", "A test agent"),
|
||||
("creator", "Test Creator"),
|
||||
("categories", ["productivity", "automation"]),
|
||||
("runs", 100),
|
||||
("rating", 4.5),
|
||||
],
|
||||
test_mock={
|
||||
"_get_agent_details": lambda *_, **__: StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id="test-listing-id",
|
||||
agent_name="Test Agent",
|
||||
description="A test agent",
|
||||
creator="Test Creator",
|
||||
categories=["productivity", "automation"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
)
|
||||
},
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
details = await self._get_agent_details(
|
||||
creator=input_data.creator, slug=input_data.slug
|
||||
)
|
||||
yield "found", details.found
|
||||
yield "store_listing_version_id", details.store_listing_version_id
|
||||
yield "agent_name", details.agent_name
|
||||
yield "description", details.description
|
||||
yield "creator", details.creator
|
||||
yield "categories", details.categories
|
||||
yield "runs", details.runs
|
||||
yield "rating", details.rating
|
||||
|
||||
async def _get_agent_details(self, creator: str, slug: str) -> StoreAgentDetails:
|
||||
"""
|
||||
Retrieve detailed information about a store agent.
|
||||
"""
|
||||
# Get by specific version ID
|
||||
agent_details = (
|
||||
await get_database_manager_async_client().get_store_agent_details(
|
||||
username=creator, agent_name=slug
|
||||
)
|
||||
)
|
||||
|
||||
return StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id=agent_details.store_listing_version_id,
|
||||
agent_name=agent_details.agent_name,
|
||||
description=agent_details.description,
|
||||
creator=agent_details.creator,
|
||||
categories=(
|
||||
agent_details.categories if hasattr(agent_details, "categories") else []
|
||||
),
|
||||
runs=agent_details.runs,
|
||||
rating=agent_details.rating,
|
||||
)
|
||||
|
||||
|
||||
class SearchStoreAgentsBlock(Block):
|
||||
"""
|
||||
Block that searches for agents in the store based on various criteria.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
query: str | None = SchemaField(
|
||||
description="Search query to find agents", default=None
|
||||
)
|
||||
category: str | None = SchemaField(
|
||||
description="Filter by category", default=None
|
||||
)
|
||||
sort_by: Literal["rating", "runs", "name", "recent"] = SchemaField(
|
||||
description="How to sort the results", default="rating"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
agents: list[StoreAgent] = SchemaField(
|
||||
description="List of agents matching the search criteria",
|
||||
default_factory=list,
|
||||
)
|
||||
agent: StoreAgent = SchemaField(description="Basic information of the agent")
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of agents found", default=0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="39524701-026c-4328-87cc-1b88c8e2cb4c",
|
||||
description="Search for agents in the store",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=SearchStoreAgentsBlock.Input,
|
||||
output_schema=SearchStoreAgentsBlock.Output,
|
||||
test_input={
|
||||
"query": "productivity",
|
||||
"category": None,
|
||||
"sort_by": "rating",
|
||||
"limit": 10,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"agents",
|
||||
[
|
||||
{
|
||||
"slug": "test-agent",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"creator": "Test Creator",
|
||||
"rating": 4.5,
|
||||
"runs": 100,
|
||||
}
|
||||
],
|
||||
),
|
||||
("total_count", 1),
|
||||
(
|
||||
"agent",
|
||||
{
|
||||
"slug": "test-agent",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"creator": "Test Creator",
|
||||
"rating": 4.5,
|
||||
"runs": 100,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_search_agents": lambda *_, **__: SearchAgentsResponse(
|
||||
agents=[
|
||||
StoreAgentDict(
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="A test agent",
|
||||
creator="Test Creator",
|
||||
rating=4.5,
|
||||
runs=100,
|
||||
)
|
||||
],
|
||||
total_count=1,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self._search_agents(
|
||||
query=input_data.query,
|
||||
category=input_data.category,
|
||||
sort_by=input_data.sort_by,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
agents = result.agents
|
||||
total_count = result.total_count
|
||||
|
||||
# Convert to dict for output
|
||||
agents_as_dicts = [agent.model_dump() for agent in agents]
|
||||
|
||||
yield "agents", agents_as_dicts
|
||||
yield "total_count", total_count
|
||||
|
||||
for agent_dict in agents_as_dicts:
|
||||
yield "agent", agent_dict
|
||||
|
||||
async def _search_agents(
|
||||
self,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
sort_by: str = "rating",
|
||||
limit: int = 10,
|
||||
) -> SearchAgentsResponse:
|
||||
"""
|
||||
Search for agents in the store using the existing store database function.
|
||||
"""
|
||||
# Map our sort_by to the store's sorted_by parameter
|
||||
sorted_by_map = {
|
||||
"rating": "most_popular",
|
||||
"runs": "most_runs",
|
||||
"name": "alphabetical",
|
||||
"recent": "recently_updated",
|
||||
}
|
||||
|
||||
result = await get_database_manager_async_client().get_store_agents(
|
||||
featured=False,
|
||||
creators=None,
|
||||
sorted_by=sorted_by_map.get(sort_by, "most_popular"),
|
||||
search_query=query,
|
||||
category=category,
|
||||
page=1,
|
||||
page_size=limit,
|
||||
)
|
||||
|
||||
agents = [
|
||||
StoreAgentDict(
|
||||
slug=agent.slug,
|
||||
name=agent.agent_name,
|
||||
description=agent.description,
|
||||
creator=agent.creator,
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
)
|
||||
for agent in result.agents
|
||||
]
|
||||
|
||||
return SearchAgentsResponse(agents=agents, total_count=len(agents))
|
||||
@@ -0,0 +1,155 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.system.library_operations import (
|
||||
AddToLibraryFromStoreBlock,
|
||||
LibraryAgent,
|
||||
)
|
||||
from backend.blocks.system.store_operations import (
|
||||
GetStoreAgentDetailsBlock,
|
||||
SearchAgentsResponse,
|
||||
SearchStoreAgentsBlock,
|
||||
StoreAgentDetails,
|
||||
StoreAgentDict,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_to_library_from_store_block_success(mocker):
|
||||
"""Test successful addition of agent from store to library."""
|
||||
block = AddToLibraryFromStoreBlock()
|
||||
|
||||
# Mock the library agent response
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.id = "lib-agent-123"
|
||||
mock_library_agent.graph_id = "graph-456"
|
||||
mock_library_agent.graph_version = 1
|
||||
mock_library_agent.name = "Test Agent"
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_add_to_library",
|
||||
return_value=LibraryAgent(
|
||||
library_agent_id="lib-agent-123",
|
||||
agent_id="graph-456",
|
||||
agent_version=1,
|
||||
agent_name="Test Agent",
|
||||
),
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
store_listing_version_id="store-listing-v1", agent_name="Custom Agent Name"
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, user_id="test-user"):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["success"] is True
|
||||
assert outputs["library_agent_id"] == "lib-agent-123"
|
||||
assert outputs["agent_id"] == "graph-456"
|
||||
assert outputs["agent_version"] == 1
|
||||
assert outputs["agent_name"] == "Test Agent"
|
||||
assert outputs["message"] == "Agent successfully added to library"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_agent_details_block_success(mocker):
|
||||
"""Test successful retrieval of store agent details."""
|
||||
block = GetStoreAgentDetailsBlock()
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_get_agent_details",
|
||||
return_value=StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id="version-123",
|
||||
agent_name="Test Agent",
|
||||
description="A test agent for testing",
|
||||
creator="Test Creator",
|
||||
categories=["productivity", "automation"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
),
|
||||
)
|
||||
|
||||
input_data = block.Input(creator="Test Creator", slug="test-slug")
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["found"] is True
|
||||
assert outputs["store_listing_version_id"] == "version-123"
|
||||
assert outputs["agent_name"] == "Test Agent"
|
||||
assert outputs["description"] == "A test agent for testing"
|
||||
assert outputs["creator"] == "Test Creator"
|
||||
assert outputs["categories"] == ["productivity", "automation"]
|
||||
assert outputs["runs"] == 100
|
||||
assert outputs["rating"] == 4.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_store_agents_block(mocker):
|
||||
"""Test searching for store agents."""
|
||||
block = SearchStoreAgentsBlock()
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_search_agents",
|
||||
return_value=SearchAgentsResponse(
|
||||
agents=[
|
||||
StoreAgentDict(
|
||||
slug="creator1/agent1",
|
||||
name="Agent One",
|
||||
description="First test agent",
|
||||
creator="Creator 1",
|
||||
rating=4.8,
|
||||
runs=500,
|
||||
),
|
||||
StoreAgentDict(
|
||||
slug="creator2/agent2",
|
||||
name="Agent Two",
|
||||
description="Second test agent",
|
||||
creator="Creator 2",
|
||||
rating=4.2,
|
||||
runs=200,
|
||||
),
|
||||
],
|
||||
total_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
query="test", category="productivity", sort_by="rating", limit=10
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data):
|
||||
outputs[name] = value
|
||||
|
||||
assert len(outputs["agents"]) == 2
|
||||
assert outputs["total_count"] == 2
|
||||
assert outputs["agents"][0]["name"] == "Agent One"
|
||||
assert outputs["agents"][0]["rating"] == 4.8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_store_agents_block_empty_results(mocker):
|
||||
"""Test searching with no results."""
|
||||
block = SearchStoreAgentsBlock()
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_search_agents",
|
||||
return_value=SearchAgentsResponse(agents=[], total_count=0),
|
||||
)
|
||||
|
||||
input_data = block.Input(query="nonexistent", limit=10)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["agents"] == []
|
||||
assert outputs["total_count"] == 0
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Literal, Union
|
||||
@@ -7,6 +8,7 @@ from zoneinfo import ZoneInfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.execution import UserContext
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
# Shared timezone literal type for all time/date blocks
|
||||
@@ -51,16 +53,80 @@ TimezoneLiteral = Literal[
|
||||
"Etc/GMT+12", # UTC-12:00
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_timezone(
|
||||
format_type: Any, # Any format type with timezone and use_user_timezone attributes
|
||||
user_timezone: str | None,
|
||||
) -> ZoneInfo:
|
||||
"""
|
||||
Determine which timezone to use based on format settings and user context.
|
||||
|
||||
Args:
|
||||
format_type: The format configuration containing timezone settings
|
||||
user_timezone: The user's timezone from context
|
||||
|
||||
Returns:
|
||||
ZoneInfo object for the determined timezone
|
||||
"""
|
||||
if format_type.use_user_timezone and user_timezone:
|
||||
tz = ZoneInfo(user_timezone)
|
||||
logger.debug(f"Using user timezone: {user_timezone}")
|
||||
else:
|
||||
tz = ZoneInfo(format_type.timezone)
|
||||
logger.debug(f"Using specified timezone: {format_type.timezone}")
|
||||
return tz
|
||||
|
||||
|
||||
def _format_datetime_iso8601(dt: datetime, include_microseconds: bool = False) -> str:
|
||||
"""
|
||||
Format a datetime object to ISO8601 string.
|
||||
|
||||
Args:
|
||||
dt: The datetime object to format
|
||||
include_microseconds: Whether to include microseconds in the output
|
||||
|
||||
Returns:
|
||||
ISO8601 formatted string
|
||||
"""
|
||||
if include_microseconds:
|
||||
return dt.isoformat()
|
||||
else:
|
||||
return dt.isoformat(timespec="seconds")
|
||||
|
||||
|
||||
# BACKWARDS COMPATIBILITY NOTE:
|
||||
# The timezone field is kept at the format level (not block level) for backwards compatibility.
|
||||
# Existing graphs have timezone saved within format_type, moving it would break them.
|
||||
#
|
||||
# The use_user_timezone flag was added to allow using the user's profile timezone.
|
||||
# Default is False to maintain backwards compatibility - existing graphs will continue
|
||||
# using their specified timezone.
|
||||
#
|
||||
# KNOWN ISSUE: If a user switches between format types (strftime <-> iso8601),
|
||||
# the timezone setting doesn't carry over. This is a UX issue but fixing it would
|
||||
# require either:
|
||||
# 1. Moving timezone to block level (breaking change, needs migration)
|
||||
# 2. Complex state management to sync timezone across format types
|
||||
#
|
||||
# Future migration path: When we do a major version bump, consider moving timezone
|
||||
# to the block Input level for better UX.
|
||||
|
||||
|
||||
class TimeStrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%H:%M:%S"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class TimeISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
include_microseconds: bool = False
|
||||
|
||||
|
||||
@@ -115,25 +181,27 @@ class GetCurrentTimeBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
async def run(
|
||||
self, input_data: Input, *, user_context: UserContext, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Extract timezone from user_context (always present)
|
||||
effective_timezone = user_context.timezone
|
||||
|
||||
# Get the appropriate timezone
|
||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
if isinstance(input_data.format_type, TimeISO8601Format):
|
||||
# ISO 8601 format for time only (extract time portion from full ISO datetime)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
# Get the full ISO format and extract just the time portion with timezone
|
||||
if input_data.format_type.include_microseconds:
|
||||
full_iso = dt.isoformat()
|
||||
else:
|
||||
full_iso = dt.isoformat(timespec="seconds")
|
||||
|
||||
full_iso = _format_datetime_iso8601(
|
||||
dt, input_data.format_type.include_microseconds
|
||||
)
|
||||
# Extract time portion (everything after 'T')
|
||||
current_time = full_iso.split("T")[1] if "T" in full_iso else full_iso
|
||||
current_time = f"T{current_time}" # Add T prefix for ISO 8601 time format
|
||||
else: # TimeStrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
current_time = dt.strftime(input_data.format_type.format)
|
||||
|
||||
yield "time", current_time
|
||||
|
||||
|
||||
@@ -141,11 +209,15 @@ class DateStrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%Y-%m-%d"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class DateISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class GetCurrentDateBlock(Block):
|
||||
@@ -217,20 +289,23 @@ class GetCurrentDateBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Extract timezone from user_context (required keyword argument)
|
||||
user_context: UserContext = kwargs["user_context"]
|
||||
effective_timezone = user_context.timezone
|
||||
|
||||
try:
|
||||
offset = int(input_data.offset)
|
||||
except ValueError:
|
||||
offset = 0
|
||||
|
||||
# Get the appropriate timezone
|
||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
|
||||
if isinstance(input_data.format_type, DateISO8601Format):
|
||||
# ISO 8601 format for date only (YYYY-MM-DD)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
# ISO 8601 date format is YYYY-MM-DD
|
||||
date_str = current_date.date().isoformat()
|
||||
else: # DateStrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
date_str = current_date.strftime(input_data.format_type.format)
|
||||
|
||||
yield "date", date_str
|
||||
@@ -240,11 +315,15 @@ class StrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%Y-%m-%d %H:%M:%S"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class ISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
include_microseconds: bool = False
|
||||
|
||||
|
||||
@@ -316,20 +395,22 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Extract timezone from user_context (required keyword argument)
|
||||
user_context: UserContext = kwargs["user_context"]
|
||||
effective_timezone = user_context.timezone
|
||||
|
||||
# Get the appropriate timezone
|
||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
if isinstance(input_data.format_type, ISO8601Format):
|
||||
# ISO 8601 format with specified timezone (also RFC3339-compliant)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
# Format with or without microseconds
|
||||
if input_data.format_type.include_microseconds:
|
||||
current_date_time = dt.isoformat()
|
||||
else:
|
||||
current_date_time = dt.isoformat(timespec="seconds")
|
||||
current_date_time = _format_datetime_iso8601(
|
||||
dt, input_data.format_type.include_microseconds
|
||||
)
|
||||
else: # StrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
current_date_time = dt.strftime(input_data.format_type.format)
|
||||
|
||||
yield "date_time", current_date_time
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from logging import getLogger
|
||||
from typing import Any, Dict, List, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from pydantic import field_serializer
|
||||
|
||||
from backend.sdk import BaseModel, Credentials, Requests
|
||||
|
||||
logger = getLogger(__name__)
|
||||
@@ -382,8 +384,9 @@ class CreatePostRequest(BaseModel):
|
||||
# Advanced
|
||||
metadata: List[Dict[str, Any]] | None = None
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
@field_serializer("date")
|
||||
def serialize_date(self, value: datetime | None) -> str | None:
|
||||
return value.isoformat() if value else None
|
||||
|
||||
|
||||
class PostAuthor(BaseModel):
|
||||
|
||||
@@ -6,8 +6,6 @@ from dotenv import load_dotenv
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
|
||||
os.environ["ENABLE_AUTH"] = "false"
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
|
||||
|
||||
@@ -307,7 +307,18 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
"type": ideogram_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=18,
|
||||
cost_filter={
|
||||
"ideogram_model_name": "V_3",
|
||||
"credentials": {
|
||||
"id": ideogram_credentials.id,
|
||||
"provider": ideogram_credentials.provider,
|
||||
"type": ideogram_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
AIShortformVideoCreatorBlock: [
|
||||
BlockCost(
|
||||
|
||||
@@ -7,7 +7,7 @@ from prisma.models import CreditTransaction
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.execution import NodeExecutionEntry, UserContext
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.credentials_store import openai_credentials
|
||||
@@ -75,6 +75,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
"type": openai_credentials.type,
|
||||
},
|
||||
},
|
||||
user_context=UserContext(timezone="UTC"),
|
||||
),
|
||||
)
|
||||
assert spending_amount_1 > 0
|
||||
@@ -88,6 +89,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
node_exec_id="test_node_exec",
|
||||
block_id=AITextGeneratorBlock().id,
|
||||
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
user_context=UserContext(timezone="UTC"),
|
||||
),
|
||||
)
|
||||
assert spending_amount_2 == 0
|
||||
|
||||
@@ -39,6 +39,7 @@ from pydantic.fields import Field
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Config
|
||||
from backend.util.truncate import truncate
|
||||
@@ -89,6 +90,7 @@ ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
id: str # type: ignore # Override base class to make this required
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
@@ -290,13 +292,14 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
def to_graph_execution_entry(self):
|
||||
def to_graph_execution_entry(self, user_context: "UserContext"):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -312,9 +315,10 @@ class NodeExecutionResult(BaseModel):
|
||||
input_data: BlockInput
|
||||
output_data: CompletedBlockOutput
|
||||
add_time: datetime
|
||||
queue_time: datetime | None
|
||||
start_time: datetime | None
|
||||
end_time: datetime | None
|
||||
queue_time: datetime | None = None
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
stats: NodeExecutionStats | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
@@ -366,9 +370,12 @@ class NodeExecutionResult(BaseModel):
|
||||
queue_time=_node_exec.queuedTime,
|
||||
start_time=_node_exec.startedTime,
|
||||
end_time=_node_exec.endedTime,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
def to_node_execution_entry(self) -> "NodeExecutionEntry":
|
||||
def to_node_execution_entry(
|
||||
self, user_context: "UserContext"
|
||||
) -> "NodeExecutionEntry":
|
||||
return NodeExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_exec_id=self.graph_exec_id,
|
||||
@@ -377,6 +384,7 @@ class NodeExecutionResult(BaseModel):
|
||||
node_id=self.node_id,
|
||||
block_id=self.block_id,
|
||||
inputs=self.input_data,
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -384,13 +392,13 @@ class NodeExecutionResult(BaseModel):
|
||||
|
||||
|
||||
async def get_graph_executions(
|
||||
graph_exec_id: str | None = None,
|
||||
graph_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
created_time_gte: datetime | None = None,
|
||||
created_time_lte: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
graph_exec_id: Optional[str] = None,
|
||||
graph_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
statuses: Optional[list[ExecutionStatus]] = None,
|
||||
created_time_gte: Optional[datetime] = None,
|
||||
created_time_lte: Optional[datetime] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
@@ -418,6 +426,60 @@ async def get_graph_executions(
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
class GraphExecutionsPaginated(BaseModel):
|
||||
"""Response schema for paginated graph executions."""
|
||||
|
||||
executions: list[GraphExecutionMeta]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
async def get_graph_executions_paginated(
|
||||
user_id: str,
|
||||
graph_id: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
statuses: Optional[list[ExecutionStatus]] = None,
|
||||
created_time_gte: Optional[datetime] = None,
|
||||
created_time_lte: Optional[datetime] = None,
|
||||
) -> GraphExecutionsPaginated:
|
||||
"""Get paginated graph executions for a specific graph."""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
"userId": user_id,
|
||||
}
|
||||
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
if created_time_gte or created_time_lte:
|
||||
where_filter["createdAt"] = {
|
||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
if statuses:
|
||||
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
total_count = await AgentGraphExecution.prisma().count(where=where_filter)
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where=where_filter,
|
||||
order={"createdAt": "desc"},
|
||||
take=page_size,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
return GraphExecutionsPaginated(
|
||||
executions=[GraphExecutionMeta.from_db(execution) for execution in executions],
|
||||
pagination=Pagination(
|
||||
total_items=total_count,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_execution_meta(
|
||||
user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
@@ -594,6 +656,42 @@ async def upsert_execution_input(
|
||||
)
|
||||
|
||||
|
||||
async def create_node_execution(
|
||||
node_exec_id: str,
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
) -> None:
|
||||
"""Create a new node execution with the first input."""
|
||||
json_input_data = SafeJson(input_data)
|
||||
await AgentNodeExecution.prisma().create(
|
||||
data=AgentNodeExecutionCreateInput(
|
||||
id=node_exec_id,
|
||||
agentNodeId=node_id,
|
||||
agentGraphExecutionId=graph_exec_id,
|
||||
executionStatus=ExecutionStatus.INCOMPLETE,
|
||||
Input={"create": {"name": input_name, "data": json_input_data}},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def add_input_to_node_execution(
|
||||
node_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
) -> None:
|
||||
"""Add an input to an existing node execution."""
|
||||
json_input_data = SafeJson(input_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=input_name,
|
||||
data=json_input_data,
|
||||
referencedByInputExecId=node_exec_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
@@ -817,12 +915,19 @@ async def get_latest_node_execution(
|
||||
# ----------------- Execution Infrastructure ----------------- #
|
||||
|
||||
|
||||
class UserContext(BaseModel):
|
||||
"""Generic user context for graph execution containing user-specific settings."""
|
||||
|
||||
timezone: str
|
||||
|
||||
|
||||
class GraphExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
|
||||
user_context: UserContext
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
@@ -833,6 +938,7 @@ class NodeExecutionEntry(BaseModel):
|
||||
node_id: str
|
||||
block_id: str
|
||||
inputs: BlockInput
|
||||
user_context: UserContext
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.exceptions
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
@@ -317,12 +316,7 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
),
|
||||
autogpt_libs.auth.models.User(
|
||||
user_id=admin_user.id,
|
||||
role="admin",
|
||||
email=admin_user.email,
|
||||
phone_number="1234567890",
|
||||
),
|
||||
user_id=admin_user.id,
|
||||
)
|
||||
|
||||
# Now we check the graph can be accessed by a user that does not own the graph
|
||||
|
||||
@@ -96,6 +96,12 @@ class User(BaseModel):
|
||||
default=True, description="Notify on monthly summary"
|
||||
)
|
||||
|
||||
# User timezone for scheduling and time display
|
||||
timezone: str = Field(
|
||||
default="not-set",
|
||||
description="User timezone (IANA timezone identifier or 'not-set')",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||
"""Convert a database User object to application User model."""
|
||||
@@ -149,6 +155,7 @@ class User(BaseModel):
|
||||
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
|
||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||
timezone=prisma_user.timezone or "not-set",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -54,19 +54,6 @@ class AgentRunData(BaseNotificationData):
|
||||
|
||||
|
||||
class ZeroBalanceData(BaseNotificationData):
|
||||
last_transaction: float
|
||||
last_transaction_time: datetime
|
||||
top_up_link: str
|
||||
|
||||
@field_validator("last_transaction_time")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class LowBalanceData(BaseNotificationData):
|
||||
agent_name: str = Field(..., description="Name of the agent")
|
||||
current_balance: float = Field(
|
||||
..., description="Current balance in credits (100 = $1)"
|
||||
@@ -75,6 +62,13 @@ class LowBalanceData(BaseNotificationData):
|
||||
shortfall: float = Field(..., description="Amount of credits needed to continue")
|
||||
|
||||
|
||||
class LowBalanceData(BaseNotificationData):
|
||||
current_balance: float = Field(
|
||||
..., description="Current balance in credits (100 = $1)"
|
||||
)
|
||||
billing_page_link: str = Field(..., description="Link to billing page")
|
||||
|
||||
|
||||
class BlockExecutionFailedData(BaseNotificationData):
|
||||
block_name: str
|
||||
block_id: str
|
||||
@@ -181,6 +175,42 @@ class RefundRequestData(BaseNotificationData):
|
||||
balance: int
|
||||
|
||||
|
||||
class AgentApprovalData(BaseNotificationData):
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
reviewer_name: str
|
||||
reviewer_email: str
|
||||
comments: str
|
||||
reviewed_at: datetime
|
||||
store_url: str
|
||||
|
||||
@field_validator("reviewed_at")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class AgentRejectionData(BaseNotificationData):
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
reviewer_name: str
|
||||
reviewer_email: str
|
||||
comments: str
|
||||
reviewed_at: datetime
|
||||
resubmit_url: str
|
||||
|
||||
@field_validator("reviewed_at")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
NotificationData = Annotated[
|
||||
Union[
|
||||
AgentRunData,
|
||||
@@ -240,6 +270,8 @@ def get_notif_data_type(
|
||||
NotificationType.MONTHLY_SUMMARY: MonthlySummaryData,
|
||||
NotificationType.REFUND_REQUEST: RefundRequestData,
|
||||
NotificationType.REFUND_PROCESSED: RefundRequestData,
|
||||
NotificationType.AGENT_APPROVED: AgentApprovalData,
|
||||
NotificationType.AGENT_REJECTED: AgentRejectionData,
|
||||
}[notification_type]
|
||||
|
||||
|
||||
@@ -274,7 +306,7 @@ class NotificationTypeOverride:
|
||||
# These are batched by the notification service
|
||||
NotificationType.AGENT_RUN: QueueType.BATCH,
|
||||
# These are batched by the notification service, but with a backoff strategy
|
||||
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
|
||||
NotificationType.ZERO_BALANCE: QueueType.IMMEDIATE,
|
||||
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
|
||||
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
|
||||
@@ -283,6 +315,8 @@ class NotificationTypeOverride:
|
||||
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
|
||||
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
|
||||
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
|
||||
NotificationType.AGENT_APPROVED: QueueType.IMMEDIATE,
|
||||
NotificationType.AGENT_REJECTED: QueueType.IMMEDIATE,
|
||||
}
|
||||
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
|
||||
|
||||
@@ -300,6 +334,8 @@ class NotificationTypeOverride:
|
||||
NotificationType.MONTHLY_SUMMARY: "monthly_summary.html",
|
||||
NotificationType.REFUND_REQUEST: "refund_request.html",
|
||||
NotificationType.REFUND_PROCESSED: "refund_processed.html",
|
||||
NotificationType.AGENT_APPROVED: "agent_approved.html",
|
||||
NotificationType.AGENT_REJECTED: "agent_rejected.html",
|
||||
}[self.notification_type]
|
||||
|
||||
@property
|
||||
@@ -315,6 +351,8 @@ class NotificationTypeOverride:
|
||||
NotificationType.MONTHLY_SUMMARY: "We did a lot this month!",
|
||||
NotificationType.REFUND_REQUEST: "[ACTION REQUIRED] You got a ${{data.amount / 100}} refund request from {{data.user_name}}",
|
||||
NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed",
|
||||
NotificationType.AGENT_APPROVED: "🎉 Your agent '{{data.agent_name}}' has been approved!",
|
||||
NotificationType.AGENT_REJECTED: "Your agent '{{data.agent_name}}' needs some updates",
|
||||
}[self.notification_type]
|
||||
|
||||
|
||||
|
||||
151
autogpt_platform/backend/backend/data/notifications_test.py
Normal file
151
autogpt_platform/backend/backend/data/notifications_test.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for notification data models."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.data.notifications import AgentApprovalData, AgentRejectionData
|
||||
|
||||
|
||||
class TestAgentApprovalData:
|
||||
"""Test cases for AgentApprovalData model."""
|
||||
|
||||
def test_valid_agent_approval_data(self):
|
||||
"""Test creating valid AgentApprovalData."""
|
||||
data = AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="Great agent, approved!",
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.agent_id == "test-agent-123"
|
||||
assert data.agent_version == 1
|
||||
assert data.reviewer_name == "John Doe"
|
||||
assert data.reviewer_email == "john@example.com"
|
||||
assert data.comments == "Great agent, approved!"
|
||||
assert data.store_url == "https://app.autogpt.com/store/test-agent-123"
|
||||
assert data.reviewed_at.tzinfo is not None
|
||||
|
||||
def test_agent_approval_data_without_timezone_raises_error(self):
|
||||
"""Test that AgentApprovalData raises error without timezone."""
|
||||
with pytest.raises(
|
||||
ValidationError, match="datetime must have timezone information"
|
||||
):
|
||||
AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="Great agent, approved!",
|
||||
reviewed_at=datetime.now(), # No timezone
|
||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
||||
)
|
||||
|
||||
def test_agent_approval_data_with_empty_comments(self):
|
||||
"""Test AgentApprovalData with empty comments."""
|
||||
data = AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="", # Empty comments
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.comments == ""
|
||||
|
||||
|
||||
class TestAgentRejectionData:
|
||||
"""Test cases for AgentRejectionData model."""
|
||||
|
||||
def test_valid_agent_rejection_data(self):
|
||||
"""Test creating valid AgentRejectionData."""
|
||||
data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the security issues before resubmitting.",
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.agent_id == "test-agent-123"
|
||||
assert data.agent_version == 1
|
||||
assert data.reviewer_name == "Jane Doe"
|
||||
assert data.reviewer_email == "jane@example.com"
|
||||
assert data.comments == "Please fix the security issues before resubmitting."
|
||||
assert data.resubmit_url == "https://app.autogpt.com/build/test-agent-123"
|
||||
assert data.reviewed_at.tzinfo is not None
|
||||
|
||||
def test_agent_rejection_data_without_timezone_raises_error(self):
|
||||
"""Test that AgentRejectionData raises error without timezone."""
|
||||
with pytest.raises(
|
||||
ValidationError, match="datetime must have timezone information"
|
||||
):
|
||||
AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the security issues.",
|
||||
reviewed_at=datetime.now(), # No timezone
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
def test_agent_rejection_data_with_long_comments(self):
|
||||
"""Test AgentRejectionData with long comments."""
|
||||
long_comment = "A" * 1000 # Very long comment
|
||||
data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments=long_comment,
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.comments == long_comment
|
||||
|
||||
def test_model_serialization(self):
|
||||
"""Test that models can be serialized and deserialized."""
|
||||
original_data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the issues.",
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
# Serialize to dict
|
||||
data_dict = original_data.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_data = AgentRejectionData.model_validate(data_dict)
|
||||
|
||||
assert restored_data.agent_name == original_data.agent_name
|
||||
assert restored_data.agent_id == original_data.agent_id
|
||||
assert restored_data.agent_version == original_data.agent_version
|
||||
assert restored_data.reviewer_name == original_data.reviewer_name
|
||||
assert restored_data.reviewer_email == original_data.reviewer_email
|
||||
assert restored_data.comments == original_data.comments
|
||||
assert restored_data.reviewed_at == original_data.reviewed_at
|
||||
assert restored_data.resubmit_url == original_data.resubmit_url
|
||||
@@ -208,6 +208,8 @@ async def get_user_notification_preference(user_id: str) -> NotificationPreferen
|
||||
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or False,
|
||||
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or False,
|
||||
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or False,
|
||||
NotificationType.AGENT_APPROVED: user.notifyOnAgentApproved or False,
|
||||
NotificationType.AGENT_REJECTED: user.notifyOnAgentRejected or False,
|
||||
}
|
||||
daily_limit = user.maxEmailsPerDay or 3
|
||||
notification_preference = NotificationPreference(
|
||||
@@ -266,6 +268,14 @@ async def update_user_notification_preference(
|
||||
update_data["notifyOnMonthlySummary"] = data.preferences[
|
||||
NotificationType.MONTHLY_SUMMARY
|
||||
]
|
||||
if NotificationType.AGENT_APPROVED in data.preferences:
|
||||
update_data["notifyOnAgentApproved"] = data.preferences[
|
||||
NotificationType.AGENT_APPROVED
|
||||
]
|
||||
if NotificationType.AGENT_REJECTED in data.preferences:
|
||||
update_data["notifyOnAgentRejected"] = data.preferences[
|
||||
NotificationType.AGENT_REJECTED
|
||||
]
|
||||
if data.daily_limit:
|
||||
update_data["maxEmailsPerDay"] = data.daily_limit
|
||||
|
||||
@@ -286,6 +296,8 @@ async def update_user_notification_preference(
|
||||
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or True,
|
||||
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or True,
|
||||
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or True,
|
||||
NotificationType.AGENT_APPROVED: user.notifyOnAgentApproved or True,
|
||||
NotificationType.AGENT_REJECTED: user.notifyOnAgentRejected or True,
|
||||
}
|
||||
notification_preference = NotificationPreference(
|
||||
user_id=user.id,
|
||||
@@ -384,3 +396,17 @@ async def unsubscribe_user_by_token(token: str) -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e
|
||||
|
||||
|
||||
async def update_user_timezone(user_id: str, timezone: str) -> User:
|
||||
"""Update a user's timezone setting."""
|
||||
try:
|
||||
user = await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"timezone": timezone},
|
||||
)
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
return User.from_db(user)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to update timezone for user {user_id}: {e}") from e
|
||||
|
||||
@@ -115,7 +115,7 @@ async def generate_activity_status_for_execution(
|
||||
|
||||
# Get all node executions for this graph execution
|
||||
node_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, include_exec_data=True
|
||||
graph_exec_id=graph_exec_id, include_exec_data=True
|
||||
)
|
||||
|
||||
# Get graph metadata and full graph structure for name, description, and links
|
||||
|
||||
@@ -4,12 +4,13 @@ from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
add_input_to_node_execution,
|
||||
create_graph_execution,
|
||||
create_node_execution,
|
||||
get_block_error_stats,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
set_execution_kv_data,
|
||||
@@ -17,7 +18,6 @@ from backend.data.execution import (
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_status,
|
||||
update_node_execution_status_batch,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.generate_data import get_user_execution_summary_data
|
||||
@@ -42,6 +42,8 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
|
||||
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -103,13 +105,13 @@ class DatabaseManager(AppService):
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution = _(get_node_execution)
|
||||
get_node_executions = _(get_node_executions)
|
||||
get_latest_node_execution = _(get_latest_node_execution)
|
||||
update_node_execution_status = _(update_node_execution_status)
|
||||
update_node_execution_status_batch = _(update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
create_node_execution = _(create_node_execution)
|
||||
add_input_to_node_execution = _(add_input_to_node_execution)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
get_block_error_stats = _(get_block_error_stats)
|
||||
@@ -145,6 +147,14 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(list_library_agents)
|
||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
@@ -161,10 +171,12 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_graph_executions = _(d.get_graph_executions)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
create_node_execution = _(d.create_node_execution)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(d.update_graph_execution_stats)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
add_input_to_node_execution = _(d.add_input_to_node_execution)
|
||||
|
||||
# Graphs
|
||||
get_graph_metadata = _(d.get_graph_metadata)
|
||||
@@ -173,12 +185,12 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
spend_credits = _(d.spend_credits)
|
||||
get_credits = _(d.get_credits)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(d.get_user_execution_summary_data)
|
||||
|
||||
# Block error monitoring
|
||||
get_block_error_stats = _(d.get_block_error_stats)
|
||||
|
||||
# User Emails
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -189,16 +201,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
@@ -223,5 +231,13 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
list_library_agents = d.list_library_agents
|
||||
add_store_agent_to_library = d.add_store_agent_to_library
|
||||
|
||||
# Store
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
|
||||
# Summary data
|
||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||
|
||||
154
autogpt_platform/backend/backend/executor/execution_cache.py
Normal file
154
autogpt_platform/backend/backend/executor/execution_cache.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def with_lock(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self._lock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExecutionCache:
|
||||
def __init__(self, graph_exec_id: str, db_client: "DatabaseManagerClient"):
|
||||
self._lock = threading.RLock()
|
||||
self._graph_exec_id = graph_exec_id
|
||||
self._graph_stats: GraphExecutionStats = GraphExecutionStats()
|
||||
self._node_executions: OrderedDict[str, NodeExecutionResult] = OrderedDict()
|
||||
|
||||
for execution in db_client.get_node_executions(self._graph_exec_id):
|
||||
self._node_executions[execution.node_exec_id] = execution
|
||||
|
||||
@with_lock
|
||||
def get_node_execution(self, node_exec_id: str) -> Optional[NodeExecutionResult]:
|
||||
execution = self._node_executions.get(node_exec_id)
|
||||
return execution.model_copy(deep=True) if execution else None
|
||||
|
||||
@with_lock
|
||||
def get_latest_node_execution(self, node_id: str) -> Optional[NodeExecutionResult]:
|
||||
for execution in reversed(self._node_executions.values()):
|
||||
if (
|
||||
execution.node_id == node_id
|
||||
and execution.status != ExecutionStatus.INCOMPLETE
|
||||
):
|
||||
return execution.model_copy(deep=True)
|
||||
return None
|
||||
|
||||
@with_lock
|
||||
def get_node_executions(
|
||||
self,
|
||||
*,
|
||||
statuses: Optional[list] = None,
|
||||
block_ids: Optional[list] = None,
|
||||
node_id: Optional[str] = None,
|
||||
):
|
||||
results = []
|
||||
for execution in self._node_executions.values():
|
||||
if statuses and execution.status not in statuses:
|
||||
continue
|
||||
if block_ids and execution.block_id not in block_ids:
|
||||
continue
|
||||
if node_id and execution.node_id != node_id:
|
||||
continue
|
||||
results.append(execution.model_copy(deep=True))
|
||||
return results
|
||||
|
||||
@with_lock
|
||||
def update_node_execution_status(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: Optional[dict] = None,
|
||||
stats: Optional[dict] = None,
|
||||
):
|
||||
if exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {exec_id} not found in cache")
|
||||
|
||||
execution = self._node_executions[exec_id]
|
||||
execution.status = status
|
||||
|
||||
if execution_data:
|
||||
execution.input_data.update(execution_data)
|
||||
|
||||
if stats:
|
||||
execution.stats = execution.stats or NodeExecutionStats()
|
||||
current_stats = execution.stats.model_dump()
|
||||
current_stats.update(stats)
|
||||
execution.stats = NodeExecutionStats.model_validate(current_stats)
|
||||
|
||||
@with_lock
|
||||
def upsert_execution_output(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
) -> NodeExecutionResult:
|
||||
if node_exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {node_exec_id} not found in cache")
|
||||
|
||||
execution = self._node_executions[node_exec_id]
|
||||
if output_name not in execution.output_data:
|
||||
execution.output_data[output_name] = []
|
||||
execution.output_data[output_name].append(output_data)
|
||||
|
||||
return execution
|
||||
|
||||
@with_lock
|
||||
def update_graph_stats(
|
||||
self, status: Optional[ExecutionStatus] = None, stats: Optional[dict] = None
|
||||
):
|
||||
if status is not None:
|
||||
pass
|
||||
if stats is not None:
|
||||
current_stats = self._graph_stats.model_dump()
|
||||
current_stats.update(stats)
|
||||
self._graph_stats = GraphExecutionStats.model_validate(current_stats)
|
||||
|
||||
@with_lock
|
||||
def update_graph_start_time(self):
|
||||
"""Update graph start time (handled by database persistence)."""
|
||||
pass
|
||||
|
||||
@with_lock
|
||||
def find_incomplete_execution_for_input(
|
||||
self, node_id: str, input_name: str
|
||||
) -> tuple[str, NodeExecutionResult] | None:
|
||||
for exec_id, execution in self._node_executions.items():
|
||||
if (
|
||||
execution.node_id == node_id
|
||||
and execution.status == ExecutionStatus.INCOMPLETE
|
||||
and input_name not in execution.input_data
|
||||
):
|
||||
return exec_id, execution
|
||||
return None
|
||||
|
||||
@with_lock
|
||||
def add_node_execution(
|
||||
self, node_exec_id: str, execution: NodeExecutionResult
|
||||
) -> None:
|
||||
self._node_executions[node_exec_id] = execution
|
||||
|
||||
@with_lock
|
||||
def update_execution_input(
|
||||
self, exec_id: str, input_name: str, input_data: Any
|
||||
) -> dict:
|
||||
if exec_id not in self._node_executions:
|
||||
raise RuntimeError(f"Execution {exec_id} not found in cache")
|
||||
execution = self._node_executions[exec_id]
|
||||
execution.input_data[input_name] = input_data
|
||||
return execution.input_data.copy()
|
||||
|
||||
def finalize(self) -> None:
|
||||
with self._lock:
|
||||
self._node_executions.clear()
|
||||
self._graph_stats = GraphExecutionStats()
|
||||
@@ -0,0 +1,355 @@
|
||||
"""Test execution creation with proper ID generation and persistence."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def execution_client_with_mock_db(event_loop):
|
||||
"""Create an ExecutionDataClient with proper database records."""
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, User
|
||||
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
|
||||
# Create test database records to satisfy foreign key constraints
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": "test_user_123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
)
|
||||
|
||||
await AgentGraph.prisma().create(
|
||||
data={
|
||||
"id": "test_graph_456",
|
||||
"version": 1,
|
||||
"userId": "test_user_123",
|
||||
"name": "Test Graph",
|
||||
"description": "Test graph for execution tests",
|
||||
}
|
||||
)
|
||||
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
|
||||
await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"id": "test_graph_exec_id",
|
||||
"userId": "test_user_123",
|
||||
"agentGraphId": "test_graph_456",
|
||||
"agentGraphVersion": 1,
|
||||
"executionStatus": AgentExecutionStatus.RUNNING,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Records might already exist, that's fine
|
||||
pass
|
||||
|
||||
# Mock the graph execution metadata - align with assertions below
|
||||
mock_graph_meta = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_123",
|
||||
graph_id="test_graph_456",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# Create client with ThreadPoolExecutor and graph metadata (constructed inside patch)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Storage for tracking created executions
|
||||
created_executions = []
|
||||
|
||||
async def mock_create_node_execution(
|
||||
node_exec_id, node_id, graph_exec_id, input_name, input_data
|
||||
):
|
||||
"""Mock execution creation that records what was created."""
|
||||
created_executions.append(
|
||||
{
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"input_name": input_name,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
return node_exec_id
|
||||
|
||||
def sync_mock_create_node_execution(
|
||||
node_exec_id, node_id, graph_exec_id, input_name, input_data
|
||||
):
|
||||
"""Mock sync execution creation that records what was created."""
|
||||
created_executions.append(
|
||||
{
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"input_name": input_name,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
return node_exec_id
|
||||
|
||||
# Prepare mock async and sync DB clients
|
||||
async_mock_client = AsyncMock()
|
||||
async_mock_client.create_node_execution = mock_create_node_execution
|
||||
|
||||
sync_mock_client = MagicMock()
|
||||
sync_mock_client.create_node_execution = sync_mock_create_node_execution
|
||||
# Mock graph execution for return values
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
|
||||
mock_graph_update = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_123",
|
||||
graph_id="test_graph_456",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# No-ops for other sync methods used by the client during tests
|
||||
sync_mock_client.add_input_to_node_execution.side_effect = lambda **kwargs: None
|
||||
sync_mock_client.update_node_execution_status.side_effect = (
|
||||
lambda *args, **kwargs: None
|
||||
)
|
||||
sync_mock_client.upsert_execution_output.side_effect = lambda **kwargs: None
|
||||
sync_mock_client.update_graph_execution_stats.side_effect = (
|
||||
lambda *args, **kwargs: mock_graph_update
|
||||
)
|
||||
sync_mock_client.update_graph_execution_start_time.side_effect = (
|
||||
lambda *args, **kwargs: mock_graph_update
|
||||
)
|
||||
|
||||
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
with patch(
|
||||
"backend.executor.execution_data.get_database_manager_async_client",
|
||||
return_value=async_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_database_manager_client",
|
||||
return_value=sync_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_execution_event_bus"
|
||||
), patch(
|
||||
"backend.executor.execution_data.non_blocking_persist", lambda func: func
|
||||
):
|
||||
# Now construct the client under the patch so it captures the mocked clients
|
||||
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
|
||||
# Store the mocks for the test to access if needed
|
||||
setattr(client, "_test_async_client", async_mock_client)
|
||||
setattr(client, "_test_sync_client", sync_mock_client)
|
||||
setattr(client, "_created_executions", created_executions)
|
||||
yield client
|
||||
|
||||
# Cleanup test database records
|
||||
try:
|
||||
await AgentGraphExecution.prisma().delete_many(
|
||||
where={"id": "test_graph_exec_id"}
|
||||
)
|
||||
await AgentGraph.prisma().delete_many(where={"id": "test_graph_456"})
|
||||
await User.prisma().delete_many(where={"id": "test_user_123"})
|
||||
except Exception:
|
||||
# Cleanup may fail if records don't exist
|
||||
pass
|
||||
|
||||
# Cleanup
|
||||
event_loop.call_soon_threadsafe(event_loop.stop)
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestExecutionCreation:
|
||||
"""Test execution creation with proper ID generation and persistence."""
|
||||
|
||||
async def test_execution_creation_with_valid_ids(
|
||||
self, execution_client_with_mock_db
|
||||
):
|
||||
"""Test that execution creation generates and persists valid IDs."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
node_id = "test_node_789"
|
||||
input_name = "test_input"
|
||||
input_data = "test_value"
|
||||
block_id = "test_block_abc"
|
||||
|
||||
# This should trigger execution creation since cache is empty
|
||||
exec_id, input_dict = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Verify execution ID is valid UUID
|
||||
try:
|
||||
uuid.UUID(exec_id)
|
||||
except ValueError:
|
||||
pytest.fail(f"Generated execution ID '{exec_id}' is not a valid UUID")
|
||||
|
||||
# Verify execution was created in cache with complete data
|
||||
assert exec_id in client._cache._node_executions
|
||||
cached_execution = client._cache._node_executions[exec_id]
|
||||
|
||||
# Check all required fields have valid values
|
||||
assert cached_execution.user_id == "test_user_123"
|
||||
assert cached_execution.graph_id == "test_graph_456"
|
||||
assert cached_execution.graph_version == 1
|
||||
assert cached_execution.graph_exec_id == "test_graph_exec_id"
|
||||
assert cached_execution.node_exec_id == exec_id
|
||||
assert cached_execution.node_id == node_id
|
||||
assert cached_execution.block_id == block_id
|
||||
assert cached_execution.status == ExecutionStatus.INCOMPLETE
|
||||
assert cached_execution.input_data == {input_name: input_data}
|
||||
assert isinstance(cached_execution.add_time, datetime)
|
||||
|
||||
# Verify execution was persisted to database with our generated ID
|
||||
created_executions = getattr(client, "_created_executions", [])
|
||||
assert len(created_executions) == 1
|
||||
created = created_executions[0]
|
||||
assert created["node_exec_id"] == exec_id # Our generated ID was used
|
||||
assert created["node_id"] == node_id
|
||||
assert created["graph_exec_id"] == "test_graph_exec_id"
|
||||
assert created["input_name"] == input_name
|
||||
assert created["input_data"] == input_data
|
||||
|
||||
# Verify input dict returned correctly
|
||||
assert input_dict == {input_name: input_data}
|
||||
|
||||
async def test_execution_reuse_vs_creation(self, execution_client_with_mock_db):
|
||||
"""Test that execution reuse works and creation only happens when needed."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
node_id = "reuse_test_node"
|
||||
block_id = "reuse_test_block"
|
||||
|
||||
# Create first execution
|
||||
exec_id_1, input_dict_1 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_1",
|
||||
input_data="value_1",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# This should reuse the existing INCOMPLETE execution
|
||||
exec_id_2, input_dict_2 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_2",
|
||||
input_data="value_2",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Should reuse the same execution
|
||||
assert exec_id_1 == exec_id_2
|
||||
assert input_dict_2 == {"input_1": "value_1", "input_2": "value_2"}
|
||||
|
||||
# Only one execution should be created in database
|
||||
created_executions = getattr(client, "_created_executions", [])
|
||||
assert len(created_executions) == 1
|
||||
|
||||
# Verify cache has the merged inputs
|
||||
cached_execution = client._cache._node_executions[exec_id_1]
|
||||
assert cached_execution.input_data == {
|
||||
"input_1": "value_1",
|
||||
"input_2": "value_2",
|
||||
}
|
||||
|
||||
# Now complete the execution and try to add another input
|
||||
client.update_node_status_and_publish(
|
||||
exec_id=exec_id_1, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
# Verify the execution status was actually updated in the cache
|
||||
updated_execution = client._cache._node_executions[exec_id_1]
|
||||
assert (
|
||||
updated_execution.status == ExecutionStatus.COMPLETED
|
||||
), f"Expected COMPLETED but got {updated_execution.status}"
|
||||
|
||||
# This should create a NEW execution since the first is no longer INCOMPLETE
|
||||
exec_id_3, input_dict_3 = client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="input_3",
|
||||
input_data="value_3",
|
||||
block_id=block_id,
|
||||
)
|
||||
|
||||
# Should be a different execution
|
||||
assert exec_id_3 != exec_id_1
|
||||
assert input_dict_3 == {"input_3": "value_3"}
|
||||
|
||||
# Verify cache behavior: should have two different executions in cache now
|
||||
cached_executions = client._cache._node_executions
|
||||
assert len(cached_executions) == 2
|
||||
assert exec_id_1 in cached_executions
|
||||
assert exec_id_3 in cached_executions
|
||||
|
||||
# First execution should be COMPLETED
|
||||
assert cached_executions[exec_id_1].status == ExecutionStatus.COMPLETED
|
||||
# Third execution should be INCOMPLETE (newly created)
|
||||
assert cached_executions[exec_id_3].status == ExecutionStatus.INCOMPLETE
|
||||
|
||||
async def test_multiple_nodes_get_different_execution_ids(
|
||||
self, execution_client_with_mock_db
|
||||
):
|
||||
"""Test that different nodes get different execution IDs."""
|
||||
client = execution_client_with_mock_db
|
||||
|
||||
# Create executions for different nodes
|
||||
exec_id_a, _ = client.upsert_execution_input(
|
||||
node_id="node_a",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="block_a",
|
||||
)
|
||||
|
||||
exec_id_b, _ = client.upsert_execution_input(
|
||||
node_id="node_b",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="block_b",
|
||||
)
|
||||
|
||||
# Should be different executions with different IDs
|
||||
assert exec_id_a != exec_id_b
|
||||
|
||||
# Both should be valid UUIDs
|
||||
uuid.UUID(exec_id_a)
|
||||
uuid.UUID(exec_id_b)
|
||||
|
||||
# Both should be in cache
|
||||
cached_executions = client._cache._node_executions
|
||||
assert len(cached_executions) == 2
|
||||
assert exec_id_a in cached_executions
|
||||
assert exec_id_b in cached_executions
|
||||
|
||||
# Both should have correct node IDs
|
||||
assert cached_executions[exec_id_a].node_id == "node_a"
|
||||
assert cached_executions[exec_id_b].node_id == "node_b"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
338
autogpt_platform/backend/backend/executor/execution_data.py
Normal file
338
autogpt_platform/backend/backend/executor/execution_data.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionMeta,
|
||||
NodeExecutionResult,
|
||||
)
|
||||
from backend.data.graph import Node
|
||||
from backend.data.model import GraphExecutionStats
|
||||
from backend.executor.execution_cache import ExecutionCache
|
||||
from backend.util.clients import (
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def non_blocking_persist(func: Callable[P, T]) -> Callable[P, None]:
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
# First argument is always self for methods - access through cast for typing
|
||||
self = cast("ExecutionDataClient", args[0])
|
||||
future = self._executor.submit(func, *args, **kwargs)
|
||||
self._pending_tasks.add(future)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExecutionDataClient:
|
||||
def __init__(
|
||||
self, executor: Executor, graph_exec_id: str, graph_metadata: GraphExecutionMeta
|
||||
):
|
||||
self._executor = executor
|
||||
self._graph_exec_id = graph_exec_id
|
||||
self._cache = ExecutionCache(graph_exec_id, self.db_client_sync)
|
||||
self._pending_tasks = set()
|
||||
self._graph_metadata = graph_metadata
|
||||
self.graph_lock = threading.RLock()
|
||||
|
||||
def finalize_execution(self, timeout: float = 30.0):
|
||||
logger.info(f"Flushing db writes for execution {self._graph_exec_id}")
|
||||
exceptions = []
|
||||
|
||||
# Wait for all pending database operations to complete
|
||||
logger.debug(
|
||||
f"Waiting for {len(self._pending_tasks)} pending database operations"
|
||||
)
|
||||
for future in list(self._pending_tasks):
|
||||
try:
|
||||
future.result(timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Background database operation failed: {e}")
|
||||
exceptions.append(e)
|
||||
finally:
|
||||
self._pending_tasks.discard(future)
|
||||
|
||||
self._cache.finalize()
|
||||
|
||||
if exceptions:
|
||||
logger.error(f"Background persistence failed with {len(exceptions)} errors")
|
||||
raise RuntimeError(
|
||||
f"Background persistence failed with {len(exceptions)} errors: {exceptions}"
|
||||
)
|
||||
|
||||
@property
|
||||
def db_client_async(self) -> "DatabaseManagerAsyncClient":
|
||||
return get_database_manager_async_client()
|
||||
|
||||
@property
|
||||
def db_client_sync(self) -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
@property
|
||||
def event_bus(self):
|
||||
return get_execution_event_bus()
|
||||
|
||||
async def get_node(self, node_id: str) -> Node:
|
||||
return await self.db_client_async.get_node(node_id)
|
||||
|
||||
def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata,
|
||||
) -> int:
|
||||
return self.db_client_sync.spend_credits(
|
||||
user_id=user_id, cost=cost, metadata=metadata
|
||||
)
|
||||
|
||||
def get_graph_execution_meta(
|
||||
self, user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
return self.db_client_sync.get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=execution_id
|
||||
)
|
||||
|
||||
def get_graph_metadata(
|
||||
self, graph_id: str, graph_version: int | None = None
|
||||
) -> Any:
|
||||
return self.db_client_sync.get_graph_metadata(graph_id, graph_version)
|
||||
|
||||
def get_credits(self, user_id: str) -> int:
|
||||
return self.db_client_sync.get_credits(user_id)
|
||||
|
||||
def get_user_email_by_id(self, user_id: str) -> str | None:
|
||||
return self.db_client_sync.get_user_email_by_id(user_id)
|
||||
|
||||
def get_latest_node_execution(self, node_id: str) -> NodeExecutionResult | None:
|
||||
return self._cache.get_latest_node_execution(node_id)
|
||||
|
||||
def get_node_execution(self, node_exec_id: str) -> NodeExecutionResult | None:
|
||||
return self._cache.get_node_execution(node_exec_id)
|
||||
|
||||
def get_node_executions(
|
||||
self,
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
) -> list[NodeExecutionResult]:
|
||||
return self._cache.get_node_executions(
|
||||
statuses=statuses, block_ids=block_ids, node_id=node_id
|
||||
)
|
||||
|
||||
def update_node_status_and_publish(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: dict | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
self._cache.update_node_execution_status(exec_id, status, execution_data, stats)
|
||||
self._persist_node_status_to_db(exec_id, status, execution_data, stats)
|
||||
|
||||
def upsert_execution_input(
|
||||
self, node_id: str, input_name: str, input_data: Any, block_id: str
|
||||
) -> tuple[str, dict]:
|
||||
# Validate input parameters to prevent foreign key constraint errors
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
raise ValueError(f"Invalid node_id: {node_id}")
|
||||
if not self._graph_exec_id or not isinstance(self._graph_exec_id, str):
|
||||
raise ValueError(f"Invalid graph_exec_id: {self._graph_exec_id}")
|
||||
if not block_id or not isinstance(block_id, str):
|
||||
raise ValueError(f"Invalid block_id: {block_id}")
|
||||
|
||||
# UPDATE: Try to find an existing incomplete execution for this node and input
|
||||
if result := self._cache.find_incomplete_execution_for_input(
|
||||
node_id, input_name
|
||||
):
|
||||
exec_id, _ = result
|
||||
updated_input_data = self._cache.update_execution_input(
|
||||
exec_id, input_name, input_data
|
||||
)
|
||||
self._persist_add_input_to_db(exec_id, input_name, input_data)
|
||||
return exec_id, updated_input_data
|
||||
|
||||
# CREATE: No suitable execution found, create new one
|
||||
node_exec_id = str(uuid.uuid4())
|
||||
logger.debug(
|
||||
f"Creating new execution {node_exec_id} for node {node_id} "
|
||||
f"in graph execution {self._graph_exec_id}"
|
||||
)
|
||||
|
||||
new_execution = NodeExecutionResult(
|
||||
user_id=self._graph_metadata.user_id,
|
||||
graph_id=self._graph_metadata.graph_id,
|
||||
graph_version=self._graph_metadata.graph_version,
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
block_id=block_id,
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
input_data={input_name: input_data},
|
||||
output_data={},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
)
|
||||
self._cache.add_node_execution(node_exec_id, new_execution)
|
||||
self._persist_new_node_execution_to_db(
|
||||
node_exec_id, node_id, input_name, input_data
|
||||
)
|
||||
|
||||
return node_exec_id, {input_name: input_data}
|
||||
|
||||
def upsert_execution_output(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
):
|
||||
self._cache.upsert_execution_output(node_exec_id, output_name, output_data)
|
||||
self._persist_execution_output_to_db(node_exec_id, output_name, output_data)
|
||||
|
||||
def update_graph_stats_and_publish(
|
||||
self,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> None:
|
||||
stats_dict = stats.model_dump() if stats else None
|
||||
self._cache.update_graph_stats(status=status, stats=stats_dict)
|
||||
self._persist_graph_stats_to_db(status=status, stats=stats)
|
||||
|
||||
def update_graph_start_time_and_publish(self) -> None:
|
||||
self._cache.update_graph_start_time()
|
||||
self._persist_graph_start_time_to_db()
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_node_status_to_db(
|
||||
self,
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: dict | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
exec_update = self.db_client_sync.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
self.event_bus.publish(exec_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_add_input_to_db(
|
||||
self, node_exec_id: str, input_name: str, input_data: Any
|
||||
):
|
||||
self.db_client_sync.add_input_to_node_execution(
|
||||
node_exec_id=node_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_execution_output_to_db(
|
||||
self, node_exec_id: str, output_name: str, output_data: Any
|
||||
):
|
||||
self.db_client_sync.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
if exec_update := self.get_node_execution(node_exec_id):
|
||||
self.event_bus.publish(exec_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_graph_stats_to_db(
|
||||
self,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
):
|
||||
graph_update = self.db_client_sync.update_graph_execution_stats(
|
||||
self._graph_exec_id, status, stats
|
||||
)
|
||||
if not graph_update:
|
||||
raise RuntimeError(
|
||||
f"Failed to update graph execution stats for {self._graph_exec_id}"
|
||||
)
|
||||
self.event_bus.publish(graph_update)
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_graph_start_time_to_db(self):
|
||||
graph_update = self.db_client_sync.update_graph_execution_start_time(
|
||||
self._graph_exec_id
|
||||
)
|
||||
if not graph_update:
|
||||
raise RuntimeError(
|
||||
f"Failed to update graph execution start time for {self._graph_exec_id}"
|
||||
)
|
||||
self.event_bus.publish(graph_update)
|
||||
|
||||
async def generate_activity_status(
|
||||
self,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
user_id: str,
|
||||
execution_status: ExecutionStatus,
|
||||
) -> str | None:
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
|
||||
return await generate_activity_status_for_execution(
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_stats=execution_stats,
|
||||
db_client=self.db_client_async,
|
||||
user_id=user_id,
|
||||
execution_status=execution_status,
|
||||
)
|
||||
|
||||
@non_blocking_persist
|
||||
def _send_execution_update(self, execution: NodeExecutionResult):
|
||||
"""Send execution update to event bus."""
|
||||
try:
|
||||
self.event_bus.publish(execution)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send execution update: {e}")
|
||||
|
||||
@non_blocking_persist
|
||||
def _persist_new_node_execution_to_db(
|
||||
self, node_exec_id: str, node_id: str, input_name: str, input_data: Any
|
||||
):
|
||||
try:
|
||||
self.db_client_sync.create_node_execution(
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
graph_exec_id=self._graph_exec_id,
|
||||
input_name=input_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create node execution {node_exec_id} for node {node_id} "
|
||||
f"in graph execution {self._graph_exec_id}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def increment_execution_count(self, user_id: str) -> int:
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}"
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
668
autogpt_platform/backend/backend/executor/execution_data_test.py
Normal file
668
autogpt_platform/backend/backend/executor/execution_data_test.py
Normal file
@@ -0,0 +1,668 @@
|
||||
"""Test suite for ExecutionDataClient."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def execution_client(event_loop):
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
|
||||
mock_graph_meta = GraphExecutionMeta(
|
||||
id="test_graph_exec_id",
|
||||
user_id="test_user_id",
|
||||
graph_id="test_graph_id",
|
||||
graph_version=1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Mock all database operations to prevent connection attempts
|
||||
async_mock_client = AsyncMock()
|
||||
sync_mock_client = MagicMock()
|
||||
|
||||
# Mock all database methods to return None or empty results
|
||||
sync_mock_client.get_node_executions.return_value = []
|
||||
sync_mock_client.create_node_execution.return_value = None
|
||||
sync_mock_client.add_input_to_node_execution.return_value = None
|
||||
sync_mock_client.update_node_execution_status.return_value = None
|
||||
sync_mock_client.upsert_execution_output.return_value = None
|
||||
sync_mock_client.update_graph_execution_stats.return_value = mock_graph_meta
|
||||
sync_mock_client.update_graph_execution_start_time.return_value = mock_graph_meta
|
||||
|
||||
# Mock event bus to prevent connection attempts
|
||||
mock_event_bus = MagicMock()
|
||||
mock_event_bus.publish.return_value = None
|
||||
|
||||
thread = threading.Thread(target=event_loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
|
||||
with patch(
|
||||
"backend.executor.execution_data.get_database_manager_async_client",
|
||||
return_value=async_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_database_manager_client",
|
||||
return_value=sync_mock_client,
|
||||
), patch(
|
||||
"backend.executor.execution_data.get_execution_event_bus",
|
||||
return_value=mock_event_bus,
|
||||
), patch(
|
||||
"backend.executor.execution_data.non_blocking_persist", lambda func: func
|
||||
):
|
||||
|
||||
client = ExecutionDataClient(executor, "test_graph_exec_id", mock_graph_meta)
|
||||
yield client
|
||||
|
||||
event_loop.call_soon_threadsafe(event_loop.stop)
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestExecutionDataClient:
|
||||
|
||||
async def test_update_node_status_writes_to_cache_immediately(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that node status updates are immediately visible in cache."""
|
||||
# First create an execution to update
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
status = ExecutionStatus.RUNNING
|
||||
execution_data = {"step": "processing"}
|
||||
stats = {"duration": 5.2}
|
||||
|
||||
# Update status of existing execution
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=status,
|
||||
execution_data=execution_data,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Verify immediate visibility in cache
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert cached_exec.status == status
|
||||
# execution_data should be merged with existing input_data, not replace it
|
||||
expected_input_data = {"test_input": "test_value", "step": "processing"}
|
||||
assert cached_exec.input_data == expected_input_data
|
||||
|
||||
def test_update_node_status_execution_not_found_raises_error(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that updating non-existent execution raises error instead of creating it."""
|
||||
non_existent_id = "does-not-exist"
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Execution does-not-exist not found in cache"
|
||||
):
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=non_existent_id, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
async def test_upsert_execution_output_writes_to_cache_immediately(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that output updates are immediately visible in cache."""
|
||||
# First create an execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
output_name = "result"
|
||||
output_data = {"answer": 42, "confidence": 0.95}
|
||||
|
||||
# Update to RUNNING status first
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id, output_name=output_name, output_data=output_data
|
||||
)
|
||||
# Check output through the node execution
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert output_name in cached_exec.output_data
|
||||
assert cached_exec.output_data[output_name] == [output_data]
|
||||
|
||||
async def test_get_node_execution_reads_from_cache(self, execution_client):
|
||||
"""Test that get_node_execution returns cached data immediately."""
|
||||
# First create an execution to work with
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Then update its status
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
execution_data={"result": "success"},
|
||||
)
|
||||
|
||||
result = execution_client.get_node_execution(node_exec_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ExecutionStatus.COMPLETED
|
||||
# execution_data gets merged with existing input_data
|
||||
expected_input_data = {"test_input": "test_value", "result": "success"}
|
||||
assert result.input_data == expected_input_data
|
||||
|
||||
async def test_get_latest_node_execution_reads_from_cache(self, execution_client):
|
||||
"""Test that get_latest_node_execution returns cached data."""
|
||||
node_id = "node-1"
|
||||
|
||||
# First create an execution for this node
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Then update its status to make it non-INCOMPLETE (so it's returned by get_latest)
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"from": "cache"},
|
||||
)
|
||||
|
||||
result = execution_client.get_latest_node_execution(node_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ExecutionStatus.RUNNING
|
||||
# execution_data gets merged with existing input_data
|
||||
expected_input_data = {"test_input": "test_value", "from": "cache"}
|
||||
assert result.input_data == expected_input_data
|
||||
|
||||
async def test_get_node_executions_sync_filters_correctly(self, execution_client):
|
||||
# Create executions with different statuses
|
||||
executions = [
|
||||
(ExecutionStatus.RUNNING, "block-a"),
|
||||
(ExecutionStatus.COMPLETED, "block-a"),
|
||||
(ExecutionStatus.FAILED, "block-b"),
|
||||
(ExecutionStatus.RUNNING, "block-b"),
|
||||
]
|
||||
|
||||
exec_ids = []
|
||||
for i, (status, block_id) in enumerate(executions):
|
||||
# First create the execution
|
||||
exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=f"node-{i}",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id=block_id,
|
||||
)
|
||||
exec_ids.append(exec_id)
|
||||
|
||||
# Then update its status and metadata
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id, status=status, execution_data={"block": block_id}
|
||||
)
|
||||
# Update cached execution with graph_exec_id and block_id for filtering
|
||||
# Note: In real implementation, these would be set during creation
|
||||
# For test purposes, we'll skip this manual update since the filtering
|
||||
# logic should work with the data as created
|
||||
|
||||
# Test status filtering
|
||||
running_execs = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING]
|
||||
)
|
||||
assert len(running_execs) == 2
|
||||
assert all(e.status == ExecutionStatus.RUNNING for e in running_execs)
|
||||
|
||||
# Test block_id filtering
|
||||
block_a_execs = execution_client.get_node_executions(block_ids=["block-a"])
|
||||
assert len(block_a_execs) == 2
|
||||
assert all(e.block_id == "block-a" for e in block_a_execs)
|
||||
|
||||
# Test combined filtering
|
||||
running_block_b = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING], block_ids=["block-b"]
|
||||
)
|
||||
assert len(running_block_b) == 1
|
||||
assert running_block_b[0].status == ExecutionStatus.RUNNING
|
||||
assert running_block_b[0].block_id == "block-b"
|
||||
|
||||
async def test_write_then_read_consistency(self, execution_client):
|
||||
"""Test critical race condition scenario: immediate read after write."""
|
||||
# First create an execution to work with
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="consistency-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Write status
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"step": 1},
|
||||
)
|
||||
|
||||
# Write output
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="intermediate",
|
||||
output_data={"progress": 50},
|
||||
)
|
||||
|
||||
# Update status again
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
execution_data={"step": 2},
|
||||
)
|
||||
|
||||
# All changes should be immediately visible
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert cached_exec.status == ExecutionStatus.COMPLETED
|
||||
# execution_data gets merged with existing input_data - step 2 overwrites step 1
|
||||
expected_input_data = {"test_input": "test_value", "step": 2}
|
||||
assert cached_exec.input_data == expected_input_data
|
||||
|
||||
# Output should be visible in execution record
|
||||
assert cached_exec.output_data["intermediate"] == [{"progress": 50}]
|
||||
|
||||
async def test_concurrent_operations_are_thread_safe(self, execution_client):
|
||||
"""Test that concurrent operations don't corrupt cache."""
|
||||
num_threads = 3 # Reduced for simpler test
|
||||
operations_per_thread = 5 # Reduced for simpler test
|
||||
|
||||
# Create all executions upfront
|
||||
created_exec_ids = []
|
||||
for thread_id in range(num_threads):
|
||||
for i in range(operations_per_thread):
|
||||
exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id=f"node-{thread_id}-{i}",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id=f"block-{thread_id}-{i}",
|
||||
)
|
||||
created_exec_ids.append((exec_id, thread_id, i))
|
||||
|
||||
def worker(thread_data):
|
||||
"""Perform multiple operations from a thread."""
|
||||
thread_id, ops = thread_data
|
||||
for i, (exec_id, _, _) in enumerate(ops):
|
||||
# Status updates
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"thread": thread_id, "op": i},
|
||||
)
|
||||
|
||||
# Output updates (use just one exec_id per thread for outputs)
|
||||
if i == 0: # Only add outputs to first execution of each thread
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=exec_id,
|
||||
output_name=f"output_{i}",
|
||||
output_data={"thread": thread_id, "value": i},
|
||||
)
|
||||
|
||||
# Organize executions by thread
|
||||
thread_data = []
|
||||
for tid in range(num_threads):
|
||||
thread_ops = [
|
||||
exec_data for exec_data in created_exec_ids if exec_data[1] == tid
|
||||
]
|
||||
thread_data.append((tid, thread_ops))
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for data in thread_data:
|
||||
thread = threading.Thread(target=worker, args=(data,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for completion
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify data integrity
|
||||
expected_executions = num_threads * operations_per_thread
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == expected_executions
|
||||
|
||||
# Verify outputs - only first execution of each thread should have outputs
|
||||
output_count = 0
|
||||
for execution in all_executions:
|
||||
if execution.output_data:
|
||||
output_count += 1
|
||||
assert output_count == num_threads # One output per thread
|
||||
|
||||
async def test_sync_and_async_versions_consistent(self, execution_client):
|
||||
"""Test that sync and async versions of output operations behave the same."""
|
||||
# First create the execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="sync-async-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="sync_result",
|
||||
output_data={"method": "sync"},
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="async_result",
|
||||
output_data={"method": "async"},
|
||||
)
|
||||
|
||||
cached_exec = execution_client.get_node_execution(node_exec_id)
|
||||
assert cached_exec is not None
|
||||
assert "sync_result" in cached_exec.output_data
|
||||
assert "async_result" in cached_exec.output_data
|
||||
assert cached_exec.output_data["sync_result"] == [{"method": "sync"}]
|
||||
assert cached_exec.output_data["async_result"] == [{"method": "async"}]
|
||||
|
||||
async def test_finalize_execution_completes_and_clears_cache(
|
||||
self, execution_client
|
||||
):
|
||||
"""Test that finalize_execution waits for background tasks and clears cache."""
|
||||
# First create the execution
|
||||
node_exec_id, _ = execution_client.upsert_execution_input(
|
||||
node_id="pending-test-node",
|
||||
input_name="test_input",
|
||||
input_data="test_value",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Trigger some background operations
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id, status=ExecutionStatus.RUNNING
|
||||
)
|
||||
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id, output_name="test", output_data={"value": 1}
|
||||
)
|
||||
|
||||
# Wait for background tasks - may fail in test environment due to DB issues
|
||||
try:
|
||||
execution_client.finalize_execution(timeout=5.0)
|
||||
except RuntimeError as e:
|
||||
# In test environment, background DB operations may fail, but cache should still be cleared
|
||||
assert "Background persistence failed" in str(e)
|
||||
|
||||
# Cache should be cleared regardless of background task failures
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == 0 # Cache should be cleared
|
||||
|
||||
async def test_manager_usage_pattern(self, execution_client):
|
||||
# Create executions first
|
||||
node_exec_id_1, _ = execution_client.upsert_execution_input(
|
||||
node_id="node-1",
|
||||
input_name="input1",
|
||||
input_data="data1",
|
||||
block_id="block-1",
|
||||
)
|
||||
|
||||
node_exec_id_2, _ = execution_client.upsert_execution_input(
|
||||
node_id="node-2",
|
||||
input_name="input_from_node1",
|
||||
input_data="value1",
|
||||
block_id="block-2",
|
||||
)
|
||||
|
||||
# Simulate manager.py workflow
|
||||
|
||||
# 1. Start execution
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_1,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "data1"},
|
||||
)
|
||||
|
||||
# 2. Node produces output
|
||||
execution_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id_1,
|
||||
output_name="result",
|
||||
output_data={"output": "value1"},
|
||||
)
|
||||
|
||||
# 3. Complete first node
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_1, status=ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
# 4. Start second node (would read output from first)
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id_2,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input_from_node1": "value1"},
|
||||
)
|
||||
|
||||
# 5. Manager queries for executions
|
||||
|
||||
all_executions = execution_client.get_node_executions()
|
||||
running_executions = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.RUNNING]
|
||||
)
|
||||
completed_executions = execution_client.get_node_executions(
|
||||
statuses=[ExecutionStatus.COMPLETED]
|
||||
)
|
||||
|
||||
# Verify manager can see all data immediately
|
||||
assert len(all_executions) == 2
|
||||
assert len(running_executions) == 1
|
||||
assert len(completed_executions) == 1
|
||||
|
||||
# Verify output is accessible
|
||||
exec_1 = execution_client.get_node_execution(node_exec_id_1)
|
||||
assert exec_1 is not None
|
||||
assert exec_1.output_data["result"] == [{"output": "value1"}]
|
||||
|
||||
def test_stats_handling_in_update_node_status(self, execution_client):
|
||||
"""Test that stats parameter is properly handled in update_node_status_and_publish."""
|
||||
# Create a fake execution directly in cache to avoid database issues
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
|
||||
node_exec_id = "test-stats-exec-id"
|
||||
fake_execution = NodeExecutionResult(
|
||||
user_id="test-user",
|
||||
graph_id="test-graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test-graph-exec",
|
||||
node_exec_id=node_exec_id,
|
||||
node_id="stats-test-node",
|
||||
block_id="test-block",
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
input_data={"test_input": "test_value"},
|
||||
output_data={},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
stats=None,
|
||||
)
|
||||
|
||||
# Add directly to cache
|
||||
execution_client._cache.add_node_execution(node_exec_id, fake_execution)
|
||||
|
||||
stats = {"token_count": 150, "processing_time": 2.5}
|
||||
|
||||
# Update status with stats
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
execution_data={"input": "test"},
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Verify execution was updated and stats are stored
|
||||
execution = execution_client.get_node_execution(node_exec_id)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.RUNNING
|
||||
|
||||
# Stats should be stored in proper stats field
|
||||
assert execution.stats is not None
|
||||
stats_dict = execution.stats.model_dump()
|
||||
# Only check the fields we set, ignore defaults
|
||||
assert stats_dict["token_count"] == 150
|
||||
assert stats_dict["processing_time"] == 2.5
|
||||
|
||||
# Update with additional stats
|
||||
additional_stats = {"error_count": 0}
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
stats=additional_stats,
|
||||
)
|
||||
|
||||
# Stats should be merged
|
||||
execution = execution_client.get_node_execution(node_exec_id)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.COMPLETED
|
||||
stats_dict = execution.stats.model_dump()
|
||||
# Check the merged stats
|
||||
assert stats_dict["token_count"] == 150
|
||||
assert stats_dict["processing_time"] == 2.5
|
||||
assert stats_dict["error_count"] == 0
|
||||
|
||||
async def test_upsert_execution_input_scenarios(self, execution_client):
|
||||
"""Test different scenarios of upsert_execution_input - create vs update."""
|
||||
node_id = "test-node"
|
||||
graph_exec_id = (
|
||||
"test_graph_exec_id" # Must match the ExecutionDataClient's scope
|
||||
)
|
||||
|
||||
# Scenario 1: Create new execution when none exists
|
||||
exec_id_1, input_data_1 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="first_input",
|
||||
input_data="value1",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should create new execution
|
||||
execution = execution_client.get_node_execution(exec_id_1)
|
||||
assert execution is not None
|
||||
assert execution.status == ExecutionStatus.INCOMPLETE
|
||||
assert execution.node_id == node_id
|
||||
assert execution.graph_exec_id == graph_exec_id
|
||||
assert input_data_1 == {"first_input": "value1"}
|
||||
|
||||
# Scenario 2: Add input to existing INCOMPLETE execution
|
||||
exec_id_2, input_data_2 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="second_input",
|
||||
input_data="value2",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should use same execution
|
||||
assert exec_id_2 == exec_id_1
|
||||
assert input_data_2 == {"first_input": "value1", "second_input": "value2"}
|
||||
|
||||
# Verify execution has both inputs
|
||||
execution = execution_client.get_node_execution(exec_id_1)
|
||||
assert execution is not None
|
||||
assert execution.input_data == {
|
||||
"first_input": "value1",
|
||||
"second_input": "value2",
|
||||
}
|
||||
|
||||
# Scenario 3: Create new execution when existing is not INCOMPLETE
|
||||
execution_client.update_node_status_and_publish(
|
||||
exec_id=exec_id_1, status=ExecutionStatus.RUNNING
|
||||
)
|
||||
|
||||
exec_id_3, input_data_3 = execution_client.upsert_execution_input(
|
||||
node_id=node_id,
|
||||
input_name="third_input",
|
||||
input_data="value3",
|
||||
block_id="test-block",
|
||||
)
|
||||
|
||||
# Should create new execution
|
||||
assert exec_id_3 != exec_id_1
|
||||
execution_3 = execution_client.get_node_execution(exec_id_3)
|
||||
assert execution_3 is not None
|
||||
assert input_data_3 == {"third_input": "value3"}
|
||||
|
||||
# Verify we now have 2 executions
|
||||
all_executions = execution_client.get_node_executions()
|
||||
assert len(all_executions) == 2
|
||||
|
||||
def test_graph_stats_operations(self, execution_client):
|
||||
"""Test graph-level stats and start time operations."""
|
||||
|
||||
# Test update_graph_stats_and_publish
|
||||
from backend.data.model import GraphExecutionStats
|
||||
|
||||
stats = GraphExecutionStats(
|
||||
walltime=10.5, cputime=8.2, node_count=5, node_error_count=1
|
||||
)
|
||||
|
||||
execution_client.update_graph_stats_and_publish(
|
||||
status=ExecutionStatus.RUNNING, stats=stats
|
||||
)
|
||||
|
||||
# Verify stats are stored in cache
|
||||
cached_stats = execution_client._cache._graph_stats
|
||||
assert cached_stats.walltime == 10.5
|
||||
|
||||
execution_client.update_graph_start_time_and_publish()
|
||||
cached_stats = execution_client._cache._graph_stats
|
||||
assert cached_stats.walltime == 10.5
|
||||
|
||||
def test_public_methods_accessible(self, execution_client):
|
||||
"""Test that public methods are accessible."""
|
||||
assert hasattr(execution_client._cache, "update_node_execution_status")
|
||||
assert hasattr(execution_client._cache, "upsert_execution_output")
|
||||
assert hasattr(execution_client._cache, "add_node_execution")
|
||||
assert hasattr(execution_client._cache, "find_incomplete_execution_for_input")
|
||||
assert hasattr(execution_client._cache, "update_execution_input")
|
||||
assert hasattr(execution_client, "upsert_execution_input")
|
||||
assert hasattr(execution_client, "update_node_status_and_publish")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -5,37 +5,15 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
from typing import Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from pydantic import JsonValue
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
@@ -47,18 +25,28 @@ from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.execution_data import ExecutionDataClient
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
@@ -67,22 +55,20 @@ from backend.executor.utils import (
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
)
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
from backend.util.decorator import (
|
||||
async_error_logged,
|
||||
async_time_measured,
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -134,7 +120,6 @@ async def execute_node(
|
||||
persist the execution result, and return the subsequent node to be executed.
|
||||
|
||||
Args:
|
||||
db_client: The client to send execution updates to the server.
|
||||
creds_manager: The manager to acquire and release credentials.
|
||||
data: The execution data for executing the current node.
|
||||
execution_stats: The execution statistics to be updated.
|
||||
@@ -190,6 +175,9 @@ async def execute_node(
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
# Add user context from NodeExecutionEntry
|
||||
extra_exec_kwargs["user_context"] = data.user_context
|
||||
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
# changes during execution. ⚠️ This means a set of credentials can only be used by
|
||||
# one (running) block at a time; simultaneous execution of blocks using same
|
||||
@@ -228,7 +216,7 @@ async def execute_node(
|
||||
|
||||
|
||||
async def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
execution_data_client: ExecutionDataClient,
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
@@ -236,12 +224,12 @@ async def _enqueue_next_nodes(
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
user_context: UserContext,
|
||||
) -> list[NodeExecutionEntry]:
|
||||
async def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
) -> NodeExecutionEntry:
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
execution_data_client.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.QUEUED,
|
||||
execution_data=data,
|
||||
@@ -254,6 +242,7 @@ async def _enqueue_next_nodes(
|
||||
node_id=node_id,
|
||||
block_id=block_id,
|
||||
inputs=data,
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
async def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
||||
@@ -273,21 +262,22 @@ async def _enqueue_next_nodes(
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
if next_data is None and output_name != next_output_name:
|
||||
return enqueued_executions
|
||||
next_node = await db_client.get_node(next_node_id)
|
||||
next_node = await execution_data_client.get_node(next_node_id)
|
||||
|
||||
# Multiple node can register the same next node, we need this to be atomic
|
||||
# To avoid same execution to be enqueued multiple times,
|
||||
# Or the same input to be consumed multiple times.
|
||||
async with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
|
||||
with execution_data_client.graph_lock:
|
||||
# Add output data to the earliest incomplete execution, or create a new one.
|
||||
next_node_exec_id, next_node_input = await db_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
next_node_exec_id, next_node_input = (
|
||||
execution_data_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
block_id=next_node.block_id,
|
||||
)
|
||||
)
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
execution_data_client.update_node_status_and_publish(
|
||||
exec_id=next_node_exec_id,
|
||||
status=ExecutionStatus.INCOMPLETE,
|
||||
)
|
||||
@@ -299,8 +289,8 @@ async def _enqueue_next_nodes(
|
||||
if link.is_static and link.sink_name not in next_node_input
|
||||
}
|
||||
if static_link_names and (
|
||||
latest_execution := await db_client.get_latest_node_execution(
|
||||
next_node_id, graph_exec_id
|
||||
latest_execution := execution_data_client.get_latest_node_execution(
|
||||
next_node_id
|
||||
)
|
||||
):
|
||||
for name in static_link_names:
|
||||
@@ -339,9 +329,8 @@ async def _enqueue_next_nodes(
|
||||
|
||||
# If link is static, there could be some incomplete executions waiting for it.
|
||||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||||
for iexec in await db_client.get_node_executions(
|
||||
for iexec in execution_data_client.get_node_executions(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.INCOMPLETE],
|
||||
):
|
||||
idata = iexec.input_data
|
||||
@@ -405,6 +394,9 @@ class ExecutionProcessor:
|
||||
9. Node executor enqueues the next executed nodes to the node execution queue.
|
||||
"""
|
||||
|
||||
# Current execution data client (scoped to current graph execution)
|
||||
execution_data: ExecutionDataClient
|
||||
|
||||
@async_error_logged(swallow=True)
|
||||
async def on_node_execution(
|
||||
self,
|
||||
@@ -422,8 +414,7 @@ class ExecutionProcessor:
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
node = await self.execution_data.get_node(node_exec.node_id)
|
||||
execution_stats = NodeExecutionStats()
|
||||
|
||||
timing_info, status = await self._on_node_execution(
|
||||
@@ -431,7 +422,6 @@ class ExecutionProcessor:
|
||||
node_exec=node_exec,
|
||||
node_exec_progress=node_exec_progress,
|
||||
stats=execution_stats,
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
@@ -455,15 +445,12 @@ class ExecutionProcessor:
|
||||
if node_error and not isinstance(node_error, str):
|
||||
node_stats["error"] = str(node_error) or node_stats.__class__.__name__
|
||||
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=status,
|
||||
stats=node_stats,
|
||||
)
|
||||
await async_update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
stats=graph_stats,
|
||||
)
|
||||
|
||||
@@ -476,22 +463,17 @@ class ExecutionProcessor:
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
async def persist_output(output_name: str, output_data: Any) -> None:
|
||||
await db_client.upsert_execution_output(
|
||||
self.execution_data.upsert_execution_output(
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
if exec_update := await db_client.get_node_execution(
|
||||
node_exec.node_exec_id
|
||||
):
|
||||
await send_async_execution_update(exec_update)
|
||||
|
||||
node_exec_progress.add_output(
|
||||
ExecutionOutputEntry(
|
||||
@@ -503,8 +485,7 @@ class ExecutionProcessor:
|
||||
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
)
|
||||
@@ -565,6 +546,8 @@ class ExecutionProcessor:
|
||||
self.node_evaluation_thread = threading.Thread(
|
||||
target=self.node_evaluation_loop.run_forever, daemon=True
|
||||
)
|
||||
# single thread executor
|
||||
self.execution_data_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
@@ -584,9 +567,13 @@ class ExecutionProcessor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
db_client = get_db_client()
|
||||
|
||||
exec_meta = db_client.get_graph_execution_meta(
|
||||
# Get graph execution metadata first via sync client
|
||||
from backend.util.clients import get_database_manager_client
|
||||
|
||||
db_client_sync = get_database_manager_client()
|
||||
|
||||
exec_meta = db_client_sync.get_graph_execution_meta(
|
||||
user_id=graph_exec.user_id,
|
||||
execution_id=graph_exec.graph_exec_id,
|
||||
)
|
||||
@@ -596,12 +583,15 @@ class ExecutionProcessor:
|
||||
)
|
||||
return
|
||||
|
||||
# Create scoped ExecutionDataClient for this graph execution with metadata
|
||||
self.execution_data = ExecutionDataClient(
|
||||
self.execution_data_executor, graph_exec.graph_exec_id, exec_meta
|
||||
)
|
||||
|
||||
if exec_meta.status == ExecutionStatus.QUEUED:
|
||||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||||
exec_meta.status = ExecutionStatus.RUNNING
|
||||
send_execution_update(
|
||||
db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
|
||||
)
|
||||
self.execution_data.update_graph_start_time_and_publish()
|
||||
elif exec_meta.status == ExecutionStatus.RUNNING:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
|
||||
@@ -611,9 +601,7 @@ class ExecutionProcessor:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was disturbed, continuing where it left off."
|
||||
)
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
status=ExecutionStatus.RUNNING,
|
||||
)
|
||||
else:
|
||||
@@ -644,12 +632,10 @@ class ExecutionProcessor:
|
||||
|
||||
# Activity status handling
|
||||
activity_status = asyncio.run_coroutine_threadsafe(
|
||||
generate_activity_status_for_execution(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.generate_activity_status(
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_version=graph_exec.graph_version,
|
||||
execution_stats=exec_stats,
|
||||
db_client=get_db_async_client(),
|
||||
user_id=graph_exec.user_id,
|
||||
execution_status=status,
|
||||
),
|
||||
@@ -664,33 +650,32 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
# Communication handling
|
||||
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
self._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
finally:
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
self.execution_data.update_graph_stats_and_publish(
|
||||
status=exec_meta.status,
|
||||
stats=exec_stats,
|
||||
)
|
||||
self.execution_data.finalize_execution()
|
||||
|
||||
def _charge_usage(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
) -> int:
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
db_client = get_db_client()
|
||||
remaining_balance = 0
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return total_cost
|
||||
return total_cost, 0
|
||||
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
db_client.spend_credits(
|
||||
remaining_balance = self.execution_data.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -708,7 +693,7 @@ class ExecutionProcessor:
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
db_client.spend_credits(
|
||||
remaining_balance = self.execution_data.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -723,7 +708,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
return total_cost
|
||||
return total_cost, remaining_balance
|
||||
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
@@ -741,7 +726,6 @@ class ExecutionProcessor:
|
||||
"""
|
||||
execution_status: ExecutionStatus = ExecutionStatus.RUNNING
|
||||
error: Exception | None = None
|
||||
db_client = get_db_client()
|
||||
execution_stats_lock = threading.Lock()
|
||||
|
||||
# State holders ----------------------------------------------------
|
||||
@@ -752,7 +736,7 @@ class ExecutionProcessor:
|
||||
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
|
||||
try:
|
||||
if db_client.get_credits(graph_exec.user_id) <= 0:
|
||||
if self.execution_data.get_credits(graph_exec.user_id) <= 0:
|
||||
raise InsufficientBalanceError(
|
||||
user_id=graph_exec.user_id,
|
||||
message="You have no credits left to run an agent.",
|
||||
@@ -764,7 +748,7 @@ class ExecutionProcessor:
|
||||
try:
|
||||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||||
automod_manager.moderate_graph_execution_inputs(
|
||||
db_client=get_db_async_client(),
|
||||
db_client=self.execution_data.db_client_async,
|
||||
graph_exec=graph_exec,
|
||||
),
|
||||
self.node_evaluation_loop,
|
||||
@@ -779,15 +763,34 @@ class ExecutionProcessor:
|
||||
# ------------------------------------------------------------
|
||||
# Pre‑populate queue ---------------------------------------
|
||||
# ------------------------------------------------------------
|
||||
for node_exec in db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
|
||||
queued_executions = self.execution_data.get_node_executions(
|
||||
statuses=[
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
],
|
||||
):
|
||||
execution_queue.add(node_exec.to_node_execution_entry())
|
||||
)
|
||||
log_metadata.info(
|
||||
f"Pre-populating queue with {len(queued_executions)} executions from cache"
|
||||
)
|
||||
|
||||
for i, node_exec in enumerate(queued_executions):
|
||||
log_metadata.info(
|
||||
f" [{i}] {node_exec.node_exec_id}: status={node_exec.status}, node={node_exec.node_id}"
|
||||
)
|
||||
try:
|
||||
node_entry = node_exec.to_node_execution_entry(
|
||||
graph_exec.user_context
|
||||
)
|
||||
execution_queue.add(node_entry)
|
||||
log_metadata.info(" Added to execution queue successfully")
|
||||
except Exception as e:
|
||||
log_metadata.error(f" Failed to add to execution queue: {e}")
|
||||
|
||||
log_metadata.info(
|
||||
f"Execution queue populated with {len(queued_executions)} executions"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Main dispatch / polling loop -----------------------------
|
||||
@@ -805,31 +808,36 @@ class ExecutionProcessor:
|
||||
|
||||
# Charge usage (may raise) ------------------------------
|
||||
try:
|
||||
cost = self._charge_usage(
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
execution_count=self.execution_data.increment_execution_count(
|
||||
graph_exec.user_id
|
||||
),
|
||||
)
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
transaction_cost=cost,
|
||||
)
|
||||
except InsufficientBalanceError as balance_error:
|
||||
error = balance_error # Set error to trigger FAILED status
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
db_client.upsert_execution_output(
|
||||
self.execution_data.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="error",
|
||||
output_data=str(error),
|
||||
)
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
self._handle_low_balance_notif(
|
||||
db_client,
|
||||
self._handle_insufficient_funds_notif(
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
execution_stats,
|
||||
error,
|
||||
)
|
||||
# Gracefully stop the execution loop
|
||||
@@ -914,12 +922,13 @@ class ExecutionProcessor:
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
# Background task finalization moved to finally block
|
||||
|
||||
# Output moderation
|
||||
try:
|
||||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||||
automod_manager.moderate_graph_execution_outputs(
|
||||
db_client=get_db_async_client(),
|
||||
db_client=self.execution_data.db_client_async,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
@@ -973,7 +982,6 @@ class ExecutionProcessor:
|
||||
error=error,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
log_metadata=log_metadata,
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
@error_logged(swallow=True)
|
||||
@@ -986,7 +994,6 @@ class ExecutionProcessor:
|
||||
error: Exception | None,
|
||||
graph_exec_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
db_client: "DatabaseManagerClient",
|
||||
) -> None:
|
||||
"""
|
||||
Clean up running node executions and evaluations when graph execution ends.
|
||||
@@ -1020,8 +1027,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
|
||||
while queued_execution := execution_queue.get_or_none():
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
self.execution_data.update_node_status_and_publish(
|
||||
exec_id=queued_execution.node_exec_id,
|
||||
status=execution_status,
|
||||
stats={"error": str(error)} if error else None,
|
||||
@@ -1049,11 +1055,10 @@ class ExecutionProcessor:
|
||||
nodes_input_masks: Optional map of node input overrides
|
||||
execution_queue: Queue to add next executions to
|
||||
"""
|
||||
db_client = get_db_async_client()
|
||||
|
||||
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
|
||||
|
||||
for next_execution in await _enqueue_next_nodes(
|
||||
db_client=db_client,
|
||||
execution_data_client=self.execution_data,
|
||||
node=output.node,
|
||||
output=output.data,
|
||||
user_id=graph_exec.user_id,
|
||||
@@ -1061,20 +1066,19 @@ class ExecutionProcessor:
|
||||
graph_id=graph_exec.graph_id,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
user_context=graph_exec.user_context,
|
||||
):
|
||||
execution_queue.add(next_execution)
|
||||
|
||||
def _handle_agent_run_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
):
|
||||
metadata = db_client.get_graph_metadata(
|
||||
metadata = self.execution_data.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
outputs = self.execution_data.get_node_executions(
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
@@ -1101,25 +1105,24 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_low_balance_notif(
|
||||
def _handle_insufficient_funds_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
exec_stats: GraphExecutionStats,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
shortfall = e.balance - e.amount
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = self.execution_data.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=exec_stats.cost,
|
||||
type=NotificationType.ZERO_BALANCE,
|
||||
data=ZeroBalanceData(
|
||||
current_balance=e.balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
shortfall=shortfall,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
@@ -1127,6 +1130,72 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = self.execution_data.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||||
f"Current balance: ${e.balance/100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount)/100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall)/100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as alert_error:
|
||||
logger.error(
|
||||
f"Failed to send insufficient funds Discord alert: {alert_error}"
|
||||
)
|
||||
|
||||
def _handle_low_balance(
|
||||
self,
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
):
|
||||
"""Check and handle low balance scenarios after a transaction"""
|
||||
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
|
||||
|
||||
balance_before = current_balance + transaction_cost
|
||||
|
||||
if (
|
||||
current_balance < LOW_BALANCE_THRESHOLD
|
||||
and balance_before >= LOW_BALANCE_THRESHOLD
|
||||
):
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=current_balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = self.execution_data.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD/100:.2f}\n"
|
||||
f"Current balance: ${current_balance/100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost/100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send low balance Discord alert: {e}")
|
||||
|
||||
|
||||
class ExecutionManager(AppProcess):
|
||||
def __init__(self):
|
||||
@@ -1308,14 +1377,14 @@ class ExecutionManager(AppProcess):
|
||||
delivery_tag = method.delivery_tag
|
||||
|
||||
@func_retry
|
||||
def _ack_message(reject: bool = False):
|
||||
def _ack_message(reject: bool, requeue: bool):
|
||||
"""Acknowledge or reject the message based on execution status."""
|
||||
|
||||
# Connection can be lost, so always get a fresh channel
|
||||
channel = self.run_client.get_channel()
|
||||
if reject:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=True)
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=requeue)
|
||||
)
|
||||
else:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
@@ -1327,13 +1396,13 @@ class ExecutionManager(AppProcess):
|
||||
logger.info(
|
||||
f"[{self.service_name}] Rejecting new execution during shutdown"
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Check if we can accept more runs
|
||||
self._cleanup_completed_runs()
|
||||
if len(self.active_graph_runs) >= self.pool_size:
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -1342,7 +1411,7 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Could not parse run message: {e}, body={body}"
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
@@ -1354,7 +1423,7 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
cancel_event = threading.Event()
|
||||
@@ -1370,9 +1439,9 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}"
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
else:
|
||||
_ack_message(reject=False)
|
||||
_ack_message(reject=False, requeue=False)
|
||||
except BaseException as e:
|
||||
logger.exception(
|
||||
f"[{self.service_name}] Error in run completion callback: {e}"
|
||||
@@ -1490,117 +1559,3 @@ class ExecutionManager(AppProcess):
|
||||
)
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
|
||||
def get_db_async_client() -> "DatabaseManagerAsyncClient":
|
||||
return get_database_manager_async_client()
|
||||
|
||||
|
||||
@func_retry
|
||||
async def send_async_execution_update(
|
||||
entry: GraphExecution | NodeExecutionResult | None,
|
||||
) -> None:
|
||||
if entry is None:
|
||||
return
|
||||
await get_async_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@func_retry
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
|
||||
if entry is None:
|
||||
return
|
||||
return get_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
async def async_update_node_execution_status(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = await db_client.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
await send_async_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
|
||||
def update_node_execution_status(
|
||||
db_client: "DatabaseManagerClient",
|
||||
exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(
|
||||
exec_id, status, execution_data, stats
|
||||
)
|
||||
send_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
|
||||
async def async_update_graph_execution_state(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = await db_client.update_graph_execution_stats(
|
||||
graph_exec_id, status, stats
|
||||
)
|
||||
if graph_update:
|
||||
await send_async_execution_update(graph_update)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
|
||||
def update_graph_execution_state(
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
|
||||
if graph_update:
|
||||
send_execution_update(graph_update)
|
||||
else:
|
||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||
return graph_update
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def synchronized(key: str, timeout: int = 60):
|
||||
r = await redis.get_redis_async()
|
||||
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
await lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if await lock.locked() and await lock.owned():
|
||||
await lock.release()
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
"""
|
||||
Increment the execution count for a given user,
|
||||
this will be used to charge the user for the execution cost.
|
||||
"""
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}" # User Execution Count global key
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import LowBalanceData
|
||||
from backend.executor.manager import ExecutionProcessor
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
"""Test that _handle_low_balance triggers notification when crossing threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 400 # $4 - below $5 threshold
|
||||
transaction_cost = 600 # $6 transaction
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
)
|
||||
|
||||
# Verify notification was queued
|
||||
mock_queue_notif.assert_called_once()
|
||||
notification_call = mock_queue_notif.call_args[0][0]
|
||||
|
||||
# Verify notification details
|
||||
assert notification_call.type == NotificationType.LOW_BALANCE
|
||||
assert notification_call.user_id == user_id
|
||||
assert isinstance(notification_call.data, LowBalanceData)
|
||||
assert notification_call.data.current_balance == current_balance
|
||||
|
||||
# Verify Discord alert was sent
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
discord_message = mock_client.discord_system_alert.call_args[0][0]
|
||||
assert "Low Balance Alert" in discord_message
|
||||
assert "test@example.com" in discord_message
|
||||
assert "$4.00" in discord_message
|
||||
assert "$6.00" in discord_message
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that no notification is sent when not crossing the threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 600 # $6 - above $5 threshold
|
||||
transaction_cost = (
|
||||
100 # $1 transaction (balance before was $7, still above threshold)
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
)
|
||||
|
||||
# Verify no notification was sent
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that no notification is sent when already below threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 300 # $3 - below $5 threshold
|
||||
transaction_cost = (
|
||||
100 # $1 transaction (balance before was $4, also below threshold)
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Initialize the execution processor and mock its execution_data
|
||||
execution_processor.on_graph_executor_start()
|
||||
|
||||
# Mock the execution_data attribute since it's created in on_graph_execution
|
||||
mock_execution_data = MagicMock()
|
||||
execution_processor.execution_data = mock_execution_data
|
||||
|
||||
mock_execution_data.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
)
|
||||
|
||||
# Verify no notification was sent (user was already below threshold)
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
# Cleanup execution processor threads
|
||||
try:
|
||||
execution_processor.node_execution_loop.call_soon_threadsafe(
|
||||
execution_processor.node_execution_loop.stop
|
||||
)
|
||||
execution_processor.node_evaluation_loop.call_soon_threadsafe(
|
||||
execution_processor.node_evaluation_loop.stop
|
||||
)
|
||||
execution_processor.node_execution_thread.join(timeout=1)
|
||||
execution_processor.node_evaluation_thread.join(timeout=1)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
|
||||
@@ -521,12 +520,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
),
|
||||
autogpt_libs.auth.models.User(
|
||||
user_id=admin_user.id,
|
||||
role="admin",
|
||||
email=admin_user.email,
|
||||
phone_number="1234567890",
|
||||
),
|
||||
user_id=admin_user.id,
|
||||
)
|
||||
alt_test_user = admin_user
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from apscheduler.jobstores.memory import MemoryJobStore
|
||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.util import ZoneInfo
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
@@ -303,6 +304,7 @@ class Scheduler(AppService):
|
||||
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
|
||||
},
|
||||
logger=apscheduler_logger,
|
||||
timezone=ZoneInfo("UTC"),
|
||||
)
|
||||
|
||||
if self.register_system_tasks:
|
||||
@@ -406,6 +408,8 @@ class Scheduler(AppService):
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Scheduling job for user {user_id} in UTC (cron: {cron})")
|
||||
|
||||
job_args = GraphExecutionJobArgs(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
@@ -418,12 +422,12 @@ class Scheduler(AppService):
|
||||
execute_graph,
|
||||
kwargs=job_args.model_dump(),
|
||||
name=name,
|
||||
trigger=CronTrigger.from_crontab(cron),
|
||||
trigger=CronTrigger.from_crontab(cron, timezone="UTC"),
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}"
|
||||
f"Added job {job.id} with cron schedule '{cron}' in UTC, input data: {input_data}"
|
||||
)
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
|
||||
@@ -18,10 +18,12 @@ from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
@@ -34,6 +36,27 @@ from backend.util.mock import MockObject
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
async def get_user_context(user_id: str) -> UserContext:
|
||||
"""
|
||||
Get UserContext for a user, always returns a valid context with timezone.
|
||||
Defaults to UTC if user has no timezone set.
|
||||
"""
|
||||
user_context = UserContext(timezone="UTC") # Default to UTC
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
if user and user.timezone and user.timezone != "not-set":
|
||||
user_context.timezone = user.timezone
|
||||
logger.debug(f"Retrieved user context: timezone={user.timezone}")
|
||||
else:
|
||||
logger.debug("User has no timezone set, using UTC")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch user timezone: {e}")
|
||||
# Continue with UTC as default
|
||||
|
||||
return user_context
|
||||
|
||||
|
||||
config = Config()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||
|
||||
@@ -877,8 +900,11 @@ async def add_graph_execution(
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
# Fetch user context for the graph execution
|
||||
user_context = await get_user_context(user_id)
|
||||
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(user_context)
|
||||
if nodes_input_masks:
|
||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.integrations.oauth.todoist import TodoistOAuthHandler
|
||||
|
||||
from .discord import DiscordOAuthHandler
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
@@ -15,6 +16,7 @@ if TYPE_CHECKING:
|
||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||
# Build handlers dict with string keys for compatibility with SDK auto-registration
|
||||
_ORIGINAL_HANDLERS = [
|
||||
DiscordOAuthHandler,
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
|
||||
175
autogpt_platform/backend/backend/integrations/oauth/discord.py
Normal file
175
autogpt_platform/backend/backend/integrations/oauth/discord.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
|
||||
class DiscordOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Discord OAuth2 handler implementation.
|
||||
|
||||
Based on the documentation at:
|
||||
- https://discord.com/developers/docs/topics/oauth2
|
||||
|
||||
Discord OAuth2 tokens expire after 7 days by default and include refresh tokens.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.DISCORD
|
||||
DEFAULT_SCOPES = ["identify"] # Basic user information
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://discord.com/oauth2/authorize"
|
||||
self.token_url = "https://discord.com/api/oauth2/token"
|
||||
self.revoke_url = "https://discord.com/api/oauth2/token/revoke"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
# Handle default scopes
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
}
|
||||
|
||||
# Discord supports PKCE
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
params = {
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
|
||||
# Include PKCE verifier if provided
|
||||
if code_verifier:
|
||||
params["code_verifier"] = code_verifier
|
||||
|
||||
return await self._request_tokens(params)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
if not credentials.access_token:
|
||||
raise ValueError("No access token to revoke")
|
||||
|
||||
# Discord requires client authentication for token revocation
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
url=self.revoke_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
auth=(self.client_id, self.client_secret),
|
||||
)
|
||||
|
||||
# Discord returns 200 OK for successful revocation
|
||||
return response.status == 200
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
return credentials
|
||||
|
||||
return await self._request_tokens(
|
||||
{
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
current_credentials=credentials,
|
||||
)
|
||||
|
||||
async def _request_tokens(
|
||||
self,
|
||||
params: dict[str, str],
|
||||
current_credentials: Optional[OAuth2Credentials] = None,
|
||||
) -> OAuth2Credentials:
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
**params,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
self.token_url, data=request_body, headers=headers
|
||||
)
|
||||
token_data: dict = response.json()
|
||||
|
||||
# Get username if this is a new token request
|
||||
username = None
|
||||
if "access_token" in token_data:
|
||||
username = await self._request_username(token_data["access_token"])
|
||||
|
||||
now = int(time.time())
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=current_credentials.title if current_credentials else None,
|
||||
username=username,
|
||||
access_token=token_data["access_token"],
|
||||
scopes=token_data.get("scope", "").split()
|
||||
or (current_credentials.scopes if current_credentials else []),
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
# Discord tokens expire after expires_in seconds (typically 7 days)
|
||||
access_token_expires_at=(
|
||||
now + expires_in
|
||||
if (expires_in := token_data.get("expires_in", None))
|
||||
else None
|
||||
),
|
||||
# Discord doesn't provide separate refresh token expiration
|
||||
refresh_token_expires_at=None,
|
||||
)
|
||||
|
||||
if current_credentials:
|
||||
new_credentials.id = current_credentials.id
|
||||
|
||||
return new_credentials
|
||||
|
||||
async def _request_username(self, access_token: str) -> str | None:
|
||||
"""
|
||||
Fetch the username using the Discord OAuth2 @me endpoint.
|
||||
"""
|
||||
url = "https://discord.com/api/oauth2/@me"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
response = await Requests().get(url, headers=headers)
|
||||
|
||||
if not response.ok:
|
||||
return None
|
||||
|
||||
# Get user info from the response
|
||||
data = response.json()
|
||||
user_info = data.get("user", {})
|
||||
|
||||
# Return username (without discriminator)
|
||||
return user_info.get("username")
|
||||
@@ -29,7 +29,7 @@ from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.metrics import discord_send_alert
|
||||
from backend.util.metrics import DiscordChannel, discord_send_alert
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
@@ -382,8 +382,10 @@ class NotificationManager(AppService):
|
||||
}
|
||||
|
||||
@expose
|
||||
async def discord_system_alert(self, content: str):
|
||||
await discord_send_alert(content)
|
||||
async def discord_system_alert(
|
||||
self, content: str, channel: DiscordChannel = DiscordChannel.PLATFORM
|
||||
):
|
||||
await discord_send_alert(content, channel)
|
||||
|
||||
async def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
|
||||
"""Queue a scheduled notification - exposed method for other services to call"""
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
{# Agent Approved Notification Email Template #}
|
||||
{#
|
||||
Template variables:
|
||||
data.agent_name: the name of the approved agent
|
||||
data.agent_id: the ID of the agent
|
||||
data.agent_version: the version of the agent
|
||||
data.reviewer_name: the name of the reviewer who approved it
|
||||
data.reviewer_email: the email of the reviewer
|
||||
data.comments: comments from the reviewer
|
||||
data.reviewed_at: when the agent was reviewed
|
||||
data.store_url: URL to view the agent in the store
|
||||
|
||||
Subject: 🎉 Your agent '{{ data.agent_name }}' has been approved!
|
||||
#}
|
||||
|
||||
{% block content %}
|
||||
<h1 style="color: #28a745; font-size: 32px; font-weight: 700; margin: 0 0 24px 0; text-align: center;">
|
||||
🎉 Congratulations!
|
||||
</h1>
|
||||
|
||||
<p style="color: #586069; font-size: 18px; text-align: center; margin: 0 0 24px 0;">
|
||||
Your agent <strong>'{{ data.agent_name }}'</strong> has been approved and is now live in the store!
|
||||
</p>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
{% if data.comments %}
|
||||
<div style="background: #d4edda; border: 1px solid #c3e6cb; border-radius: 8px; padding: 20px; margin: 0;">
|
||||
<h3 style="color: #155724; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
||||
💬 Creator feedback area
|
||||
</h3>
|
||||
<p style="color: #155724; margin: 0; font-size: 16px; line-height: 1.5;">
|
||||
{{ data.comments }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="height: 40px; background: transparent;"></div>
|
||||
{% endif %}
|
||||
|
||||
<div style="background: #d1ecf1; border: 1px solid #bee5eb; border-radius: 8px; padding: 20px; margin: 0;">
|
||||
<h3 style="color: #0c5460; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
||||
What's Next?
|
||||
</h3>
|
||||
<ul style="color: #0c5460; margin: 0; padding-left: 18px; font-size: 16px; line-height: 1.6;">
|
||||
<li>Your agent is now live and discoverable in the AutoGPT Store</li>
|
||||
<li>Users can find, install, and run your agent</li>
|
||||
<li>You can update your agent anytime by submitting a new version</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<div style="text-align: center; margin: 24px 0;">
|
||||
<a href="{{ data.store_url }}" style="display: inline-block; background: linear-gradient(135deg, #7c3aed 0%, #5b21b6 100%); color: black; text-decoration: none; padding: 14px 28px; border-radius: 6px; font-weight: 600; font-size: 16px;">
|
||||
View Your Agent in Store
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<div style="background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 6px; padding: 16px; margin: 24px 0; text-align: center;">
|
||||
<p style="margin: 0; color: #856404; font-size: 14px;">
|
||||
<strong>💡 Pro Tip:</strong> Share your agent with the community! Post about it on social media, forums, or your blog to help more users discover and benefit from your creation.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<p style="color: #6a737d; font-size: 14px; text-align: center; margin: 24px 0;">
|
||||
Thank you for contributing to the AutoGPT ecosystem! 🚀
|
||||
</p>
|
||||
|
||||
{% endblock %}
|
||||
@@ -0,0 +1,77 @@
|
||||
{# Agent Rejected Notification Email Template #}
|
||||
{#
|
||||
Template variables:
|
||||
data.agent_name: the name of the rejected agent
|
||||
data.agent_id: the ID of the agent
|
||||
data.agent_version: the version of the agent
|
||||
data.reviewer_name: the name of the reviewer who rejected it
|
||||
data.reviewer_email: the email of the reviewer
|
||||
data.comments: comments from the reviewer explaining the rejection
|
||||
data.reviewed_at: when the agent was reviewed
|
||||
data.resubmit_url: URL to resubmit the agent
|
||||
|
||||
Subject: Your agent '{{ data.agent_name }}' needs some updates
|
||||
#}
|
||||
|
||||
|
||||
{% block content %}
|
||||
<h1 style="color: #d73a49; font-size: 32px; font-weight: 700; margin: 0 0 24px 0; text-align: center;">
|
||||
📝 Review Complete
|
||||
</h1>
|
||||
|
||||
<p style="color: #586069; font-size: 18px; text-align: center; margin: 0 0 24px 0;">
|
||||
Your agent <strong>'{{ data.agent_name }}'</strong> needs some updates before approval.
|
||||
</p>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<div style="background: #f8d7da; border: 1px solid #f5c6cb; border-radius: 8px; padding: 20px; margin: 0 0 24px 0;">
|
||||
<h3 style="color: #721c24; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
||||
💬 Creator feedback area
|
||||
</h3>
|
||||
<p style="color: #721c24; margin: 0; font-size: 16px; line-height: 1.5;">
|
||||
{{ data.comments }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="height: 40px; background: transparent;"></div>
|
||||
|
||||
<div style="background: #d4edda; border: 1px solid #c3e6cb; border-radius: 8px; padding: 20px; margin: 0 0 24px 0;">
|
||||
<h3 style="color: #155724; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
||||
☑ Steps to Resubmit:
|
||||
</h3>
|
||||
<ul style="color: #155724; margin: 0; padding-left: 18px; font-size: 16px; line-height: 1.6;">
|
||||
<li>Review the feedback provided above carefully</li>
|
||||
<li>Make the necessary updates to your agent</li>
|
||||
<li>Test your agent thoroughly to ensure it works as expected</li>
|
||||
<li>Submit your updated agent for review</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<div style="background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 6px; padding: 12px; margin: 0 0 24px 0; text-align: center;">
|
||||
<p style="margin: 0; color: #856404; font-size: 14px;">
|
||||
<strong>💡 Tip:</strong> Address all the points mentioned in the feedback to increase your chances of approval in the next review.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="text-align: center; margin: 32px 0;">
|
||||
<a href="{{ data.resubmit_url }}" style="display: inline-block; background: linear-gradient(135deg, #7c3aed 0%, #5b21b6 100%); color: black; text-decoration: none; padding: 14px 28px; border-radius: 6px; font-weight: 600; font-size: 16px;">
|
||||
Update & Resubmit Agent
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div style="background: #d1ecf1; border: 1px solid #bee5eb; border-radius: 6px; padding: 16px; margin: 24px 0;">
|
||||
<p style="margin: 0; color: #0c5460; font-size: 14px; text-align: center;">
|
||||
<strong>🌟 Don't Give Up!</strong> Many successful agents go through multiple iterations before approval. Our review team is here to help you succeed!
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<p style="color: #6a737d; font-size: 14px; text-align: center; margin: 32px 0 24px 0;">
|
||||
We're excited to see your improved agent submission! 🚀
|
||||
</p>
|
||||
|
||||
{% endblock %}
|
||||
@@ -1,9 +1,7 @@
|
||||
{# Low Balance Notification Email Template #}
|
||||
{# Template variables:
|
||||
data.agent_name: the name of the agent
|
||||
data.current_balance: the current balance of the user
|
||||
data.billing_page_link: the link to the billing page
|
||||
data.shortfall: the shortfall amount
|
||||
#}
|
||||
|
||||
<p style="
|
||||
@@ -25,7 +23,7 @@ data.shortfall: the shortfall amount
|
||||
margin-top: 0;
|
||||
margin-bottom: 20px;
|
||||
">
|
||||
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
|
||||
Your account balance has dropped below the recommended threshold.
|
||||
</p>
|
||||
|
||||
<div style="
|
||||
@@ -44,15 +42,6 @@ data.shortfall: the shortfall amount
|
||||
">
|
||||
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
|
||||
</p>
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -79,7 +68,7 @@ data.shortfall: the shortfall amount
|
||||
margin-top: 0;
|
||||
margin-bottom: 5px;
|
||||
">
|
||||
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
|
||||
Your account requires additional credits to continue running agents. Please add credits to your account to avoid service interruption.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@@ -110,5 +99,5 @@ data.shortfall: the shortfall amount
|
||||
margin-bottom: 10px;
|
||||
font-style: italic;
|
||||
">
|
||||
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
|
||||
This is an automated low balance notification. Consider adding credits soon to avoid service interruption.
|
||||
</p>
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
{# Low Balance Notification Email Template #}
|
||||
{# Template variables:
|
||||
data.agent_name: the name of the agent
|
||||
data.current_balance: the current balance of the user
|
||||
data.billing_page_link: the link to the billing page
|
||||
data.shortfall: the shortfall amount
|
||||
#}
|
||||
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
line-height: 165%;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Zero Balance Warning</strong>
|
||||
</p>
|
||||
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
line-height: 165%;
|
||||
margin-top: 0;
|
||||
margin-bottom: 20px;
|
||||
">
|
||||
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
|
||||
</p>
|
||||
|
||||
<div style="
|
||||
margin-left: 15px;
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #5D23BB;
|
||||
background-color: #f8f8ff;
|
||||
">
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
|
||||
</p>
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div style="
|
||||
margin-left: 15px;
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #FF6B6B;
|
||||
background-color: #FFF0F0;
|
||||
">
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Low Balance:</strong>
|
||||
</p>
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 5px;
|
||||
">
|
||||
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="
|
||||
text-align: center;
|
||||
margin: 30px 0;
|
||||
">
|
||||
<a href="{{ data.billing_page_link }}" style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
background-color: #5D23BB;
|
||||
color: white;
|
||||
padding: 12px 24px;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
display: inline-block;
|
||||
">
|
||||
Manage Billing
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
line-height: 150%;
|
||||
margin-top: 30px;
|
||||
margin-bottom: 10px;
|
||||
font-style: italic;
|
||||
">
|
||||
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
|
||||
</p>
|
||||
@@ -14,6 +14,8 @@ from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
CredentialsType,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
@@ -104,14 +106,39 @@ class Provider:
|
||||
)
|
||||
|
||||
def get_test_credentials(self) -> Credentials:
|
||||
"""Get test credentials for the provider."""
|
||||
return APIKeyCredentials(
|
||||
id=str(self.test_credentials_uuid),
|
||||
provider=self.name,
|
||||
api_key=SecretStr("mock-api-key"),
|
||||
title=f"Mock {self.name.title()} API key",
|
||||
expires_at=None,
|
||||
)
|
||||
"""Get test credentials for the provider based on supported auth types."""
|
||||
test_id = str(self.test_credentials_uuid)
|
||||
|
||||
# Return credentials based on the first supported auth type
|
||||
if "user_password" in self.supported_auth_types:
|
||||
return UserPasswordCredentials(
|
||||
id=test_id,
|
||||
provider=self.name,
|
||||
username=SecretStr(f"mock-{self.name}-username"),
|
||||
password=SecretStr(f"mock-{self.name}-password"),
|
||||
title=f"Mock {self.name.title()} credentials",
|
||||
)
|
||||
elif "oauth2" in self.supported_auth_types:
|
||||
return OAuth2Credentials(
|
||||
id=test_id,
|
||||
provider=self.name,
|
||||
username=f"mock-{self.name}-username",
|
||||
access_token=SecretStr(f"mock-{self.name}-access-token"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token=SecretStr(f"mock-{self.name}-refresh-token"),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[f"mock-{self.name}-scope"],
|
||||
title=f"Mock {self.name.title()} OAuth credentials",
|
||||
)
|
||||
else:
|
||||
# Default to API key credentials
|
||||
return APIKeyCredentials(
|
||||
id=test_id,
|
||||
provider=self.name,
|
||||
api_key=SecretStr(f"mock-{self.name}-api-key"),
|
||||
title=f"Mock {self.name.title()} API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
def get_api(self, credentials: Credentials) -> Any:
|
||||
"""Get API client instance for the given credentials."""
|
||||
|
||||
@@ -11,7 +11,45 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
return snapshot
|
||||
|
||||
|
||||
# Test ID constants
|
||||
TEST_USER_ID = "test-user-id"
|
||||
ADMIN_USER_ID = "admin-user-id"
|
||||
TARGET_USER_ID = "target-user-id"
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "test-user-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "admin-user-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "target-user-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
import fastapi
|
||||
|
||||
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
|
||||
return {"sub": test_user_id, "role": "user", "email": "test@example.com"}
|
||||
|
||||
return {"get_jwt_payload": override_get_jwt_payload, "user_id": test_user_id}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_admin(admin_user_id):
|
||||
"""Provide mock JWT payload for admin user testing."""
|
||||
import fastapi
|
||||
|
||||
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
|
||||
return {
|
||||
"sub": admin_user_id,
|
||||
"role": "admin",
|
||||
"email": "test-admin@example.com",
|
||||
}
|
||||
|
||||
return {"get_jwt_payload": override_get_jwt_payload, "user_id": admin_user_id}
|
||||
|
||||
@@ -58,17 +58,13 @@ class ProviderConstants(BaseModel):
|
||||
default_factory=lambda: {
|
||||
name.upper().replace("-", "_"): name for name in get_all_provider_names()
|
||||
},
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"PROVIDER_NAMES": {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"EXA": "exa",
|
||||
"GEM": "gem",
|
||||
"EXAMPLE_SERVICE": "example-service",
|
||||
}
|
||||
examples=[
|
||||
{
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"EXA": "exa",
|
||||
"GEM": "gem",
|
||||
"EXAMPLE_SERVICE": "example-service",
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@@ -3,14 +3,15 @@ import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, List, Literal
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
Security,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
@@ -50,8 +51,6 @@ from backend.util.settings import Settings
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
from ..utils import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
router = APIRouter()
|
||||
@@ -69,7 +68,7 @@ async def login(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to initiate an OAuth flow for")
|
||||
],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: Request,
|
||||
scopes: Annotated[
|
||||
str, Query(title="Comma-separated list of authorization scopes")
|
||||
@@ -109,7 +108,7 @@ async def callback(
|
||||
],
|
||||
code: Annotated[str, Body(title="Authorization code acquired by user login")],
|
||||
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: Request,
|
||||
) -> CredentialsMetaResponse:
|
||||
logger.debug(f"Received OAuth callback for provider: {provider}")
|
||||
@@ -182,7 +181,7 @@ async def callback(
|
||||
|
||||
@router.get("/credentials")
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
return [
|
||||
@@ -204,7 +203,7 @@ async def list_credentials_by_provider(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to list credentials for")
|
||||
],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
return [
|
||||
@@ -227,7 +226,7 @@ async def get_credential(
|
||||
ProviderName, Path(title="The provider to retrieve credentials for")
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Credentials:
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
@@ -244,7 +243,7 @@ async def get_credential(
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
async def create_credentials(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to create credentials for")
|
||||
],
|
||||
@@ -288,7 +287,7 @@ async def delete_credentials(
|
||||
ProviderName, Path(title="The provider to delete credentials for")
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
force: Annotated[
|
||||
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
|
||||
] = False,
|
||||
@@ -429,7 +428,7 @@ async def webhook_ingress_generic(
|
||||
@router.post("/webhooks/{webhook_id}/ping")
|
||||
async def webhook_ping(
|
||||
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
|
||||
user_id: Annotated[str, Depends(get_user_id)], # require auth
|
||||
user_id: Annotated[str, Security(get_user_id)], # require auth
|
||||
):
|
||||
webhook = await get_webhook(webhook_id)
|
||||
webhook_manager = get_webhook_manager(webhook.provider)
|
||||
@@ -568,7 +567,7 @@ def _get_provider_oauth_handler(
|
||||
|
||||
@router.get("/ayrshare/sso_url")
|
||||
async def get_ayrshare_sso_url(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> AyrshareSSOResponse:
|
||||
"""
|
||||
Generate an SSO URL for Ayrshare social media integration.
|
||||
|
||||
@@ -5,6 +5,7 @@ import pydantic
|
||||
|
||||
from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash
|
||||
from backend.data.graph import Graph
|
||||
from backend.util.timezone_name import TimeZoneName
|
||||
|
||||
|
||||
class WSMethod(enum.Enum):
|
||||
@@ -70,3 +71,12 @@ class UploadFileResponse(pydantic.BaseModel):
|
||||
size: int
|
||||
content_type: str
|
||||
expires_in_hours: int
|
||||
|
||||
|
||||
class TimezoneResponse(pydantic.BaseModel):
|
||||
# Allow "not-set" as a special value, or any valid IANA timezone
|
||||
timezone: TimeZoneName | str
|
||||
|
||||
|
||||
class UpdateTimezoneRequest(pydantic.BaseModel):
|
||||
timezone: TimeZoneName
|
||||
|
||||
@@ -3,12 +3,13 @@ import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import pydantic
|
||||
import starlette.middleware.cors
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth import verify_settings as verify_auth_settings
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
@@ -20,6 +21,8 @@ import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.builder
|
||||
import backend.server.v2.builder.routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
@@ -59,6 +62,8 @@ def launch_darkly_context():
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan_context(app: fastapi.FastAPI):
|
||||
verify_auth_settings()
|
||||
|
||||
await backend.data.db.connect()
|
||||
|
||||
# Ensure SDK auto-registration is patched before initializing blocks
|
||||
@@ -129,6 +134,9 @@ app = fastapi.FastAPI(
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
# Add 401 responses to authenticated endpoints in OpenAPI spec
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
@@ -195,6 +203,9 @@ app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/ap
|
||||
app.include_router(
|
||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.builder.routes.router, tags=["v2"], prefix="/api/builder"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.admin.store_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
@@ -365,10 +376,10 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
@staticmethod
|
||||
async def test_review_store_listing(
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
user: autogpt_libs.auth.models.User,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.server.v2.admin.store_admin_routes.review_submission(
|
||||
request.store_listing_version_id, request, user
|
||||
request.store_listing_version_id, request, user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Annotated
|
||||
|
||||
import fastapi
|
||||
import pydantic
|
||||
from autogpt_libs.auth import get_user_id
|
||||
|
||||
import backend.data.analytics
|
||||
from backend.server.utils import get_user_id
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -21,7 +21,7 @@ class LogRawMetricRequest(pydantic.BaseModel):
|
||||
|
||||
@router.post(path="/log_raw_metric")
|
||||
async def log_raw_metric(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
request: LogRawMetricRequest,
|
||||
):
|
||||
try:
|
||||
@@ -47,7 +47,7 @@ async def log_raw_metric(
|
||||
|
||||
@router.post("/log_raw_analytics")
|
||||
async def log_raw_analytics(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
type: Annotated[str, fastapi.Body(..., embed=True)],
|
||||
data: Annotated[
|
||||
dict,
|
||||
|
||||
@@ -5,18 +5,17 @@ from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.routers.analytics as analytics_routes
|
||||
from backend.server.conftest import TEST_USER_ID
|
||||
from backend.server.test_helpers import (
|
||||
assert_error_response_structure,
|
||||
assert_mock_called_with_partial,
|
||||
assert_response_status,
|
||||
safe_parse_json,
|
||||
)
|
||||
from backend.server.utils import get_user_id
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_routes.router)
|
||||
@@ -24,17 +23,20 @@ app.include_router(analytics_routes.router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
def override_get_user_id() -> str:
|
||||
"""Override get_user_id for testing"""
|
||||
return TEST_USER_ID
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
|
||||
app.dependency_overrides[get_user_id] = override_get_user_id
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_log_raw_metric_success_improved(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging with improved assertions."""
|
||||
# Mock the analytics function
|
||||
@@ -63,7 +65,7 @@ def test_log_raw_metric_success_improved(
|
||||
# Verify the function was called with correct parameters
|
||||
assert_mock_called_with_partial(
|
||||
mock_log_metric,
|
||||
user_id=TEST_USER_ID,
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
|
||||
@@ -10,8 +10,6 @@ import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.routers.analytics as analytics_routes
|
||||
from backend.server.conftest import TEST_USER_ID
|
||||
from backend.server.utils import get_user_id
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_routes.router)
|
||||
@@ -19,12 +17,14 @@ app.include_router(analytics_routes.router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
def override_get_user_id() -> str:
|
||||
"""Override get_user_id for testing"""
|
||||
return TEST_USER_ID
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
|
||||
app.dependency_overrides[get_user_id] = override_get_user_id
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -3,12 +3,11 @@ from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.routers.analytics as analytics_routes
|
||||
from backend.server.conftest import TEST_USER_ID
|
||||
from backend.server.utils import get_user_id
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_routes.router)
|
||||
@@ -16,17 +15,20 @@ app.include_router(analytics_routes.router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
def override_get_user_id() -> str:
|
||||
"""Override get_user_id for testing"""
|
||||
return TEST_USER_ID
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
|
||||
app.dependency_overrides[get_user_id] = override_get_user_id
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging"""
|
||||
|
||||
@@ -53,7 +55,7 @@ def test_log_raw_metric_success(
|
||||
|
||||
# Verify the function was called with correct parameters
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=TEST_USER_ID,
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
@@ -121,6 +123,7 @@ def test_log_raw_metric_various_values(
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging"""
|
||||
|
||||
@@ -155,7 +158,7 @@ def test_log_raw_analytics_success(
|
||||
|
||||
# Verify the function was called with correct parameters
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
TEST_USER_ID,
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs.auth.middleware import APIKeyValidator
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from backend.data.user import (
|
||||
@@ -20,18 +19,18 @@ from backend.server.routers.postmark.models import (
|
||||
PostmarkSubscriptionChangeWebhook,
|
||||
PostmarkWebhook,
|
||||
)
|
||||
from backend.server.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
postmark_validator = APIKeyValidator(
|
||||
"X-Postmark-Webhook-Token",
|
||||
settings.secrets.postmark_webhook_token,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
postmark_api_key_auth = APIKeyAuthenticator(
|
||||
"X-Postmark-Webhook-Token",
|
||||
settings.secrets.postmark_webhook_token,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/unsubscribe", summary="One Click Email Unsubscribe")
|
||||
@@ -50,7 +49,7 @@ async def unsubscribe_via_one_click(token: Annotated[str, Query()]):
|
||||
|
||||
@router.post(
|
||||
"/",
|
||||
dependencies=[Depends(postmark_validator.get_dependency())],
|
||||
dependencies=[Security(postmark_api_key_auth)],
|
||||
summary="Handle Postmark Email Webhooks",
|
||||
)
|
||||
async def postmark_webhook_handler(
|
||||
|
||||
@@ -7,16 +7,18 @@ from typing import Annotated, Any, Sequence
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
File,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
Response,
|
||||
Security,
|
||||
UploadFile,
|
||||
)
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
@@ -60,9 +62,11 @@ from backend.data.onboarding import (
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
update_user_timezone,
|
||||
)
|
||||
from backend.executor import scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
@@ -77,15 +81,20 @@ from backend.server.model import (
|
||||
ExecuteGraphResponse,
|
||||
RequestTopUp,
|
||||
SetGraphActiveVersion,
|
||||
TimezoneResponse,
|
||||
UpdatePermissionsRequest,
|
||||
UpdateTimezoneRequest,
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.server.utils import get_user_id
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.feature_flag import feature_flag
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import (
|
||||
convert_cron_to_utc,
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
|
||||
@@ -115,7 +124,7 @@ v1_router.include_router(
|
||||
backend.server.routers.analytics.router,
|
||||
prefix="/analytics",
|
||||
tags=["analytics"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
|
||||
|
||||
@@ -128,9 +137,9 @@ v1_router.include_router(
|
||||
"/auth/user",
|
||||
summary="Get or create user",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Depends(auth_middleware)):
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_create_user(user_data)
|
||||
return user.model_dump()
|
||||
|
||||
@@ -139,24 +148,53 @@ async def get_or_create_user_route(user_data: dict = Depends(auth_middleware)):
|
||||
"/auth/user/email",
|
||||
summary="Update user email",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Depends(get_user_id)], email: str = Body(...)
|
||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||
) -> dict[str, str]:
|
||||
await update_user_email(user_id, email)
|
||||
|
||||
return {"email": email}
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/auth/user/timezone",
|
||||
summary="Get user timezone",
|
||||
tags=["auth"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_user_timezone_route(
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
user = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user/timezone",
|
||||
summary="Update user timezone",
|
||||
tags=["auth"],
|
||||
dependencies=[Security(requires_user)],
|
||||
response_model=TimezoneResponse,
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/auth/user/preferences",
|
||||
summary="Get notification preferences",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_preferences(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> NotificationPreference:
|
||||
preferences = await get_user_notification_preference(user_id)
|
||||
return preferences
|
||||
@@ -166,10 +204,10 @@ async def get_preferences(
|
||||
"/auth/user/preferences",
|
||||
summary="Update notification preferences",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_preferences(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
preferences: NotificationPreferenceDTO = Body(...),
|
||||
) -> NotificationPreference:
|
||||
output = await update_user_notification_preference(user_id, preferences)
|
||||
@@ -185,9 +223,9 @@ async def update_preferences(
|
||||
"/onboarding",
|
||||
summary="Get onboarding status",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_onboarding(user_id: Annotated[str, Depends(get_user_id)]):
|
||||
async def get_onboarding(user_id: Annotated[str, Security(get_user_id)]):
|
||||
return await get_user_onboarding(user_id)
|
||||
|
||||
|
||||
@@ -195,10 +233,10 @@ async def get_onboarding(user_id: Annotated[str, Depends(get_user_id)]):
|
||||
"/onboarding",
|
||||
summary="Update onboarding progress",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_onboarding(
|
||||
user_id: Annotated[str, Depends(get_user_id)], data: UserOnboardingUpdate
|
||||
user_id: Annotated[str, Security(get_user_id)], data: UserOnboardingUpdate
|
||||
):
|
||||
return await update_user_onboarding(user_id, data)
|
||||
|
||||
@@ -207,10 +245,10 @@ async def update_onboarding(
|
||||
"/onboarding/agents",
|
||||
summary="Get recommended agents",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_onboarding_agents(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
):
|
||||
return await get_recommended_agents(user_id)
|
||||
|
||||
@@ -219,7 +257,7 @@ async def get_onboarding_agents(
|
||||
"/onboarding/enabled",
|
||||
summary="Check onboarding enabled",
|
||||
tags=["onboarding", "public"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def is_onboarding_enabled():
|
||||
return await onboarding_enabled()
|
||||
@@ -234,7 +272,7 @@ async def is_onboarding_enabled():
|
||||
path="/blocks",
|
||||
summary="List available blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
@@ -248,7 +286,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
path="/blocks/{block_id}/execute",
|
||||
summary="Execute graph block",
|
||||
tags=["blocks"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
|
||||
obj = get_block(block_id)
|
||||
@@ -265,10 +303,10 @@ async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlock
|
||||
path="/files/upload",
|
||||
summary="Upload file to cloud storage",
|
||||
tags=["files"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
file: UploadFile = File(...),
|
||||
provider: str = "gcs",
|
||||
expiration_hours: int = 24,
|
||||
@@ -356,10 +394,10 @@ async def upload_file(
|
||||
path="/credits",
|
||||
tags=["credits"],
|
||||
summary="Get user credits",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_user_credits(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> dict[str, int]:
|
||||
return {"credits": await _user_credit_model.get_credits(user_id)}
|
||||
|
||||
@@ -368,10 +406,10 @@ async def get_user_credits(
|
||||
path="/credits",
|
||||
summary="Request credit top up",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def request_top_up(
|
||||
request: RequestTopUp, user_id: Annotated[str, Depends(get_user_id)]
|
||||
request: RequestTopUp, user_id: Annotated[str, Security(get_user_id)]
|
||||
):
|
||||
checkout_url = await _user_credit_model.top_up_intent(
|
||||
user_id, request.credit_amount
|
||||
@@ -383,10 +421,10 @@ async def request_top_up(
|
||||
path="/credits/{transaction_key}/refund",
|
||||
summary="Refund credit transaction",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def refund_top_up(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
transaction_key: str,
|
||||
metadata: dict[str, str],
|
||||
) -> int:
|
||||
@@ -397,9 +435,9 @@ async def refund_top_up(
|
||||
path="/credits",
|
||||
summary="Fulfill checkout session",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def fulfill_checkout(user_id: Annotated[str, Depends(get_user_id)]):
|
||||
async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
|
||||
await _user_credit_model.fulfill_checkout(user_id=user_id)
|
||||
return Response(status_code=200)
|
||||
|
||||
@@ -408,10 +446,10 @@ async def fulfill_checkout(user_id: Annotated[str, Depends(get_user_id)]):
|
||||
path="/credits/auto-top-up",
|
||||
summary="Configure auto top up",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def configure_user_auto_top_up(
|
||||
request: AutoTopUpConfig, user_id: Annotated[str, Depends(get_user_id)]
|
||||
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> str:
|
||||
if request.threshold < 0:
|
||||
raise ValueError("Threshold must be greater than 0")
|
||||
@@ -437,10 +475,10 @@ async def configure_user_auto_top_up(
|
||||
path="/credits/auto-top-up",
|
||||
summary="Get auto top up",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_user_auto_top_up(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> AutoTopUpConfig:
|
||||
return await get_auto_top_up(user_id)
|
||||
|
||||
@@ -490,10 +528,10 @@ async def stripe_webhook(request: Request):
|
||||
path="/credits/manage",
|
||||
tags=["credits"],
|
||||
summary="Manage payment methods",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def manage_payment_method(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> dict[str, str]:
|
||||
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
|
||||
|
||||
@@ -502,10 +540,10 @@ async def manage_payment_method(
|
||||
path="/credits/transactions",
|
||||
tags=["credits"],
|
||||
summary="Get credit history",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_credit_history(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
transaction_time: datetime | None = None,
|
||||
transaction_type: str | None = None,
|
||||
transaction_count_limit: int = 100,
|
||||
@@ -525,10 +563,10 @@ async def get_credit_history(
|
||||
path="/credits/refunds",
|
||||
tags=["credits"],
|
||||
summary="Get refund requests",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[RefundRequest]:
|
||||
return await _user_credit_model.get_refund_requests(user_id)
|
||||
|
||||
@@ -546,10 +584,10 @@ class DeleteGraphResponse(TypedDict):
|
||||
path="/graphs",
|
||||
summary="List user graphs",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def list_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphMeta]:
|
||||
return await graph_db.list_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
@@ -558,17 +596,17 @@ async def list_graphs(
|
||||
path="/graphs/{graph_id}",
|
||||
summary="Get specific graph",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/versions/{version}",
|
||||
summary="Get graph version",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
version: int | None = None,
|
||||
for_export: bool = False,
|
||||
) -> graph_db.GraphModel:
|
||||
@@ -588,10 +626,10 @@ async def get_graph(
|
||||
path="/graphs/{graph_id}/versions",
|
||||
summary="Get all graph versions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
graph_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not graphs:
|
||||
@@ -603,11 +641,11 @@ async def get_graph_all_versions(
|
||||
path="/graphs",
|
||||
summary="Create new graph",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def create_new_graph(
|
||||
create_graph: CreateGraph,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> graph_db.GraphModel:
|
||||
graph = graph_db.make_graph_model(create_graph.graph, user_id)
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
@@ -624,10 +662,10 @@ async def create_new_graph(
|
||||
path="/graphs/{graph_id}",
|
||||
summary="Delete graph permanently",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def delete_graph(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
graph_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> DeleteGraphResponse:
|
||||
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
|
||||
await on_graph_deactivate(active_version, user_id=user_id)
|
||||
@@ -639,12 +677,12 @@ async def delete_graph(
|
||||
path="/graphs/{graph_id}",
|
||||
summary="Update graph version",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_graph(
|
||||
graph_id: str,
|
||||
graph: graph_db.Graph,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> graph_db.GraphModel:
|
||||
# Sanity check
|
||||
if graph.id and graph.id != graph_id:
|
||||
@@ -695,12 +733,12 @@ async def update_graph(
|
||||
path="/graphs/{graph_id}/versions/active",
|
||||
summary="Set active graph version",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def set_graph_active_version(
|
||||
graph_id: str,
|
||||
request_body: SetGraphActiveVersion,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
):
|
||||
new_active_version = request_body.active_graph_version
|
||||
new_active_graph = await graph_db.get_graph(
|
||||
@@ -734,11 +772,11 @@ async def set_graph_active_version(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
summary="Execute graph agent",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
credentials_inputs: Annotated[
|
||||
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
|
||||
@@ -780,10 +818,10 @@ async def execute_graph(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
summary="Stop graph execution",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def stop_graph_run(
|
||||
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> execution_db.GraphExecutionMeta | None:
|
||||
res = await _stop_graph_run(
|
||||
user_id=user_id,
|
||||
@@ -820,39 +858,48 @@ async def _stop_graph_run(
|
||||
|
||||
@v1_router.get(
|
||||
path="/executions",
|
||||
summary="Get all executions",
|
||||
summary="List all executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_graphs_executions(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
async def list_graphs_executions(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
return await execution_db.get_graph_executions(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions",
|
||||
summary="Get graph executions",
|
||||
summary="List graph executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_graph_executions(
|
||||
async def list_graph_executions(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
return await execution_db.get_graph_executions(graph_id=graph_id, user_id=user_id)
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
page: int = Query(1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
25, ge=1, le=100, description="Number of executions per page"
|
||||
),
|
||||
) -> execution_db.GraphExecutionsPaginated:
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
summary="Get execution details",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_graph_execution(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> execution_db.GraphExecution | execution_db.GraphExecutionWithNodes:
|
||||
graph = await graph_db.get_graph(graph_id=graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
@@ -877,12 +924,12 @@ async def get_graph_execution(
|
||||
path="/executions/{graph_exec_id}",
|
||||
summary="Delete graph execution",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_graph_execution(
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> None:
|
||||
await execution_db.delete_graph_execution(
|
||||
graph_exec_id=graph_exec_id, user_id=user_id
|
||||
@@ -906,10 +953,10 @@ class ScheduleCreationRequest(pydantic.BaseModel):
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
summary="Create execution schedule",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def create_graph_execution_schedule(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
graph_id: str = Path(..., description="ID of the graph to schedule"),
|
||||
schedule_params: ScheduleCreationRequest = Body(),
|
||||
) -> scheduler.GraphExecutionJobInfo:
|
||||
@@ -924,53 +971,99 @@ async def create_graph_execution_schedule(
|
||||
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
|
||||
)
|
||||
|
||||
return await get_scheduler_client().add_execution_schedule(
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert cron expression from user timezone to UTC
|
||||
try:
|
||||
utc_cron = convert_cron_to_utc(schedule_params.cron, user_timezone)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid cron expression for timezone {user_timezone}: {e}",
|
||||
)
|
||||
|
||||
result = await get_scheduler_client().add_execution_schedule(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
name=schedule_params.name,
|
||||
cron=schedule_params.cron,
|
||||
cron=utc_cron, # Send UTC cron to scheduler
|
||||
input_data=schedule_params.inputs,
|
||||
input_credentials=schedule_params.credentials,
|
||||
)
|
||||
|
||||
# Convert the next_run_time back to user timezone for display
|
||||
if result.next_run_time:
|
||||
result.next_run_time = convert_utc_time_to_user_timezone(
|
||||
result.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
summary="List execution schedules for a graph",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def list_graph_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
graph_id: str = Path(),
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await get_scheduler_client().get_execution_schedules(
|
||||
schedules = await get_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
|
||||
# Get user timezone for conversion
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert next_run_time to user timezone for display
|
||||
for schedule in schedules:
|
||||
if schedule.next_run_time:
|
||||
schedule.next_run_time = convert_utc_time_to_user_timezone(
|
||||
schedule.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
return schedules
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/schedules",
|
||||
summary="List execution schedules for a user",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def list_all_graphs_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
schedules = await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
|
||||
# Get user timezone for conversion
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert UTC next_run_time to user timezone for display
|
||||
for schedule in schedules:
|
||||
if schedule.next_run_time:
|
||||
schedule.next_run_time = convert_utc_time_to_user_timezone(
|
||||
schedule.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
return schedules
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
path="/schedules/{schedule_id}",
|
||||
summary="Delete execution schedule",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def delete_graph_execution_schedule(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
schedule_id: str = Path(..., description="ID of the schedule to delete"),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
@@ -993,10 +1086,10 @@ async def delete_graph_execution_schedule(
|
||||
summary="Create new API key",
|
||||
response_model=CreateAPIKeyResponse,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def create_api_key(
|
||||
request: CreateAPIKeyRequest, user_id: Annotated[str, Depends(get_user_id)]
|
||||
request: CreateAPIKeyRequest, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> CreateAPIKeyResponse:
|
||||
"""Create a new API key"""
|
||||
try:
|
||||
@@ -1024,10 +1117,10 @@ async def create_api_key(
|
||||
summary="List user API keys",
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_api_keys(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[APIKeyWithoutHash]:
|
||||
"""List all API keys for the user"""
|
||||
try:
|
||||
@@ -1045,10 +1138,10 @@ async def get_api_keys(
|
||||
summary="Get specific API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_api_key(
|
||||
key_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> APIKeyWithoutHash:
|
||||
"""Get a specific API key"""
|
||||
try:
|
||||
@@ -1069,10 +1162,10 @@ async def get_api_key(
|
||||
summary="Revoke API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def delete_api_key(
|
||||
key_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""Revoke an API key"""
|
||||
try:
|
||||
@@ -1097,10 +1190,10 @@ async def delete_api_key(
|
||||
summary="Suspend API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def suspend_key(
|
||||
key_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""Suspend an API key"""
|
||||
try:
|
||||
@@ -1122,12 +1215,12 @@ async def suspend_key(
|
||||
summary="Update key permissions",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_permissions(
|
||||
key_id: str,
|
||||
request: UpdatePermissionsRequest,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
"""Update API key permissions"""
|
||||
try:
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
@@ -14,9 +13,7 @@ from pytest_snapshot.plugin import Snapshot
|
||||
import backend.server.routers.v1 as v1_routes
|
||||
from backend.data.credit import AutoTopUpConfig
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.server.conftest import TEST_USER_ID
|
||||
from backend.server.routers.v1 import upload_file
|
||||
from backend.server.utils import get_user_id
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(v1_routes.v1_router)
|
||||
@@ -24,31 +21,26 @@ app.include_router(v1_routes.v1_router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
def override_auth_middleware(request: fastapi.Request) -> dict[str, str]:
|
||||
"""Override auth middleware for testing"""
|
||||
return {"sub": TEST_USER_ID, "role": "user", "email": "test@example.com"}
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
|
||||
def override_get_user_id() -> str:
|
||||
"""Override get_user_id for testing"""
|
||||
return TEST_USER_ID
|
||||
|
||||
|
||||
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
|
||||
override_auth_middleware
|
||||
)
|
||||
app.dependency_overrides[get_user_id] = override_get_user_id
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# Auth endpoints tests
|
||||
def test_get_or_create_user_route(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test get or create user endpoint"""
|
||||
mock_user = Mock()
|
||||
mock_user.model_dump.return_value = {
|
||||
"id": TEST_USER_ID,
|
||||
"id": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
@@ -263,6 +255,7 @@ def test_get_auto_top_up(
|
||||
def test_get_graphs(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test get graphs endpoint"""
|
||||
mock_graph = GraphModel(
|
||||
@@ -271,7 +264,7 @@ def test_get_graphs(
|
||||
is_active=True,
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id="test-user-id",
|
||||
user_id=test_user_id,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
@@ -296,6 +289,7 @@ def test_get_graphs(
|
||||
def test_get_graph(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test get single graph endpoint"""
|
||||
mock_graph = GraphModel(
|
||||
@@ -304,7 +298,7 @@ def test_get_graph(
|
||||
is_active=True,
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id="test-user-id",
|
||||
user_id=test_user_id,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
@@ -343,6 +337,7 @@ def test_get_graph_not_found(
|
||||
def test_delete_graph(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test delete graph endpoint"""
|
||||
# Mock active graph for deactivation
|
||||
@@ -352,7 +347,7 @@ def test_delete_graph(
|
||||
is_active=True,
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id="test-user-id",
|
||||
user_id=test_user_id,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
@@ -399,7 +394,7 @@ def test_missing_required_field() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_success():
|
||||
async def test_upload_file_success(test_user_id: str):
|
||||
"""Test successful file upload."""
|
||||
# Create mock upload file
|
||||
file_content = b"test file content"
|
||||
@@ -425,7 +420,7 @@ async def test_upload_file_success():
|
||||
|
||||
result = await upload_file(
|
||||
file=upload_file_mock,
|
||||
user_id="test-user-123",
|
||||
user_id=test_user_id,
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
)
|
||||
@@ -446,12 +441,12 @@ async def test_upload_file_success():
|
||||
filename="test.txt",
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
user_id="test-user-123",
|
||||
user_id=test_user_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_no_filename():
|
||||
async def test_upload_file_no_filename(test_user_id: str):
|
||||
"""Test file upload without filename."""
|
||||
file_content = b"test content"
|
||||
file_obj = BytesIO(file_content)
|
||||
@@ -476,7 +471,7 @@ async def test_upload_file_no_filename():
|
||||
|
||||
upload_file_mock.read = AsyncMock(return_value=file_content)
|
||||
|
||||
result = await upload_file(file=upload_file_mock, user_id="test-user-123")
|
||||
result = await upload_file(file=upload_file_mock, user_id=test_user_id)
|
||||
|
||||
assert result.file_name == "uploaded_file"
|
||||
assert result.content_type == "application/octet-stream"
|
||||
@@ -486,7 +481,7 @@ async def test_upload_file_no_filename():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_invalid_expiration():
|
||||
async def test_upload_file_invalid_expiration(test_user_id: str):
|
||||
"""Test file upload with invalid expiration hours."""
|
||||
file_obj = BytesIO(b"content")
|
||||
upload_file_mock = UploadFile(
|
||||
@@ -498,7 +493,7 @@ async def test_upload_file_invalid_expiration():
|
||||
# Test expiration too short
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await upload_file(
|
||||
file=upload_file_mock, user_id="test-user-123", expiration_hours=0
|
||||
file=upload_file_mock, user_id=test_user_id, expiration_hours=0
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "between 1 and 48" in exc_info.value.detail
|
||||
@@ -506,14 +501,14 @@ async def test_upload_file_invalid_expiration():
|
||||
# Test expiration too long
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await upload_file(
|
||||
file=upload_file_mock, user_id="test-user-123", expiration_hours=49
|
||||
file=upload_file_mock, user_id=test_user_id, expiration_hours=49
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "between 1 and 48" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_virus_scan_failure():
|
||||
async def test_upload_file_virus_scan_failure(test_user_id: str):
|
||||
"""Test file upload when virus scan fails."""
|
||||
file_content = b"malicious content"
|
||||
file_obj = BytesIO(file_content)
|
||||
@@ -530,11 +525,11 @@ async def test_upload_file_virus_scan_failure():
|
||||
upload_file_mock.read = AsyncMock(return_value=file_content)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Virus detected!"):
|
||||
await upload_file(file=upload_file_mock, user_id="test-user-123")
|
||||
await upload_file(file=upload_file_mock, user_id=test_user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_cloud_storage_failure():
|
||||
async def test_upload_file_cloud_storage_failure(test_user_id: str):
|
||||
"""Test file upload when cloud storage fails."""
|
||||
file_content = b"test content"
|
||||
file_obj = BytesIO(file_content)
|
||||
@@ -556,11 +551,11 @@ async def test_upload_file_cloud_storage_failure():
|
||||
upload_file_mock.read = AsyncMock(return_value=file_content)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Storage error!"):
|
||||
await upload_file(file=upload_file_mock, user_id="test-user-123")
|
||||
await upload_file(file=upload_file_mock, user_id=test_user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_size_limit_exceeded():
|
||||
async def test_upload_file_size_limit_exceeded(test_user_id: str):
|
||||
"""Test file upload when file size exceeds the limit."""
|
||||
# Create a file that exceeds the default 256MB limit
|
||||
large_file_content = b"x" * (257 * 1024 * 1024) # 257MB
|
||||
@@ -574,14 +569,14 @@ async def test_upload_file_size_limit_exceeded():
|
||||
upload_file_mock.read = AsyncMock(return_value=large_file_content)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await upload_file(file=upload_file_mock, user_id="test-user-123")
|
||||
await upload_file(file=upload_file_mock, user_id=test_user_id)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "exceeds the maximum allowed size of 256MB" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_gcs_not_configured_fallback():
|
||||
async def test_upload_file_gcs_not_configured_fallback(test_user_id: str):
|
||||
"""Test file upload fallback to base64 when GCS is not configured."""
|
||||
file_content = b"test file content"
|
||||
file_obj = BytesIO(file_content)
|
||||
@@ -602,7 +597,7 @@ async def test_upload_file_gcs_not_configured_fallback():
|
||||
|
||||
upload_file_mock.read = AsyncMock(return_value=file_content)
|
||||
|
||||
result = await upload_file(file=upload_file_mock, user_id="test-user-123")
|
||||
result = await upload_file(file=upload_file_mock, user_id=test_user_id)
|
||||
|
||||
# Verify fallback behavior
|
||||
assert result.file_name == "test.txt"
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
"""Common test fixtures with proper setup and teardown."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_db_connection() -> AsyncGenerator[Prisma, None]:
|
||||
"""Provide a test database connection with proper cleanup.
|
||||
|
||||
This fixture ensures the database connection is properly
|
||||
closed after the test, even if the test fails.
|
||||
"""
|
||||
db = Prisma()
|
||||
try:
|
||||
await db.connect()
|
||||
yield db
|
||||
finally:
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_transaction():
|
||||
"""Mock database transaction with proper async context manager."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_context(*args, **kwargs):
|
||||
yield None
|
||||
|
||||
with patch("backend.data.db.locked_transaction", side_effect=mock_context) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_app_state():
|
||||
"""Fixture that ensures app state is isolated between tests."""
|
||||
# Example: Save original state
|
||||
# from backend.server.app import app
|
||||
# original_overrides = app.dependency_overrides.copy()
|
||||
|
||||
# try:
|
||||
# yield app
|
||||
# finally:
|
||||
# # Restore original state
|
||||
# app.dependency_overrides = original_overrides
|
||||
|
||||
# For now, just yield None as this is an example
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_files():
|
||||
"""Fixture to track and cleanup files created during tests."""
|
||||
created_files = []
|
||||
|
||||
def track_file(filepath: str):
|
||||
created_files.append(filepath)
|
||||
|
||||
yield track_file
|
||||
|
||||
# Cleanup
|
||||
import os
|
||||
|
||||
for filepath in created_files:
|
||||
try:
|
||||
if os.path.exists(filepath):
|
||||
os.remove(filepath)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to cleanup {filepath}: {e}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_mock_with_cleanup():
|
||||
"""Create async mocks that are properly cleaned up."""
|
||||
mocks = []
|
||||
|
||||
def create_mock(**kwargs):
|
||||
mock = Mock(**kwargs)
|
||||
mocks.append(mock)
|
||||
return mock
|
||||
|
||||
yield create_mock
|
||||
|
||||
# Reset all mocks
|
||||
for mock in mocks:
|
||||
mock.reset_mock()
|
||||
|
||||
|
||||
class TestDatabaseIsolation:
|
||||
"""Example of proper test isolation with database operations."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_and_teardown(self, test_db_connection):
|
||||
"""Setup and teardown for each test method."""
|
||||
# Setup: Clear test data
|
||||
await test_db_connection.user.delete_many(
|
||||
where={"email": {"contains": "@test.example"}}
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
# Teardown: Clear test data again
|
||||
await test_db_connection.user.delete_many(
|
||||
where={"email": {"contains": "@test.example"}}
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def test_create_user(self, test_db_connection):
|
||||
"""Test that demonstrates proper isolation."""
|
||||
# This test has access to a clean database
|
||||
user = await test_db_connection.user.create(
|
||||
data={
|
||||
"id": "test-user-id",
|
||||
"email": "test@test.example",
|
||||
"name": "Test User",
|
||||
}
|
||||
)
|
||||
assert user.email == "test@test.example"
|
||||
# User will be cleaned up automatically
|
||||
|
||||
|
||||
@pytest.fixture(scope="function") # Explicitly use function scope
|
||||
def reset_singleton_state():
|
||||
"""Reset singleton state between tests."""
|
||||
# Example: Reset a singleton instance
|
||||
# from backend.data.some_singleton import SingletonClass
|
||||
|
||||
# # Save original state
|
||||
# original_instance = getattr(SingletonClass, "_instance", None)
|
||||
|
||||
# try:
|
||||
# # Clear singleton
|
||||
# SingletonClass._instance = None
|
||||
# yield
|
||||
# finally:
|
||||
# # Restore original state
|
||||
# SingletonClass._instance = original_instance
|
||||
|
||||
# For now, just yield None as this is an example
|
||||
yield None
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Common test utilities and constants for server tests."""
|
||||
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
# Test ID constants
|
||||
TEST_USER_ID = "test-user-id"
|
||||
ADMIN_USER_ID = "admin-user-id"
|
||||
TARGET_USER_ID = "target-user-id"
|
||||
|
||||
# Common test data constants
|
||||
FIXED_TIMESTAMP = "2024-01-01T00:00:00Z"
|
||||
TRANSACTION_UUID = "transaction-123-uuid"
|
||||
METRIC_UUID = "metric-123-uuid"
|
||||
ANALYTICS_UUID = "analytics-123-uuid"
|
||||
|
||||
|
||||
def create_mock_with_id(mock_id: str) -> Mock:
|
||||
"""Create a mock object with an id attribute.
|
||||
|
||||
Args:
|
||||
mock_id: The ID value to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object with id attribute set
|
||||
"""
|
||||
return Mock(id=mock_id)
|
||||
|
||||
|
||||
def assert_status_and_parse_json(
|
||||
response: Any, expected_status: int = 200
|
||||
) -> Dict[str, Any]:
|
||||
"""Assert response status and return parsed JSON.
|
||||
|
||||
Args:
|
||||
response: The HTTP response object
|
||||
expected_status: Expected status code (default: 200)
|
||||
|
||||
Returns:
|
||||
Parsed JSON response data
|
||||
|
||||
Raises:
|
||||
AssertionError: If status code doesn't match expected
|
||||
"""
|
||||
assert (
|
||||
response.status_code == expected_status
|
||||
), f"Expected status {expected_status}, got {response.status_code}: {response.text}"
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string",
|
||||
[
|
||||
(100, "api_calls_count", "external_api"),
|
||||
(0, "error_count", "no_errors"),
|
||||
(-5.2, "temperature_delta", "cooling"),
|
||||
(1.23456789, "precision_test", "float_precision"),
|
||||
(999999999, "large_number", "max_value"),
|
||||
],
|
||||
)
|
||||
def parametrized_metric_values_decorator(func):
|
||||
"""Decorator for parametrized metric value tests."""
|
||||
return pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string",
|
||||
[
|
||||
(100, "api_calls_count", "external_api"),
|
||||
(0, "error_count", "no_errors"),
|
||||
(-5.2, "temperature_delta", "cooling"),
|
||||
(1.23456789, "precision_test", "float_precision"),
|
||||
(999999999, "large_number", "max_value"),
|
||||
],
|
||||
)(func)
|
||||
@@ -1,11 +0,0 @@
|
||||
from autogpt_libs.auth.depends import requires_user
|
||||
from autogpt_libs.auth.models import User
|
||||
from fastapi import Depends
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_user_id(user: User = Depends(requires_user)) -> str:
|
||||
return user.user_id
|
||||
109
autogpt_platform/backend/backend/server/utils/api_key_auth.py
Normal file
109
autogpt_platform/backend/backend/server/utils/api_key_auth.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
API Key authentication utilities for FastAPI applications.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import secrets
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
class APIKeyAuthenticator(APIKeyHeader):
|
||||
"""
|
||||
Configurable API key authenticator for FastAPI applications,
|
||||
with support for custom validation functions.
|
||||
|
||||
This class provides a flexible way to implement API key authentication with optional
|
||||
custom validation logic. It can be used for simple token matching
|
||||
or more complex validation scenarios like database lookups.
|
||||
|
||||
Examples:
|
||||
Simple token validation:
|
||||
```python
|
||||
api_key_auth = APIKeyAuthenticator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="your-secret-token"
|
||||
)
|
||||
|
||||
@app.get("/protected", dependencies=[Security(api_key_auth)])
|
||||
def protected_endpoint():
|
||||
return {"message": "Access granted"}
|
||||
```
|
||||
|
||||
Custom validation with database lookup:
|
||||
```python
|
||||
async def validate_with_db(api_key: str):
|
||||
api_key_obj = await db.get_api_key(api_key)
|
||||
return api_key_obj if api_key_obj and api_key_obj.is_active else None
|
||||
|
||||
api_key_auth = APIKeyAuthenticator(
|
||||
header_name="X-API-Key",
|
||||
validator=validate_with_db
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
header_name (str): The name of the header containing the API key
|
||||
expected_token (Optional[str]): The expected API key value for simple token matching
|
||||
validator (Optional[Callable]): Custom validation function that takes an API key
|
||||
string and returns a boolean or object. Can be async.
|
||||
status_if_missing (int): HTTP status code to use for validation errors
|
||||
message_if_invalid (str): Error message to return when validation fails
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
header_name: str,
|
||||
expected_token: Optional[str] = None,
|
||||
validator: Optional[Callable[[str], bool]] = None,
|
||||
status_if_missing: int = HTTP_401_UNAUTHORIZED,
|
||||
message_if_invalid: str = "Invalid API key",
|
||||
):
|
||||
super().__init__(
|
||||
name=header_name,
|
||||
scheme_name=f"{__class__.__name__}-{header_name}",
|
||||
auto_error=False,
|
||||
)
|
||||
self.expected_token = expected_token
|
||||
self.custom_validator = validator
|
||||
self.status_if_missing = status_if_missing
|
||||
self.message_if_invalid = message_if_invalid
|
||||
|
||||
async def __call__(self, request: Request) -> Any:
|
||||
api_key = await super()(request)
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=self.status_if_missing, detail="No API key in request"
|
||||
)
|
||||
|
||||
# Use custom validation if provided, otherwise use default equality check
|
||||
validator = self.custom_validator or self.default_validator
|
||||
result = (
|
||||
await validator(api_key)
|
||||
if inspect.iscoroutinefunction(validator)
|
||||
else validator(api_key)
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=self.status_if_missing, detail=self.message_if_invalid
|
||||
)
|
||||
|
||||
# Store validation result in request state if it's not just a boolean
|
||||
if result is not True:
|
||||
request.state.api_key = result
|
||||
|
||||
return result
|
||||
|
||||
async def default_validator(self, api_key: str) -> bool:
|
||||
if not self.expected_token:
|
||||
raise MissingConfigError(
|
||||
f"{self.__class__.__name__}.expected_token is not set; "
|
||||
"either specify it or provide a custom validator"
|
||||
)
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
@@ -83,7 +83,7 @@ class AutoModManager:
|
||||
f"Moderating inputs for graph execution {graph_exec.graph_exec_id}"
|
||||
)
|
||||
try:
|
||||
moderation_passed = await self._moderate_content(
|
||||
moderation_passed, content_id = await self._moderate_content(
|
||||
content,
|
||||
{
|
||||
"user_id": graph_exec.user_id,
|
||||
@@ -99,7 +99,7 @@ class AutoModManager:
|
||||
)
|
||||
# Update node statuses for frontend display before raising error
|
||||
await self._update_failed_nodes_for_moderation(
|
||||
db_client, graph_exec.graph_exec_id, "input"
|
||||
db_client, graph_exec.graph_exec_id, "input", content_id
|
||||
)
|
||||
|
||||
return ModerationError(
|
||||
@@ -107,6 +107,7 @@ class AutoModManager:
|
||||
user_id=graph_exec.user_id,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
moderation_type="input",
|
||||
content_id=content_id,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -146,8 +147,10 @@ class AutoModManager:
|
||||
return None
|
||||
|
||||
# Get completed executions and collect outputs
|
||||
completed_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=[ExecutionStatus.COMPLETED], include_exec_data=True
|
||||
completed_executions = await db_client.get_node_executions( # type: ignore
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
include_exec_data=True,
|
||||
)
|
||||
|
||||
if not completed_executions:
|
||||
@@ -167,7 +170,7 @@ class AutoModManager:
|
||||
# Run moderation
|
||||
logger.warning(f"Moderating outputs for graph execution {graph_exec_id}")
|
||||
try:
|
||||
moderation_passed = await self._moderate_content(
|
||||
moderation_passed, content_id = await self._moderate_content(
|
||||
content,
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -181,7 +184,7 @@ class AutoModManager:
|
||||
logger.warning(f"Moderation failed for graph execution {graph_exec_id}")
|
||||
# Update node statuses for frontend display before raising error
|
||||
await self._update_failed_nodes_for_moderation(
|
||||
db_client, graph_exec_id, "output"
|
||||
db_client, graph_exec_id, "output", content_id
|
||||
)
|
||||
|
||||
return ModerationError(
|
||||
@@ -189,6 +192,7 @@ class AutoModManager:
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
moderation_type="output",
|
||||
content_id=content_id,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -212,10 +216,11 @@ class AutoModManager:
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
graph_exec_id: str,
|
||||
moderation_type: Literal["input", "output"],
|
||||
content_id: str | None = None,
|
||||
):
|
||||
"""Update node execution statuses for frontend display when moderation fails"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.executor.manager import send_async_execution_update
|
||||
from backend.util.clients import get_async_execution_event_bus
|
||||
|
||||
if moderation_type == "input":
|
||||
# For input moderation, mark queued/running/incomplete nodes as failed
|
||||
@@ -229,13 +234,20 @@ class AutoModManager:
|
||||
target_statuses = [ExecutionStatus.COMPLETED]
|
||||
|
||||
# Get the executions that need to be updated
|
||||
executions_to_update = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=target_statuses, include_exec_data=True
|
||||
executions_to_update = await db_client.get_node_executions( # type: ignore
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=target_statuses,
|
||||
include_exec_data=True,
|
||||
)
|
||||
|
||||
if not executions_to_update:
|
||||
return
|
||||
|
||||
# Create error message with content_id if available
|
||||
error_message = "Failed due to content moderation"
|
||||
if content_id:
|
||||
error_message += f" (Moderation ID: {content_id})"
|
||||
|
||||
# Prepare database update tasks
|
||||
exec_updates = []
|
||||
for exec_entry in executions_to_update:
|
||||
@@ -245,11 +257,11 @@ class AutoModManager:
|
||||
|
||||
if exec_entry.input_data:
|
||||
for name in exec_entry.input_data.keys():
|
||||
cleared_inputs[name] = ["Failed due to content moderation"]
|
||||
cleared_inputs[name] = [error_message]
|
||||
|
||||
if exec_entry.output_data:
|
||||
for name in exec_entry.output_data.keys():
|
||||
cleared_outputs[name] = ["Failed due to content moderation"]
|
||||
cleared_outputs[name] = [error_message]
|
||||
|
||||
# Add update task to list
|
||||
exec_updates.append(
|
||||
@@ -257,7 +269,7 @@ class AutoModManager:
|
||||
exec_entry.node_exec_id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats={
|
||||
"error": "Failed due to content moderation",
|
||||
"error": error_message,
|
||||
"cleared_inputs": cleared_inputs,
|
||||
"cleared_outputs": cleared_outputs,
|
||||
},
|
||||
@@ -268,19 +280,24 @@ class AutoModManager:
|
||||
updated_execs = await asyncio.gather(*exec_updates)
|
||||
|
||||
# Send all websocket updates in parallel
|
||||
event_bus = get_async_execution_event_bus()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
send_async_execution_update(updated_exec)
|
||||
event_bus.publish(updated_exec)
|
||||
for updated_exec in updated_execs
|
||||
if updated_exec is not None
|
||||
]
|
||||
)
|
||||
|
||||
async def _moderate_content(self, content: str, metadata: dict[str, Any]) -> bool:
|
||||
async def _moderate_content(
|
||||
self, content: str, metadata: dict[str, Any]
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Moderate content using AutoMod API
|
||||
|
||||
Returns:
|
||||
True: Content approved or timeout occurred
|
||||
False: Content rejected by moderation
|
||||
Tuple of (approval_status, content_id)
|
||||
- approval_status: True if approved or timeout occurred, False if rejected
|
||||
- content_id: Reference ID from moderation API, or None if not available
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: When moderation times out (should be bypassed)
|
||||
@@ -298,12 +315,12 @@ class AutoModManager:
|
||||
logger.debug(
|
||||
f"Content approved for {metadata.get('graph_exec_id', 'unknown')}"
|
||||
)
|
||||
return True
|
||||
return True, response.content_id
|
||||
else:
|
||||
reasons = [r.reason for r in response.moderation_results if r.reason]
|
||||
error_msg = f"Content rejected by AutoMod: {'; '.join(reasons)}"
|
||||
logger.warning(f"Content rejected: {error_msg}")
|
||||
return False
|
||||
return False, response.content_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Re-raise timeout to be handled by calling methods
|
||||
@@ -313,7 +330,7 @@ class AutoModManager:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"AutoMod moderation error: {e}")
|
||||
return self.config.fail_open
|
||||
return self.config.fail_open, None
|
||||
|
||||
async def _make_request(self, request_data: AutoModRequest) -> AutoModResponse:
|
||||
"""Make HTTP request to AutoMod API using the standard request utility"""
|
||||
|
||||
@@ -26,6 +26,9 @@ class AutoModResponse(BaseModel):
|
||||
"""Response model for AutoMod API"""
|
||||
|
||||
success: bool = Field(..., description="Whether the request was successful")
|
||||
content_id: str = Field(
|
||||
..., description="Unique reference ID for this moderation request"
|
||||
)
|
||||
status: str = Field(
|
||||
..., description="Overall status: 'approved', 'rejected', 'flagged', 'pending'"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from autogpt_libs.auth import requires_admin_user
|
||||
from autogpt_libs.auth.depends import get_user_id
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Body, Security
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import admin_get_user_history, get_user_credit_model
|
||||
@@ -18,7 +17,7 @@ _user_credit_model = get_user_credit_model()
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["credits", "admin"],
|
||||
dependencies=[Depends(requires_admin_user)],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@@ -29,18 +28,16 @@ async def add_user_credits(
|
||||
user_id: typing.Annotated[str, Body()],
|
||||
amount: typing.Annotated[int, Body()],
|
||||
comments: typing.Annotated[str, Body()],
|
||||
admin_user: typing.Annotated[
|
||||
str,
|
||||
Depends(get_user_id),
|
||||
],
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
):
|
||||
""" """
|
||||
logger.info(f"Admin user {admin_user} is adding {amount} credits to user {user_id}")
|
||||
logger.info(
|
||||
f"Admin user {admin_user_id} is adding {amount} credits to user {user_id}"
|
||||
)
|
||||
new_balance, transaction_key = await _user_credit_model._add_transaction(
|
||||
user_id,
|
||||
amount,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
metadata=SafeJson({"admin_id": admin_user, "reason": comments}),
|
||||
metadata=SafeJson({"admin_id": admin_user_id, "reason": comments}),
|
||||
)
|
||||
return {
|
||||
"new_balance": new_balance,
|
||||
@@ -54,17 +51,14 @@ async def add_user_credits(
|
||||
summary="Get All Users History",
|
||||
)
|
||||
async def admin_get_all_user_history(
|
||||
admin_user: typing.Annotated[
|
||||
str,
|
||||
Depends(get_user_id),
|
||||
],
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
transaction_filter: typing.Optional[CreditTransactionType] = None,
|
||||
):
|
||||
""" """
|
||||
logger.info(f"Admin user {admin_user} is getting grant history")
|
||||
logger.info(f"Admin user {admin_user_id} is getting grant history")
|
||||
|
||||
try:
|
||||
resp = await admin_get_user_history(
|
||||
@@ -73,7 +67,7 @@ async def admin_get_all_user_history(
|
||||
search=search,
|
||||
transaction_filter=transaction_filter,
|
||||
)
|
||||
logger.info(f"Admin user {admin_user} got {len(resp.history)} grant history")
|
||||
logger.info(f"Admin user {admin_user_id} got {len(resp.history)} grant history")
|
||||
return resp
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting grant history: {e}")
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import autogpt_libs.auth
|
||||
import autogpt_libs.auth.depends
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from prisma import Json
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
|
||||
import backend.server.v2.admin.model as admin_model
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.server.conftest import ADMIN_USER_ID, TARGET_USER_ID
|
||||
from backend.util.models import Pagination
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
@@ -22,25 +21,19 @@ app.include_router(credit_admin_routes.router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
def override_requires_admin_user() -> dict[str, str]:
|
||||
"""Override admin user check for testing"""
|
||||
return {"sub": ADMIN_USER_ID, "role": "admin"}
|
||||
|
||||
|
||||
def override_get_user_id() -> str:
|
||||
"""Override get_user_id for testing"""
|
||||
return ADMIN_USER_ID
|
||||
|
||||
|
||||
app.dependency_overrides[autogpt_libs.auth.requires_admin_user] = (
|
||||
override_requires_admin_user
|
||||
)
|
||||
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = override_get_user_id
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_add_user_credits_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
admin_user_id: str,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful credit addition by admin"""
|
||||
# Mock the credit model
|
||||
@@ -52,7 +45,7 @@ def test_add_user_credits_success(
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"user_id": TARGET_USER_ID,
|
||||
"user_id": target_user_id,
|
||||
"amount": 500,
|
||||
"comments": "Test credit grant for debugging",
|
||||
}
|
||||
@@ -67,12 +60,12 @@ def test_add_user_credits_success(
|
||||
# Verify the function was called with correct parameters
|
||||
mock_credit_model._add_transaction.assert_called_once()
|
||||
call_args = mock_credit_model._add_transaction.call_args
|
||||
assert call_args[0] == (TARGET_USER_ID, 500)
|
||||
assert call_args[0] == (target_user_id, 500)
|
||||
assert call_args[1]["transaction_type"] == prisma.enums.CreditTransactionType.GRANT
|
||||
# Check that metadata is a Json object with the expected content
|
||||
assert isinstance(call_args[1]["metadata"], Json)
|
||||
assert call_args[1]["metadata"] == Json(
|
||||
{"admin_id": ADMIN_USER_ID, "reason": "Test credit grant for debugging"}
|
||||
{"admin_id": admin_user_id, "reason": "Test credit grant for debugging"}
|
||||
)
|
||||
|
||||
# Snapshot test the response
|
||||
@@ -290,18 +283,10 @@ def test_add_credits_invalid_request() -> None:
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_admin_endpoints_require_admin_role(mocker: pytest_mock.MockFixture) -> None:
|
||||
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that admin endpoints require admin role"""
|
||||
# Clear the admin override to test authorization
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
# Mock requires_admin_user to raise an exception
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.requires_admin_user",
|
||||
side_effect=fastapi.HTTPException(
|
||||
status_code=403, detail="Admin access required"
|
||||
),
|
||||
)
|
||||
# Simulate regular non-admin user
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
# Test add_credits endpoint
|
||||
response = client.post(
|
||||
@@ -312,20 +297,8 @@ def test_admin_endpoints_require_admin_role(mocker: pytest_mock.MockFixture) ->
|
||||
"comments": "test",
|
||||
},
|
||||
)
|
||||
assert (
|
||||
response.status_code == 401
|
||||
) # Auth middleware returns 401 when auth is disabled
|
||||
assert response.status_code == 403
|
||||
|
||||
# Test users_history endpoint
|
||||
response = client.get("/admin/users_history")
|
||||
assert (
|
||||
response.status_code == 401
|
||||
) # Auth middleware returns 401 when auth is disabled
|
||||
|
||||
# Restore the override
|
||||
app.dependency_overrides[autogpt_libs.auth.requires_admin_user] = (
|
||||
override_requires_admin_user
|
||||
)
|
||||
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = (
|
||||
override_get_user_id
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -2,26 +2,28 @@ import logging
|
||||
import tempfile
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(prefix="/admin", tags=["store", "admin"])
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["store", "admin"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
@@ -66,15 +68,11 @@ async def get_admin_listings_with_versions(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
user: typing.Annotated[
|
||||
autogpt_libs.auth.models.User,
|
||||
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
|
||||
],
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Review a store listing submission.
|
||||
@@ -82,7 +80,7 @@ async def review_submission(
|
||||
Args:
|
||||
store_listing_version_id: ID of the submission to review
|
||||
request: Review details including approval status and comments
|
||||
user: Authenticated admin user performing the review
|
||||
user_id: Authenticated admin user performing the review
|
||||
|
||||
Returns:
|
||||
StoreSubmission with updated review information
|
||||
@@ -93,7 +91,7 @@ async def review_submission(
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user.user_id,
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
return submission
|
||||
except Exception as e:
|
||||
@@ -108,13 +106,9 @@ async def review_submission(
|
||||
"/submissions/download/{store_listing_version_id}",
|
||||
summary="Admin Download Agent File",
|
||||
tags=["store", "admin"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
async def admin_download_agent_file(
|
||||
user: typing.Annotated[
|
||||
autogpt_libs.auth.models.User,
|
||||
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
|
||||
],
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
@@ -132,7 +126,7 @@ async def admin_download_agent_file(
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await backend.server.v2.store.db.get_agent_as_admin(
|
||||
user_id=user.user_id,
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
376
autogpt_platform/backend/backend/server/v2/builder/db.py
Normal file
376
autogpt_platform/backend/backend/server/v2/builder/db.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import functools
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import prisma
|
||||
|
||||
import backend.data.block
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
||||
from backend.data.credit import get_block_costs
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.v2.builder.model import (
|
||||
BlockCategoryResponse,
|
||||
BlockData,
|
||||
BlockResponse,
|
||||
BlockType,
|
||||
CountResponse,
|
||||
Provider,
|
||||
ProviderResponse,
|
||||
SearchBlocksResponse,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
_static_counts_cache: dict | None = None
|
||||
_suggested_blocks: list[BlockData] | None = None
|
||||
|
||||
|
||||
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
|
||||
categories: dict[BlockCategory, BlockCategoryResponse] = {}
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip blocks that don't have categories (all should have at least one)
|
||||
if not block.categories:
|
||||
continue
|
||||
|
||||
# Add block to the categories
|
||||
for category in block.categories:
|
||||
if category not in categories:
|
||||
categories[category] = BlockCategoryResponse(
|
||||
name=category.name.lower(),
|
||||
total_blocks=0,
|
||||
blocks=[],
|
||||
)
|
||||
|
||||
categories[category].total_blocks += 1
|
||||
|
||||
# Append if the category has less than the specified number of blocks
|
||||
if len(categories[category].blocks) < category_blocks:
|
||||
categories[category].blocks.append(block.to_dict())
|
||||
|
||||
# Sort categories by name
|
||||
return sorted(categories.values(), key=lambda x: x.name)
|
||||
|
||||
|
||||
def get_blocks(
|
||||
*,
|
||||
category: str | None = None,
|
||||
type: BlockType | None = None,
|
||||
provider: ProviderName | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> BlockResponse:
|
||||
"""
|
||||
Get blocks based on either category, type or provider.
|
||||
Providing nothing fetches all block types.
|
||||
"""
|
||||
# Only one of category, type, or provider can be specified
|
||||
if (category and type) or (category and provider) or (type and provider):
|
||||
raise ValueError("Only one of category, type, or provider can be specified")
|
||||
|
||||
blocks: list[Block[BlockSchema, BlockSchema]] = []
|
||||
skip = (page - 1) * page_size
|
||||
take = page_size
|
||||
total = 0
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip blocks that don't match the category
|
||||
if category and category not in {c.name.lower() for c in block.categories}:
|
||||
continue
|
||||
# Skip blocks that don't match the type
|
||||
if (
|
||||
(type == "input" and block.block_type.value != "Input")
|
||||
or (type == "output" and block.block_type.value != "Output")
|
||||
or (type == "action" and block.block_type.value in ("Input", "Output"))
|
||||
):
|
||||
continue
|
||||
# Skip blocks that don't match the provider
|
||||
if provider:
|
||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||
if not any(provider in info.provider for info in credentials_info):
|
||||
continue
|
||||
|
||||
total += 1
|
||||
if skip > 0:
|
||||
skip -= 1
|
||||
continue
|
||||
if take > 0:
|
||||
take -= 1
|
||||
blocks.append(block)
|
||||
|
||||
costs = get_block_costs()
|
||||
|
||||
return BlockResponse(
|
||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def search_blocks(
|
||||
include_blocks: bool = True,
|
||||
include_integrations: bool = True,
|
||||
query: str = "",
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> SearchBlocksResponse:
|
||||
"""
|
||||
Get blocks based on the filter and query.
|
||||
`providers` only applies for `integrations` filter.
|
||||
"""
|
||||
blocks: list[Block[BlockSchema, BlockSchema]] = []
|
||||
query = query.lower()
|
||||
|
||||
total = 0
|
||||
skip = (page - 1) * page_size
|
||||
take = page_size
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip blocks that don't match the query
|
||||
if (
|
||||
query not in block.name.lower()
|
||||
and query not in block.description.lower()
|
||||
and not _matches_llm_model(block.input_schema, query)
|
||||
):
|
||||
continue
|
||||
keep = False
|
||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||
if include_integrations and len(credentials) > 0:
|
||||
keep = True
|
||||
integration_count += 1
|
||||
if include_blocks and len(credentials) == 0:
|
||||
keep = True
|
||||
block_count += 1
|
||||
|
||||
if not keep:
|
||||
continue
|
||||
|
||||
total += 1
|
||||
if skip > 0:
|
||||
skip -= 1
|
||||
continue
|
||||
if take > 0:
|
||||
take -= 1
|
||||
blocks.append(block)
|
||||
|
||||
costs = get_block_costs()
|
||||
|
||||
return SearchBlocksResponse(
|
||||
blocks=BlockResponse(
|
||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
),
|
||||
total_block_count=block_count,
|
||||
total_integration_count=integration_count,
|
||||
)
|
||||
|
||||
|
||||
def get_providers(
|
||||
query: str = "",
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> ProviderResponse:
|
||||
providers = []
|
||||
query = query.lower()
|
||||
|
||||
skip = (page - 1) * page_size
|
||||
take = page_size
|
||||
|
||||
all_providers = _get_all_providers()
|
||||
|
||||
for provider in all_providers.values():
|
||||
if (
|
||||
query not in provider.name.value.lower()
|
||||
and query not in provider.description.lower()
|
||||
):
|
||||
continue
|
||||
if skip > 0:
|
||||
skip -= 1
|
||||
continue
|
||||
if take > 0:
|
||||
take -= 1
|
||||
providers.append(provider)
|
||||
|
||||
total = len(all_providers)
|
||||
|
||||
return ProviderResponse(
|
||||
providers=providers,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def get_counts(user_id: str) -> CountResponse:
|
||||
my_agents = await prisma.models.LibraryAgent.prisma().count(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
}
|
||||
)
|
||||
counts = await _get_static_counts()
|
||||
return CountResponse(
|
||||
my_agents=my_agents,
|
||||
**counts,
|
||||
)
|
||||
|
||||
|
||||
async def _get_static_counts():
|
||||
"""
|
||||
Get counts of blocks, integrations, and marketplace agents.
|
||||
This is cached to avoid unnecessary database queries and calculations.
|
||||
Can't use functools.cache here because the function is async.
|
||||
"""
|
||||
global _static_counts_cache
|
||||
if _static_counts_cache is not None:
|
||||
return _static_counts_cache
|
||||
|
||||
all_blocks = 0
|
||||
input_blocks = 0
|
||||
action_blocks = 0
|
||||
output_blocks = 0
|
||||
integrations = 0
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
all_blocks += 1
|
||||
|
||||
if block.block_type.value == "Input":
|
||||
input_blocks += 1
|
||||
elif block.block_type.value == "Output":
|
||||
output_blocks += 1
|
||||
else:
|
||||
action_blocks += 1
|
||||
|
||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||
if len(credentials) > 0:
|
||||
integrations += 1
|
||||
|
||||
marketplace_agents = await prisma.models.StoreAgent.prisma().count()
|
||||
|
||||
_static_counts_cache = {
|
||||
"all_blocks": all_blocks,
|
||||
"input_blocks": input_blocks,
|
||||
"action_blocks": action_blocks,
|
||||
"output_blocks": output_blocks,
|
||||
"integrations": integrations,
|
||||
"marketplace_agents": marketplace_agents,
|
||||
}
|
||||
|
||||
return _static_counts_cache
|
||||
|
||||
|
||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
for field in schema_cls.model_fields.values():
|
||||
if field.annotation == LlmModel:
|
||||
# Check if query matches any value in llm_models
|
||||
if any(query in name for name in llm_models):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
providers: dict[ProviderName, Provider] = {}
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||
for info in credentials_info:
|
||||
for provider in info.provider: # provider is a ProviderName enum member
|
||||
if provider in providers:
|
||||
providers[provider].integration_count += 1
|
||||
else:
|
||||
providers[provider] = Provider(
|
||||
name=provider, description="", integration_count=1
|
||||
)
|
||||
return providers
|
||||
|
||||
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockData]:
|
||||
global _suggested_blocks
|
||||
|
||||
if _suggested_blocks is not None and len(_suggested_blocks) >= count:
|
||||
return _suggested_blocks[:count]
|
||||
|
||||
_suggested_blocks = []
|
||||
# Sum the number of executions for each block type
|
||||
# Prisma cannot group by nested relations, so we do a raw query
|
||||
# Calculate the cutoff timestamp
|
||||
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
|
||||
results = await prisma.get_client().query_raw(
|
||||
"""
|
||||
SELECT
|
||||
agent_node."agentBlockId" AS block_id,
|
||||
COUNT(execution.id) AS execution_count
|
||||
FROM "AgentNodeExecution" execution
|
||||
JOIN "AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
|
||||
WHERE execution."endedTime" >= $1::timestamp
|
||||
GROUP BY agent_node."agentBlockId"
|
||||
ORDER BY execution_count DESC;
|
||||
""",
|
||||
timestamp_threshold,
|
||||
)
|
||||
|
||||
# Get the top blocks based on execution count
|
||||
# But ignore Input and Output blocks
|
||||
blocks: list[tuple[BlockData, int]] = []
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
if block.disabled or block.block_type in (
|
||||
backend.data.block.BlockType.INPUT,
|
||||
backend.data.block.BlockType.OUTPUT,
|
||||
backend.data.block.BlockType.AGENT,
|
||||
):
|
||||
continue
|
||||
# Find the execution count for this block
|
||||
execution_count = next(
|
||||
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
||||
0,
|
||||
)
|
||||
blocks.append((block.to_dict(), execution_count))
|
||||
# Sort blocks by execution count
|
||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
_suggested_blocks = [block[0] for block in blocks]
|
||||
|
||||
# Return the top blocks
|
||||
return _suggested_blocks[:count]
|
||||
87
autogpt_platform/backend/backend/server/v2/builder/model.py
Normal file
87
autogpt_platform/backend/backend/server/v2/builder/model.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
FilterType = Literal[
|
||||
"blocks",
|
||||
"integrations",
|
||||
"marketplace_agents",
|
||||
"my_agents",
|
||||
]
|
||||
|
||||
BlockType = Literal["all", "input", "action", "output"]
|
||||
|
||||
BlockData = dict[str, Any]
|
||||
|
||||
|
||||
# Suggestions
|
||||
class SuggestionsResponse(BaseModel):
|
||||
otto_suggestions: list[str]
|
||||
recent_searches: list[str]
|
||||
providers: list[ProviderName]
|
||||
top_blocks: list[BlockData]
|
||||
|
||||
|
||||
# All blocks
|
||||
class BlockCategoryResponse(BaseModel):
|
||||
name: str
|
||||
total_blocks: int
|
||||
blocks: list[BlockData]
|
||||
|
||||
model_config = {"use_enum_values": False} # <== use enum names like "AI"
|
||||
|
||||
|
||||
# Input/Action/Output and see all for block categories
|
||||
class BlockResponse(BaseModel):
|
||||
blocks: list[BlockData]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
# Providers
|
||||
class Provider(BaseModel):
|
||||
name: ProviderName
|
||||
description: str
|
||||
integration_count: int
|
||||
|
||||
|
||||
class ProviderResponse(BaseModel):
|
||||
providers: list[Provider]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
# Search
|
||||
class SearchRequest(BaseModel):
|
||||
search_query: str | None = None
|
||||
filter: list[FilterType] | None = None
|
||||
by_creator: list[str] | None = None
|
||||
search_id: str | None = None
|
||||
page: int | None = None
|
||||
page_size: int | None = None
|
||||
|
||||
|
||||
class SearchBlocksResponse(BaseModel):
|
||||
blocks: BlockResponse
|
||||
total_block_count: int
|
||||
total_integration_count: int
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
items: list[BlockData | library_model.LibraryAgent | store_model.StoreAgent]
|
||||
total_items: dict[FilterType, int]
|
||||
page: int
|
||||
more_pages: bool
|
||||
|
||||
|
||||
class CountResponse(BaseModel):
|
||||
all_blocks: int
|
||||
input_blocks: int
|
||||
action_blocks: int
|
||||
output_blocks: int
|
||||
integrations: int
|
||||
marketplace_agents: int
|
||||
my_agents: int
|
||||
233
autogpt_platform/backend/backend/server/v2/builder/routes.py
Normal file
233
autogpt_platform/backend/backend/server/v2/builder/routes.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import logging
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
|
||||
import backend.server.v2.builder.db as builder_db
|
||||
import backend.server.v2.builder.model as builder_model
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.db as store_db
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
dependencies=[fastapi.Security(requires_user)],
|
||||
)
|
||||
|
||||
|
||||
# Taken from backend/server/v2/store/db.py
|
||||
def sanitize_query(query: str | None) -> str | None:
|
||||
if query is None:
|
||||
return query
|
||||
query = query.strip()[:100]
|
||||
return (
|
||||
query.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
.replace("[", "\\[")
|
||||
.replace("]", "\\]")
|
||||
.replace("'", "\\'")
|
||||
.replace('"', '\\"')
|
||||
.replace(";", "\\;")
|
||||
.replace("--", "\\--")
|
||||
.replace("/*", "\\/*")
|
||||
.replace("*/", "\\*/")
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggestions",
|
||||
summary="Get Builder suggestions",
|
||||
response_model=builder_model.SuggestionsResponse,
|
||||
)
|
||||
async def get_suggestions() -> builder_model.SuggestionsResponse:
|
||||
"""
|
||||
Get all suggestions for the Blocks Menu.
|
||||
"""
|
||||
return builder_model.SuggestionsResponse(
|
||||
otto_suggestions=[
|
||||
"What blocks do I need to get started?",
|
||||
"Help me create a list",
|
||||
"Help me feed my data to Google Maps",
|
||||
],
|
||||
recent_searches=[
|
||||
"image generation",
|
||||
"deepfake",
|
||||
"competitor analysis",
|
||||
],
|
||||
providers=[
|
||||
ProviderName.TWITTER,
|
||||
ProviderName.GITHUB,
|
||||
ProviderName.NOTION,
|
||||
ProviderName.GOOGLE,
|
||||
ProviderName.DISCORD,
|
||||
ProviderName.GOOGLE_MAPS,
|
||||
],
|
||||
top_blocks=await builder_db.get_suggested_blocks(),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/categories",
|
||||
summary="Get Builder block categories",
|
||||
response_model=Sequence[builder_model.BlockCategoryResponse],
|
||||
)
|
||||
async def get_block_categories(
|
||||
blocks_per_category: Annotated[int, fastapi.Query()] = 3,
|
||||
) -> Sequence[builder_model.BlockCategoryResponse]:
|
||||
"""
|
||||
Get all block categories with a specified number of blocks per category.
|
||||
"""
|
||||
return builder_db.get_block_categories(blocks_per_category)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/blocks",
|
||||
summary="Get Builder blocks",
|
||||
response_model=builder_model.BlockResponse,
|
||||
)
|
||||
async def get_blocks(
|
||||
category: Annotated[str | None, fastapi.Query()] = None,
|
||||
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
|
||||
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
||||
page: Annotated[int, fastapi.Query()] = 1,
|
||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||
) -> builder_model.BlockResponse:
|
||||
"""
|
||||
Get blocks based on either category, type, or provider.
|
||||
"""
|
||||
return builder_db.get_blocks(
|
||||
category=category,
|
||||
type=type,
|
||||
provider=provider,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
summary="Get Builder integration providers",
|
||||
response_model=builder_model.ProviderResponse,
|
||||
)
|
||||
async def get_providers(
|
||||
page: Annotated[int, fastapi.Query()] = 1,
|
||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||
) -> builder_model.ProviderResponse:
|
||||
"""
|
||||
Get all integration providers with their block counts.
|
||||
"""
|
||||
return builder_db.get_providers(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search",
|
||||
summary="Builder search",
|
||||
tags=["store", "private"],
|
||||
response_model=builder_model.SearchResponse,
|
||||
)
|
||||
async def search(
|
||||
options: builder_model.SearchRequest,
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
) -> builder_model.SearchResponse:
|
||||
"""
|
||||
Search for blocks (including integrations), marketplace agents, and user library agents.
|
||||
"""
|
||||
# If no filters are provided, then we will return all types
|
||||
if not options.filter:
|
||||
options.filter = [
|
||||
"blocks",
|
||||
"integrations",
|
||||
"marketplace_agents",
|
||||
"my_agents",
|
||||
]
|
||||
options.search_query = sanitize_query(options.search_query)
|
||||
options.page = options.page or 1
|
||||
options.page_size = options.page_size or 50
|
||||
|
||||
# Blocks&Integrations
|
||||
blocks = builder_model.SearchBlocksResponse(
|
||||
blocks=builder_model.BlockResponse(
|
||||
blocks=[],
|
||||
pagination=Pagination.empty(),
|
||||
),
|
||||
total_block_count=0,
|
||||
total_integration_count=0,
|
||||
)
|
||||
if "blocks" in options.filter or "integrations" in options.filter:
|
||||
blocks = builder_db.search_blocks(
|
||||
include_blocks="blocks" in options.filter,
|
||||
include_integrations="integrations" in options.filter,
|
||||
query=options.search_query or "",
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
)
|
||||
|
||||
# Library Agents
|
||||
my_agents = library_model.LibraryAgentResponse(
|
||||
agents=[],
|
||||
pagination=Pagination.empty(),
|
||||
)
|
||||
if "my_agents" in options.filter:
|
||||
my_agents = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=options.search_query,
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
)
|
||||
|
||||
# Marketplace Agents
|
||||
marketplace_agents = store_model.StoreAgentsResponse(
|
||||
agents=[],
|
||||
pagination=Pagination.empty(),
|
||||
)
|
||||
if "marketplace_agents" in options.filter:
|
||||
marketplace_agents = await store_db.get_store_agents(
|
||||
creators=options.by_creator,
|
||||
search_query=options.search_query,
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
)
|
||||
|
||||
more_pages = False
|
||||
if (
|
||||
blocks.blocks.pagination.current_page < blocks.blocks.pagination.total_pages
|
||||
or my_agents.pagination.current_page < my_agents.pagination.total_pages
|
||||
or marketplace_agents.pagination.current_page
|
||||
< marketplace_agents.pagination.total_pages
|
||||
):
|
||||
more_pages = True
|
||||
|
||||
return builder_model.SearchResponse(
|
||||
items=blocks.blocks.blocks + my_agents.agents + marketplace_agents.agents,
|
||||
total_items={
|
||||
"blocks": blocks.total_block_count,
|
||||
"integrations": blocks.total_integration_count,
|
||||
"marketplace_agents": marketplace_agents.pagination.total_items,
|
||||
"my_agents": my_agents.pagination.total_items,
|
||||
},
|
||||
page=options.page,
|
||||
more_pages=more_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/counts",
|
||||
summary="Get Builder item counts",
|
||||
response_model=builder_model.CountResponse,
|
||||
)
|
||||
async def get_counts(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
) -> builder_model.CountResponse:
|
||||
"""
|
||||
Get item counts for the menu categories in the Blocks Menu.
|
||||
"""
|
||||
return await builder_db.get_counts(user_id)
|
||||
@@ -51,6 +51,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
description: str
|
||||
|
||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
output_schema: dict[str, Any]
|
||||
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
||||
description="Input schema for credentials required by the agent",
|
||||
)
|
||||
@@ -126,6 +127,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
credentials_input_schema=(
|
||||
graph.credentials_input_schema if sub_graphs is not None else None
|
||||
),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user