mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-11 16:18:07 -05:00
Compare commits
109 Commits
update-ins
...
toggle-cor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50f39e55f7 | ||
|
|
e70c970ab6 | ||
|
|
3bbce71678 | ||
|
|
34fbf4377f | ||
|
|
f682ef885a | ||
|
|
2ffd249aac | ||
|
|
986245ec43 | ||
|
|
f89717153f | ||
|
|
5da41e0753 | ||
|
|
cddeb185a8 | ||
|
|
08a3fd6d26 | ||
|
|
39b30bc82c | ||
|
|
2df0e2b750 | ||
|
|
925f249ce1 | ||
|
|
e8cf3edbf4 | ||
|
|
dc03ea718c | ||
|
|
dbee580d80 | ||
|
|
0325ec0a2c | ||
|
|
3952a1a226 | ||
|
|
cfc975d39b | ||
|
|
46e0f6cc45 | ||
|
|
c03af5c196 | ||
|
|
00cbfb8f80 | ||
|
|
3beafae955 | ||
|
|
9cd186a2f3 | ||
|
|
dcf26bd3d4 | ||
|
|
b97f097c9d | ||
|
|
ce24975a9d | ||
|
|
75c90e49ce | ||
|
|
2e38f132e7 | ||
|
|
5be6987d58 | ||
|
|
b59592be9b | ||
|
|
7ea17df9ed | ||
|
|
22e692bdda | ||
|
|
901e9eba5d | ||
|
|
bbd6709bd6 | ||
|
|
3ec1721d6d | ||
|
|
8c8a2ab0c2 | ||
|
|
4041e1f39c | ||
|
|
7cbdc1ad1a | ||
|
|
2a19aa0ed3 | ||
|
|
6d39dfe382 | ||
|
|
57ecc10535 | ||
|
|
4928ce3f90 | ||
|
|
e16e69ca55 | ||
|
|
4bcc73f784 | ||
|
|
c8240a4d6b | ||
|
|
f669db4a10 | ||
|
|
0e755a5c85 | ||
|
|
dfdc71f97f | ||
|
|
def008408c | ||
|
|
916d0adabb | ||
|
|
417ee7f0e1 | ||
|
|
ae4c9897b4 | ||
|
|
7544028b94 | ||
|
|
ad49946890 | ||
|
|
b333d56492 | ||
|
|
b23cb14e49 | ||
|
|
e71a44521a | ||
|
|
15aa091f65 | ||
|
|
3e4ca19036 | ||
|
|
a595da02f7 | ||
|
|
12cdd45551 | ||
|
|
df3c81a7a6 | ||
|
|
0f477e2392 | ||
|
|
82618aede0 | ||
|
|
b713093276 | ||
|
|
3718b948ea | ||
|
|
44d739386b | ||
|
|
533d2d0277 | ||
|
|
c6821484c7 | ||
|
|
fecbd3042d | ||
|
|
c0172c93aa | ||
|
|
8a68e03eb1 | ||
|
|
da585a34e1 | ||
|
|
bc43d05cac | ||
|
|
469b1fccbb | ||
|
|
1aa7e10cbd | ||
|
|
890bb3b8b4 | ||
|
|
2bb8e91040 | ||
|
|
76090f0ba2 | ||
|
|
5b12e02c4e | ||
|
|
e0520f5e0a | ||
|
|
a9530b7304 | ||
|
|
8a9c165faf | ||
|
|
476bfc6c84 | ||
|
|
61f17e5b97 | ||
|
|
a54bed6d68 | ||
|
|
5502256bea | ||
|
|
bd97727763 | ||
|
|
aa256f21cd | ||
|
|
2848e62f8a | ||
|
|
fa5ff9ca3c | ||
|
|
7c908c10b8 | ||
|
|
68cd1cb398 | ||
|
|
f4538d6f5a | ||
|
|
4589b15450 | ||
|
|
ccc4d0dc6c | ||
|
|
c4483fa6c7 | ||
|
|
c2af8c1a6a | ||
|
|
2610c4579f | ||
|
|
0c09b0c459 | ||
|
|
1105e6c0d2 | ||
|
|
483c399812 | ||
|
|
260dd526c9 | ||
|
|
75a159db01 | ||
|
|
62032e6584 | ||
|
|
105d5dc7e9 | ||
|
|
92515b3683 |
244
.github/copilot-instructions.md
vendored
Normal file
244
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,244 @@
|
||||
# GitHub Copilot Instructions for AutoGPT
|
||||
|
||||
This file provides comprehensive onboarding information for GitHub Copilot coding agent to work efficiently with the AutoGPT repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
**AutoGPT** is a powerful platform for creating, deploying, and managing continuous AI agents that automate complex workflows. This is a large monorepo (~150MB) containing multiple components:
|
||||
|
||||
- **AutoGPT Platform** (`autogpt_platform/`) - Main focus: Modern AI agent platform (Polyform Shield License)
|
||||
- **Classic AutoGPT** (`classic/`) - Legacy agent system (MIT License)
|
||||
- **Documentation** (`docs/`) - MkDocs-based documentation site
|
||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
||||
|
||||
**Primary Languages & Frameworks:**
|
||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
||||
|
||||
## Build and Validation Instructions
|
||||
|
||||
### Essential Setup Commands
|
||||
|
||||
**Always run these commands in the correct directory and in this order:**
|
||||
|
||||
1. **Initial Setup** (required once):
|
||||
```bash
|
||||
# Clone and enter repository
|
||||
git clone <repo> && cd AutoGPT
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
2. **Backend Setup** (always run before backend development):
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry install # Install dependencies
|
||||
poetry run prisma migrate dev # Run database migrations
|
||||
poetry run prisma generate # Generate Prisma client
|
||||
```
|
||||
|
||||
3. **Frontend Setup** (always run before frontend development):
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm install # Install dependencies
|
||||
```
|
||||
|
||||
### Runtime Requirements
|
||||
|
||||
**Critical:** Always ensure Docker services are running before starting development:
|
||||
```bash
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
**Python Version:** Use Python 3.11 (required; managed by Poetry via pyproject.toml)
|
||||
**Node.js Version:** Use Node.js 21+ with pnpm package manager
|
||||
|
||||
### Development Commands
|
||||
|
||||
**Backend Development:**
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run serve # Start development server (port 8000)
|
||||
poetry run test # Run all tests (requires ~5 minutes)
|
||||
poetry run pytest path/to/test.py # Run specific test
|
||||
poetry run format # Format code (Black + isort) - always run first
|
||||
poetry run lint # Lint code (ruff) - run after format
|
||||
```
|
||||
|
||||
**Frontend Development:**
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm dev # Start development server (port 3000) - use for active development
|
||||
pnpm build # Build for production (only needed for E2E tests or deployment)
|
||||
pnpm test # Run Playwright E2E tests (requires build first)
|
||||
pnpm test-ui # Run tests with UI
|
||||
pnpm format # Format and lint code
|
||||
pnpm storybook # Start component development server
|
||||
```
|
||||
|
||||
### Testing Strategy
|
||||
|
||||
**Backend Tests:**
|
||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
||||
|
||||
**Frontend Tests:**
|
||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
||||
- **Component Tests**: Use Storybook for isolated component development
|
||||
|
||||
### Critical Validation Steps
|
||||
|
||||
**Before committing changes:**
|
||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
||||
2. Ensure all tests pass in modified areas
|
||||
3. Verify Docker services are still running
|
||||
4. Check that database migrations apply cleanly
|
||||
|
||||
**Common Issues & Workarounds:**
|
||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
||||
- **Permission errors**: Ensure Docker has proper permissions
|
||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
||||
- **Test timeouts**: Backend tests can take 5+ minutes, use `-x` flag to stop on first failure
|
||||
|
||||
## Project Layout & Architecture
|
||||
|
||||
### Core Architecture
|
||||
|
||||
**AutoGPT Platform** (`autogpt_platform/`):
|
||||
- `backend/` - FastAPI server with async support
|
||||
- `backend/backend/` - Core API logic
|
||||
- `backend/blocks/` - Agent execution blocks
|
||||
- `backend/data/` - Database models and schemas
|
||||
- `schema.prisma` - Database schema definition
|
||||
- `frontend/` - Next.js application
|
||||
- `src/app/` - App Router pages and layouts
|
||||
- `src/components/` - Reusable React components
|
||||
- `src/lib/` - Utilities and configurations
|
||||
- `autogpt_libs/` - Shared Python utilities
|
||||
- `docker-compose.yml` - Development stack orchestration
|
||||
|
||||
**Key Configuration Files:**
|
||||
- `pyproject.toml` - Python dependencies and tooling
|
||||
- `package.json` - Node.js dependencies and scripts
|
||||
- `schema.prisma` - Database schema and migrations
|
||||
- `next.config.mjs` - Next.js configuration
|
||||
- `tailwind.config.ts` - Styling configuration
|
||||
|
||||
### Security & Middleware
|
||||
|
||||
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
|
||||
**Authentication**: JWT-based with Supabase integration
|
||||
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
|
||||
|
||||
### Development Workflow
|
||||
|
||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
||||
- `platform-backend-ci.yml` - Backend testing and validation
|
||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
||||
|
||||
**Pre-commit Hooks**: Run linting and formatting checks
|
||||
**Conventional Commits**: Use format `type(scope): description` (e.g., `feat(backend): add API`)
|
||||
|
||||
### Key Source Files
|
||||
|
||||
**Backend Entry Points:**
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
**Frontend Entry Points:**
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
|
||||
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
|
||||
|
||||
### Agent Block System
|
||||
|
||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
||||
- Block definition with input/output schemas
|
||||
- Execution logic with proper error handling
|
||||
- Tests validating functionality
|
||||
|
||||
### Database & ORM
|
||||
|
||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
||||
- Schema in `schema.prisma`
|
||||
- Migrations in `backend/migrations/`
|
||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
### Configuration Files Priority Order
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
### Docker Environment Setup
|
||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Copy `.env.default` files to `.env` for local development customization
|
||||
|
||||
## Advanced Development Patterns
|
||||
|
||||
### Adding New Blocks
|
||||
1. Create file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class with input/output schemas
|
||||
3. Implement `run` method with proper error handling
|
||||
4. Generate block UUID using `uuid.uuid4()`
|
||||
5. Register in block registry
|
||||
6. Write tests alongside block implementation
|
||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
||||
|
||||
### API Development
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
4. For `data/*.py` changes, validate user ID checks
|
||||
5. Run `poetry run test` to verify changes
|
||||
|
||||
### Frontend Development
|
||||
1. Components in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for component development
|
||||
4. Test user-facing features with Playwright E2E tests
|
||||
5. Update protected routes in middleware when needed
|
||||
|
||||
### Security Guidelines
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
- Prevents sensitive data caching in browsers/proxies
|
||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
||||
|
||||
### CI/CD Alignment
|
||||
The repository has comprehensive CI workflows that test:
|
||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
||||
- **Integration**: Full-stack type checking and E2E testing
|
||||
|
||||
Match these patterns when developing locally - the copilot setup environment mirrors these CI configurations.
|
||||
|
||||
## Collaboration with Other AI Assistants
|
||||
|
||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
||||
- Check for existing CLAUDE.md files that provide additional context
|
||||
- Follow established patterns and conventions already in the codebase
|
||||
- Maintain consistency with existing code style and architecture
|
||||
- Consider that changes may be reviewed and extended by both human developers and AI assistants
|
||||
|
||||
## Trust These Instructions
|
||||
|
||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
||||
1. Information here is incomplete for your specific task
|
||||
2. You encounter errors not covered by the workarounds
|
||||
3. You need to understand implementation details not covered above
|
||||
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
Normal file
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
Normal file
@@ -0,0 +1,97 @@
|
||||
name: Auto Fix CI Failures
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["CI"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
actions: read
|
||||
issues: write
|
||||
id-token: write # Required for OIDC token exchange
|
||||
|
||||
jobs:
|
||||
auto-fix:
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'failure' &&
|
||||
github.event.workflow_run.pull_requests[0] &&
|
||||
!startsWith(github.event.workflow_run.head_branch, 'claude-auto-fix-ci-')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.workflow_run.head_branch }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Setup git identity
|
||||
run: |
|
||||
git config --global user.email "claude[bot]@users.noreply.github.com"
|
||||
git config --global user.name "claude[bot]"
|
||||
|
||||
- name: Create fix branch
|
||||
id: branch
|
||||
run: |
|
||||
BRANCH_NAME="claude-auto-fix-ci-${{ github.event.workflow_run.head_branch }}-${{ github.run_id }}"
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const jobs = await github.rest.actions.listJobsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const failedJobs = jobs.data.jobs.filter(job => job.conclusion === 'failure');
|
||||
|
||||
let errorLogs = [];
|
||||
for (const job of failedJobs) {
|
||||
const logs = await github.rest.actions.downloadJobLogsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
job_id: job.id
|
||||
});
|
||||
errorLogs.push({
|
||||
jobName: job.name,
|
||||
logs: logs.data
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
runUrl: run.data.html_url,
|
||||
failedJobs: failedJobs.map(j => j.name),
|
||||
errorLogs: errorLogs
|
||||
};
|
||||
|
||||
- name: Fix CI failures with Claude
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
prompt: |
|
||||
/fix-ci
|
||||
Failed CI Run: ${{ fromJSON(steps.failure_details.outputs.result).runUrl }}
|
||||
Failed Jobs: ${{ join(fromJSON(steps.failure_details.outputs.result).failedJobs, ', ') }}
|
||||
PR Number: ${{ github.event.workflow_run.pull_requests[0].number }}
|
||||
Branch Name: ${{ steps.branch.outputs.branch_name }}
|
||||
Base Branch: ${{ github.event.workflow_run.head_branch }}
|
||||
Repository: ${{ github.repository }}
|
||||
|
||||
Error logs:
|
||||
${{ toJSON(fromJSON(steps.failure_details.outputs.result).errorLogs) }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: "--allowedTools 'Edit,MultiEdit,Write,Read,Glob,Grep,LS,Bash(git:*),Bash(bun:*),Bash(npm:*),Bash(npx:*),Bash(gh:*)'"
|
||||
379
.github/workflows/claude-dependabot.yml
vendored
Normal file
379
.github/workflows/claude-dependabot.yml
vendored
Normal file
@@ -0,0 +1,379 @@
|
||||
# Claude Dependabot PR Review Workflow
|
||||
#
|
||||
# This workflow automatically runs Claude analysis on Dependabot PRs to:
|
||||
# - Identify dependency changes and their versions
|
||||
# - Look up changelogs for updated packages
|
||||
# - Assess breaking changes and security impacts
|
||||
# - Provide actionable recommendations for the development team
|
||||
#
|
||||
# Triggered on: Dependabot PRs (opened, synchronize)
|
||||
# Requirements: ANTHROPIC_API_KEY secret must be configured
|
||||
|
||||
name: Claude Dependabot PR Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
|
||||
jobs:
|
||||
dependabot-review:
|
||||
# Only run on Dependabot PRs
|
||||
if: github.actor == 'dependabot[bot]'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
|
||||
- name: Run Claude Dependabot Analysis
|
||||
id: claude_review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
You are Claude, an AI assistant specialized in reviewing Dependabot dependency update PRs.
|
||||
|
||||
Your primary tasks are:
|
||||
1. **Analyze the dependency changes** in this Dependabot PR
|
||||
2. **Look up changelogs** for all updated dependencies to understand what changed
|
||||
3. **Identify breaking changes** and assess potential impact on the AutoGPT codebase
|
||||
4. **Provide actionable recommendations** for the development team
|
||||
|
||||
## Analysis Process:
|
||||
|
||||
1. **Identify Changed Dependencies**:
|
||||
- Use git diff to see what dependencies were updated
|
||||
- Parse package.json, poetry.lock, requirements files, etc.
|
||||
- List all package versions: old → new
|
||||
|
||||
2. **Changelog Research**:
|
||||
- For each updated dependency, look up its changelog/release notes
|
||||
- Use WebFetch to access GitHub releases, NPM package pages, PyPI project pages. The pr should also have some details
|
||||
- Focus on versions between the old and new versions
|
||||
- Identify: breaking changes, deprecations, security fixes, new features
|
||||
|
||||
3. **Breaking Change Assessment**:
|
||||
- Categorize changes: BREAKING, MAJOR, MINOR, PATCH, SECURITY
|
||||
- Assess impact on AutoGPT's usage patterns
|
||||
- Check if AutoGPT uses affected APIs/features
|
||||
- Look for migration guides or upgrade instructions
|
||||
|
||||
4. **Codebase Impact Analysis**:
|
||||
- Search the AutoGPT codebase for usage of changed APIs
|
||||
- Identify files that might be affected by breaking changes
|
||||
- Check test files for deprecated usage patterns
|
||||
- Look for configuration changes needed
|
||||
|
||||
## Output Format:
|
||||
|
||||
Provide a comprehensive review comment with:
|
||||
|
||||
### 🔍 Dependency Analysis Summary
|
||||
- List of updated packages with version changes
|
||||
- Overall risk assessment (LOW/MEDIUM/HIGH)
|
||||
|
||||
### 📋 Detailed Changelog Review
|
||||
For each updated dependency:
|
||||
- **Package**: name (old_version → new_version)
|
||||
- **Changes**: Summary of key changes
|
||||
- **Breaking Changes**: List any breaking changes
|
||||
- **Security Fixes**: Note security improvements
|
||||
- **Migration Notes**: Any upgrade steps needed
|
||||
|
||||
### ⚠️ Impact Assessment
|
||||
- **Breaking Changes Found**: Yes/No with details
|
||||
- **Affected Files**: List AutoGPT files that may need updates
|
||||
- **Test Impact**: Any tests that may need updating
|
||||
- **Configuration Changes**: Required config updates
|
||||
|
||||
### 🛠️ Recommendations
|
||||
- **Action Required**: What the team should do
|
||||
- **Testing Focus**: Areas to test thoroughly
|
||||
- **Follow-up Tasks**: Any additional work needed
|
||||
- **Merge Recommendation**: APPROVE/REVIEW_NEEDED/HOLD
|
||||
|
||||
### 📚 Useful Links
|
||||
- Links to relevant changelogs, migration guides, documentation
|
||||
|
||||
Be thorough but concise. Focus on actionable insights that help the development team make informed decisions about the dependency updates.
|
||||
284
.github/workflows/claude.yml
vendored
284
.github/workflows/claude.yml
vendored
@@ -30,18 +30,296 @@ jobs:
|
||||
github.event.issue.author_association == 'COLLABORATOR'
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
contents: write
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@beta
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
|
||||
--model opus
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
302
.github/workflows/copilot-setup-steps.yml
vendored
Normal file
302
.github/workflows/copilot-setup-steps.yml
vendored
Normal file
@@ -0,0 +1,302 @@
|
||||
name: "Copilot Setup Steps"
|
||||
|
||||
# Automatically run the setup steps when they are changed to allow for easy validation, and
|
||||
# allow manual testing through the repository's "Actions" tab
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
paths:
|
||||
- .github/workflows/copilot-setup-steps.yml
|
||||
pull_request:
|
||||
paths:
|
||||
- .github/workflows/copilot-setup-steps.yml
|
||||
|
||||
jobs:
|
||||
# The job MUST be called `copilot-setup-steps` or it will not be picked up by Copilot.
|
||||
copilot-setup-steps:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
# Set the permissions to the lowest permissions possible needed for your steps.
|
||||
# Copilot will be given its own token for its operations.
|
||||
permissions:
|
||||
# If you want to clone the repository as part of your setup steps, for example to install dependencies, you'll need the `contents: read` permission. If you don't clone the repository in your setup steps, Copilot will do this for you automatically after the steps complete.
|
||||
contents: read
|
||||
|
||||
# You can define any steps you want, and they will run before the agent starts.
|
||||
# If you do not check out your code, Copilot will do this for you.
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
4
.github/workflows/platform-backend-ci.yml
vendored
4
.github/workflows/platform-backend-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
@@ -201,7 +201,7 @@ jobs:
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
|
||||
2
.github/workflows/platform-frontend-ci.yml
vendored
2
.github/workflows/platform-frontend-ci.yml
vendored
@@ -160,7 +160,7 @@ jobs:
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
|
||||
@@ -61,24 +61,27 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && npm install
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Start development server
|
||||
npm run dev
|
||||
pnpm dev
|
||||
|
||||
# Run E2E tests
|
||||
npm run test
|
||||
pnpm test
|
||||
|
||||
# Run Storybook for component development
|
||||
npm run storybook
|
||||
pnpm storybook
|
||||
|
||||
# Build production
|
||||
npm run build
|
||||
pnpm build
|
||||
|
||||
# Type checking
|
||||
npm run types
|
||||
pnpm types
|
||||
```
|
||||
|
||||
We have a components library in autogpt_platform/frontend/src/components/atoms that should be used when adding new pages and components.
|
||||
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
@@ -149,14 +152,23 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class
|
||||
3. Define input/output schemas
|
||||
4. Implement `run` method
|
||||
5. Register in block registry
|
||||
6. Generate the block uuid using `uuid.uuid4()`
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
Quick steps:
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
raw: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
hash: str
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
PREFIX: str = "agpt_"
|
||||
PREFIX_LENGTH: int = 8
|
||||
POSTFIX_LENGTH: int = 8
|
||||
|
||||
def generate_api_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with all its parts."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
return APIKeyContainer(
|
||||
raw=raw_key,
|
||||
prefix=raw_key[: self.PREFIX_LENGTH],
|
||||
postfix=raw_key[-self.POSTFIX_LENGTH :],
|
||||
hash=hashlib.sha256(raw_key.encode()).hexdigest(),
|
||||
)
|
||||
|
||||
def verify_api_key(self, provided_key: str, stored_hash: str) -> bool:
|
||||
"""Verify if a provided API key matches the stored hash."""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(provided_hash, stored_hash)
|
||||
@@ -0,0 +1,78 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
key: str
|
||||
head: str
|
||||
tail: str
|
||||
hash: str
|
||||
salt: str
|
||||
|
||||
|
||||
class APIKeySmith:
|
||||
PREFIX: str = "agpt_"
|
||||
HEAD_LENGTH: int = 8
|
||||
TAIL_LENGTH: int = 8
|
||||
|
||||
def generate_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with secure hashing."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
hash, salt = self.hash_key(raw_key)
|
||||
|
||||
return APIKeyContainer(
|
||||
key=raw_key,
|
||||
head=raw_key[: self.HEAD_LENGTH],
|
||||
tail=raw_key[-self.TAIL_LENGTH :],
|
||||
hash=hash,
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
def verify_key(
|
||||
self, provided_key: str, known_hash: str, known_salt: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verify an API key against a known hash (+ salt).
|
||||
Supports verifying both legacy SHA256 and secure Scrypt hashes.
|
||||
"""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
|
||||
# Handle legacy SHA256 hashes (migration support)
|
||||
if known_salt is None:
|
||||
legacy_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(legacy_hash, known_hash)
|
||||
|
||||
try:
|
||||
salt_bytes = bytes.fromhex(known_salt)
|
||||
provided_hash = self._hash_key_with_salt(provided_key, salt_bytes)
|
||||
return secrets.compare_digest(provided_hash, known_hash)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
def _generate_salt(self) -> bytes:
|
||||
"""Generate a random salt for hashing."""
|
||||
return secrets.token_bytes(32)
|
||||
|
||||
def _hash_key_with_salt(self, raw_key: str, salt: bytes) -> str:
|
||||
"""Hash API key using Scrypt with salt."""
|
||||
kdf = Scrypt(
|
||||
length=32,
|
||||
salt=salt,
|
||||
n=2**14, # CPU/memory cost parameter
|
||||
r=8, # Block size parameter
|
||||
p=1, # Parallelization parameter
|
||||
)
|
||||
key_hash = kdf.derive(raw_key.encode())
|
||||
return key_hash.hex()
|
||||
@@ -0,0 +1,79 @@
|
||||
import hashlib
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
|
||||
|
||||
def test_generate_api_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
assert key.key.startswith(keysmith.PREFIX)
|
||||
assert key.head == key.key[: keysmith.HEAD_LENGTH]
|
||||
assert key.tail == key.key[-keysmith.TAIL_LENGTH :]
|
||||
assert len(key.hash) == 64 # 32 bytes hex encoded
|
||||
assert len(key.salt) == 64 # 32 bytes hex encoded
|
||||
|
||||
|
||||
def test_verify_new_secure_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test correct key validates
|
||||
assert keysmith.verify_key(key.key, key.hash, key.salt) is True
|
||||
|
||||
# Test wrong key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey123"
|
||||
assert keysmith.verify_key(wrong_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_verify_legacy_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}legacykey123"
|
||||
legacy_hash = hashlib.sha256(legacy_key.encode()).hexdigest()
|
||||
|
||||
# Test legacy key validates without salt
|
||||
assert keysmith.verify_key(legacy_key, legacy_hash) is True
|
||||
|
||||
# Test wrong legacy key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wronglegacy"
|
||||
assert keysmith.verify_key(wrong_key, legacy_hash) is False
|
||||
|
||||
|
||||
def test_rehash_existing_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}migratekey123"
|
||||
|
||||
# Migrate the legacy key
|
||||
new_hash, new_salt = keysmith.hash_key(legacy_key)
|
||||
|
||||
# Verify migrated key works
|
||||
assert keysmith.verify_key(legacy_key, new_hash, new_salt) is True
|
||||
|
||||
# Verify different key fails with migrated hash
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey"
|
||||
assert keysmith.verify_key(wrong_key, new_hash, new_salt) is False
|
||||
|
||||
|
||||
def test_invalid_key_prefix():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test key without proper prefix fails
|
||||
invalid_key = "invalid_prefix_key"
|
||||
assert keysmith.verify_key(invalid_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_secure_hash_requires_salt():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Secure hash without salt should fail
|
||||
assert keysmith.verify_key(key.key, key.hash) is False
|
||||
|
||||
|
||||
def test_invalid_salt_format():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Invalid salt format should fail gracefully
|
||||
assert keysmith.verify_key(key.key, key.hash, "invalid_hex") is False
|
||||
@@ -1,13 +1,13 @@
|
||||
from .depends import requires_admin_user, requires_user
|
||||
from .jwt_utils import parse_jwt_token
|
||||
from .middleware import APIKeyValidator, auth_middleware
|
||||
from .config import verify_settings
|
||||
from .dependencies import get_user_id, requires_admin_user, requires_user
|
||||
from .helpers import add_auth_responses_to_openapi
|
||||
from .models import User
|
||||
|
||||
__all__ = [
|
||||
"parse_jwt_token",
|
||||
"requires_user",
|
||||
"verify_settings",
|
||||
"get_user_id",
|
||||
"requires_admin_user",
|
||||
"APIKeyValidator",
|
||||
"auth_middleware",
|
||||
"requires_user",
|
||||
"add_auth_responses_to_openapi",
|
||||
"User",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,90 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from jwt.algorithms import get_default_algorithms, has_crypto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfigError(ValueError):
|
||||
"""Raised when authentication configuration is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
ALGO_RECOMMENDATION = (
|
||||
"We highly recommend using an asymmetric algorithm such as ES256, "
|
||||
"because when leaked, a shared secret would allow anyone to "
|
||||
"forge valid tokens and impersonate users. "
|
||||
"More info: https://supabase.com/docs/guides/auth/signing-keys#choosing-the-right-signing-algorithm" # noqa
|
||||
)
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
self.JWT_VERIFY_KEY: str = os.getenv(
|
||||
"JWT_VERIFY_KEY", os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
).strip()
|
||||
self.JWT_ALGORITHM: str = os.getenv("JWT_SIGN_ALGORITHM", "HS256").strip()
|
||||
|
||||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
if not self.JWT_VERIFY_KEY:
|
||||
raise AuthConfigError(
|
||||
"JWT_VERIFY_KEY must be set. "
|
||||
"An empty JWT secret would allow anyone to forge valid tokens."
|
||||
)
|
||||
|
||||
if len(self.JWT_VERIFY_KEY) < 32:
|
||||
logger.warning(
|
||||
"⚠️ JWT_VERIFY_KEY appears weak (less than 32 characters). "
|
||||
"Consider using a longer, cryptographically secure secret."
|
||||
)
|
||||
|
||||
supported_algorithms = get_default_algorithms().keys()
|
||||
|
||||
if not has_crypto:
|
||||
logger.warning(
|
||||
"⚠️ Asymmetric JWT verification is not available "
|
||||
"because the 'cryptography' package is not installed. "
|
||||
+ ALGO_RECOMMENDATION
|
||||
)
|
||||
|
||||
if (
|
||||
self.JWT_ALGORITHM not in supported_algorithms
|
||||
or self.JWT_ALGORITHM == "none"
|
||||
):
|
||||
raise AuthConfigError(
|
||||
f"Invalid JWT_SIGN_ALGORITHM: '{self.JWT_ALGORITHM}'. "
|
||||
"Supported algorithms are listed on "
|
||||
"https://pyjwt.readthedocs.io/en/stable/algorithms.html"
|
||||
)
|
||||
|
||||
if self.JWT_ALGORITHM.startswith("HS"):
|
||||
logger.warning(
|
||||
f"⚠️ JWT_SIGN_ALGORITHM is set to '{self.JWT_ALGORITHM}', "
|
||||
"a symmetric shared-key signature algorithm. " + ALGO_RECOMMENDATION
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
_settings: Settings = None # type: ignore
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
global _settings
|
||||
|
||||
if not _settings:
|
||||
_settings = Settings()
|
||||
|
||||
return _settings
|
||||
|
||||
|
||||
def verify_settings() -> None:
|
||||
global _settings
|
||||
|
||||
if not _settings:
|
||||
_settings = Settings() # calls validation indirectly
|
||||
return
|
||||
|
||||
_settings.validate()
|
||||
|
||||
306
autogpt_platform/autogpt_libs/autogpt_libs/auth/config_test.py
Normal file
306
autogpt_platform/autogpt_libs/autogpt_libs/auth/config_test.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Comprehensive tests for auth configuration to ensure 100% line and branch coverage.
|
||||
These tests verify critical security checks preventing JWT token forgery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth.config import AuthConfigError, Settings
|
||||
|
||||
|
||||
def test_environment_variable_precedence(mocker: MockerFixture):
|
||||
"""Test that environment variables take precedence over defaults."""
|
||||
secret = "environment-secret-key-with-proper-length-123456"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_environment_variable_backwards_compatible(mocker: MockerFixture):
|
||||
"""Test that SUPABASE_JWT_SECRET is read if JWT_VERIFY_KEY is not set."""
|
||||
secret = "environment-secret-key-with-proper-length-123456"
|
||||
mocker.patch.dict(os.environ, {"SUPABASE_JWT_SECRET": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_auth_config_error_inheritance():
|
||||
"""Test that AuthConfigError is properly defined as an Exception."""
|
||||
assert issubclass(AuthConfigError, Exception)
|
||||
error = AuthConfigError("test message")
|
||||
assert str(error) == "test message"
|
||||
|
||||
|
||||
def test_settings_static_after_creation(mocker: MockerFixture):
|
||||
"""Test that settings maintain their values after creation."""
|
||||
secret = "immutable-secret-key-with-proper-length-12345"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
original_secret = settings.JWT_VERIFY_KEY
|
||||
|
||||
# Changing environment after creation shouldn't affect settings
|
||||
os.environ["JWT_VERIFY_KEY"] = "different-secret"
|
||||
|
||||
assert settings.JWT_VERIFY_KEY == original_secret
|
||||
|
||||
|
||||
def test_settings_load_with_valid_secret(mocker: MockerFixture):
|
||||
"""Test auth enabled with a valid JWT secret."""
|
||||
valid_secret = "a" * 32 # 32 character secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": valid_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == valid_secret
|
||||
|
||||
|
||||
def test_settings_load_with_strong_secret(mocker: MockerFixture):
|
||||
"""Test auth enabled with a cryptographically strong secret."""
|
||||
strong_secret = "super-secret-jwt-token-with-at-least-32-characters-long"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": strong_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == strong_secret
|
||||
assert len(settings.JWT_VERIFY_KEY) >= 32
|
||||
|
||||
|
||||
def test_secret_empty_raises_error(mocker: MockerFixture):
|
||||
"""Test that auth enabled with empty secret raises AuthConfigError."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": ""}, clear=True)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
Settings()
|
||||
assert "JWT_VERIFY_KEY" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_secret_missing_raises_error(mocker: MockerFixture):
|
||||
"""Test that auth enabled without secret env var raises AuthConfigError."""
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
Settings()
|
||||
assert "JWT_VERIFY_KEY" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("secret", [" ", " ", "\t", "\n", " \t\n "])
|
||||
def test_secret_only_whitespace_raises_error(mocker: MockerFixture, secret: str):
|
||||
"""Test that auth enabled with whitespace-only secret raises error."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Settings()
|
||||
|
||||
|
||||
def test_secret_weak_logs_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that weak JWT secret triggers warning log."""
|
||||
weak_secret = "short" # Less than 32 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": weak_secret}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == weak_secret
|
||||
assert "key appears weak" in caplog.text.lower()
|
||||
assert "less than 32 characters" in caplog.text
|
||||
|
||||
|
||||
def test_secret_31_char_logs_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that 31-character secret triggers warning (boundary test)."""
|
||||
secret_31 = "a" * 31 # Exactly 31 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_31}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert len(settings.JWT_VERIFY_KEY) == 31
|
||||
assert "key appears weak" in caplog.text.lower()
|
||||
|
||||
|
||||
def test_secret_32_char_no_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that 32-character secret does not trigger warning (boundary test)."""
|
||||
secret_32 = "a" * 32 # Exactly 32 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_32}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert len(settings.JWT_VERIFY_KEY) == 32
|
||||
assert "JWT secret appears weak" not in caplog.text
|
||||
|
||||
|
||||
def test_secret_whitespace_stripped(mocker: MockerFixture):
|
||||
"""Test that JWT secret whitespace is stripped."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": f" {secret} "}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_secret_with_special_characters(mocker: MockerFixture):
|
||||
"""Test JWT secret with special characters."""
|
||||
special_secret = "!@#$%^&*()_+-=[]{}|;:,.<>?`~" + "a" * 10 # 40 chars total
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": special_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == special_secret
|
||||
|
||||
|
||||
def test_secret_with_unicode(mocker: MockerFixture):
|
||||
"""Test JWT secret with unicode characters."""
|
||||
unicode_secret = "秘密🔐キー" + "a" * 25 # Ensure >32 bytes
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": unicode_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == unicode_secret
|
||||
|
||||
|
||||
def test_secret_very_long(mocker: MockerFixture):
|
||||
"""Test JWT secret with excessive length."""
|
||||
long_secret = "a" * 1000 # 1000 character secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": long_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == long_secret
|
||||
assert len(settings.JWT_VERIFY_KEY) == 1000
|
||||
|
||||
|
||||
def test_secret_with_newline(mocker: MockerFixture):
|
||||
"""Test JWT secret containing newlines."""
|
||||
multiline_secret = "secret\nwith\nnewlines" + "a" * 20
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": multiline_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == multiline_secret
|
||||
|
||||
|
||||
def test_secret_base64_encoded(mocker: MockerFixture):
|
||||
"""Test JWT secret that looks like base64."""
|
||||
base64_secret = "dGhpc19pc19hX3NlY3JldF9rZXlfd2l0aF9wcm9wZXJfbGVuZ3Ro"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": base64_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == base64_secret
|
||||
|
||||
|
||||
def test_secret_numeric_only(mocker: MockerFixture):
|
||||
"""Test JWT secret with only numbers."""
|
||||
numeric_secret = "1234567890" * 4 # 40 character numeric secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": numeric_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == numeric_secret
|
||||
|
||||
|
||||
def test_algorithm_default_hs256(mocker: MockerFixture):
|
||||
"""Test that JWT algorithm defaults to HS256."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": "a" * 32}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_ALGORITHM == "HS256"
|
||||
|
||||
|
||||
def test_algorithm_whitespace_stripped(mocker: MockerFixture):
|
||||
"""Test that JWT algorithm whitespace is stripped."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": " HS256 "},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_ALGORITHM == "HS256"
|
||||
|
||||
|
||||
def test_no_crypto_warning(mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
|
||||
"""Test warning when crypto package is not available."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
# Mock has_crypto to return False
|
||||
mocker.patch("autogpt_libs.auth.config.has_crypto", False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
Settings()
|
||||
assert "Asymmetric JWT verification is not available" in caplog.text
|
||||
assert "cryptography" in caplog.text
|
||||
|
||||
|
||||
def test_algorithm_invalid_raises_error(mocker: MockerFixture):
|
||||
"""Test that invalid JWT algorithm raises AuthConfigError."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "INVALID_ALG"},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthConfigError) as exc_info:
|
||||
Settings()
|
||||
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
|
||||
assert "INVALID_ALG" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_algorithm_none_raises_error(mocker: MockerFixture):
|
||||
"""Test that 'none' algorithm raises AuthConfigError."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "none"},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthConfigError) as exc_info:
|
||||
Settings()
|
||||
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm", ["HS256", "HS384", "HS512"])
|
||||
def test_algorithm_symmetric_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
|
||||
):
|
||||
"""Test warning for symmetric algorithms (HS256, HS384, HS512)."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert algorithm in caplog.text
|
||||
assert "symmetric shared-key signature algorithm" in caplog.text
|
||||
assert settings.JWT_ALGORITHM == algorithm
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"algorithm",
|
||||
["ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512"],
|
||||
)
|
||||
def test_algorithm_asymmetric_no_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
|
||||
):
|
||||
"""Test that asymmetric algorithms do not trigger warning."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
# Should not contain the symmetric algorithm warning
|
||||
assert "symmetric shared-key signature algorithm" not in caplog.text
|
||||
assert settings.JWT_ALGORITHM == algorithm
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
FastAPI dependency functions for JWT-based authentication and authorization.
|
||||
|
||||
These are the high-level dependency functions used in route definitions.
|
||||
"""
|
||||
|
||||
import fastapi
|
||||
|
||||
from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
|
||||
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures
|
||||
"""
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures, 403 for insufficient permissions
|
||||
"""
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures or missing user ID
|
||||
"""
|
||||
user_id = jwt_payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
return user_id
|
||||
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Comprehensive integration tests for authentication dependencies.
|
||||
Tests the full authentication flow from HTTP requests to user validation.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Security
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth.dependencies import (
|
||||
get_user_id,
|
||||
requires_admin_user,
|
||||
requires_user,
|
||||
)
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
|
||||
class TestAuthDependencies:
|
||||
"""Test suite for authentication dependency functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create a test FastAPI application."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/user")
|
||||
def get_user_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
@app.get("/admin")
|
||||
def get_admin_endpoint(user: User = Security(requires_admin_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
@app.get("/user-id")
|
||||
def get_user_id_endpoint(user_id: str = Security(get_user_id)):
|
||||
return {"user_id": user_id}
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app):
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
# Mock get_jwt_payload to return our test payload
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
requires_admin_user(jwt_payload)
|
||||
|
||||
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = get_user_id(jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestAuthDependenciesIntegration:
|
||||
"""Integration tests for auth dependencies with FastAPI."""
|
||||
|
||||
acceptable_jwt_secret = "test-secret-with-proper-length-123456"
|
||||
|
||||
@pytest.fixture
|
||||
def create_token(self, mocker: MockerFixture):
|
||||
"""Helper to create JWT tokens."""
|
||||
import jwt
|
||||
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": self.acceptable_jwt_secret},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
def _create_token(payload, secret=self.acceptable_jwt_secret):
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
return _create_token
|
||||
|
||||
def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Should fail without auth header
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
token = create_token(
|
||||
{"sub": "test-user", "role": "user", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {token}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/admin")
|
||||
def admin_endpoint(user: User = Security(requires_admin_user)):
|
||||
return {"user_id": user.user_id}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Regular user token
|
||||
user_token = create_token(
|
||||
{"sub": "regular-user", "role": "user", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Admin token
|
||||
admin_token = create_token(
|
||||
{"sub": "admin-user", "role": "admin", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin", headers={"Authorization": f"Bearer {admin_token}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "admin-user"
|
||||
|
||||
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
"role": "admin",
|
||||
"email": "test@example.com",
|
||||
"app_metadata": {"provider": "email", "providers": ["email"]},
|
||||
"user_metadata": {
|
||||
"full_name": "Test User",
|
||||
"avatar_url": "https://example.com/avatar.jpg",
|
||||
},
|
||||
"aud": "authenticated",
|
||||
"iat": 1234567890,
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
"role": "user",
|
||||
"email": "测试@example.com",
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
"role": "user",
|
||||
"email": None,
|
||||
"phone": None,
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = requires_user(payload1)
|
||||
user2 = requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
assert user1.role == "user"
|
||||
assert user2.role == "admin"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload,expected_error,admin_only",
|
||||
[
|
||||
(None, "Authorization header is missing", False),
|
||||
({}, "User ID not found", False),
|
||||
({"sub": ""}, "User ID not found", False),
|
||||
({"role": "user"}, "User ID not found", False),
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
# Valid case
|
||||
user = verify_user({"sub": "user", "role": "user"}, admin_only=False)
|
||||
assert user.user_id == "user"
|
||||
@@ -1,46 +0,0 @@
|
||||
import fastapi
|
||||
|
||||
from .config import settings
|
||||
from .middleware import auth_middleware
|
||||
from .models import DEFAULT_USER_ID, User
|
||||
|
||||
|
||||
def requires_user(payload: dict = fastapi.Depends(auth_middleware)) -> User:
|
||||
return verify_user(payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(
|
||||
payload: dict = fastapi.Depends(auth_middleware),
|
||||
) -> User:
|
||||
return verify_user(payload, admin_only=True)
|
||||
|
||||
|
||||
def verify_user(payload: dict | None, admin_only: bool) -> User:
|
||||
if not payload:
|
||||
if settings.ENABLE_AUTH:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="Authorization header is missing"
|
||||
)
|
||||
# This handles the case when authentication is disabled
|
||||
payload = {"sub": DEFAULT_USER_ID, "role": "admin"}
|
||||
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
|
||||
if admin_only and payload["role"] != "admin":
|
||||
raise fastapi.HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return User.from_payload(payload)
|
||||
|
||||
|
||||
def get_user_id(payload: dict = fastapi.Depends(auth_middleware)) -> str:
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
return user_id
|
||||
@@ -1,68 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from .depends import requires_admin_user, requires_user, verify_user
|
||||
|
||||
|
||||
def test_verify_user_no_payload():
|
||||
user = verify_user(None, admin_only=False)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_no_user_id():
|
||||
with pytest.raises(Exception):
|
||||
verify_user({"role": "admin"}, admin_only=False)
|
||||
|
||||
|
||||
def test_verify_user_not_admin():
|
||||
with pytest.raises(Exception):
|
||||
verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
|
||||
admin_only=True,
|
||||
)
|
||||
|
||||
|
||||
def test_verify_user_with_admin_role():
|
||||
user = verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"},
|
||||
admin_only=True,
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_with_user_role():
|
||||
user = verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
|
||||
admin_only=False,
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "user"
|
||||
|
||||
|
||||
def test_requires_user():
|
||||
user = requires_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "user"
|
||||
|
||||
|
||||
def test_requires_user_no_user_id():
|
||||
with pytest.raises(Exception):
|
||||
requires_user({"role": "user"})
|
||||
|
||||
|
||||
def test_requires_admin_user():
|
||||
user = requires_admin_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"}
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_requires_admin_user_not_admin():
|
||||
with pytest.raises(Exception):
|
||||
requires_admin_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
|
||||
)
|
||||
68
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers.py
Normal file
68
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
for method, details in methods.items():
|
||||
security_schemas = [
|
||||
schema
|
||||
for auth_option in details.get("security", [])
|
||||
for schema in auth_option.keys()
|
||||
]
|
||||
if bearer_jwt_auth.scheme_name not in security_schemas:
|
||||
continue
|
||||
|
||||
if "responses" not in details:
|
||||
details["responses"] = {}
|
||||
|
||||
details["responses"]["401"] = {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
|
||||
# Ensure #/components/responses exists
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "responses" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["responses"] = {}
|
||||
|
||||
# Define 401 response
|
||||
openapi_schema["components"]["responses"]["HTTP401NotAuthenticatedError"] = {
|
||||
"description": "Authentication required",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"detail": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
435
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers_test.py
Normal file
435
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers_test.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Comprehensive tests for auth helpers module to achieve 100% coverage.
|
||||
Tests OpenAPI schema generation and authentication response handling.
|
||||
"""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from autogpt_libs.auth.helpers import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth.jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_basic():
|
||||
"""Test adding 401 responses to OpenAPI schema."""
|
||||
app = FastAPI(title="Test App", version="1.0.0")
|
||||
|
||||
# Add some test endpoints with authentication
|
||||
from fastapi import Depends
|
||||
|
||||
from autogpt_libs.auth.dependencies import requires_user
|
||||
|
||||
@app.get("/protected", dependencies=[Depends(requires_user)])
|
||||
def protected_endpoint():
|
||||
return {"message": "Protected"}
|
||||
|
||||
@app.get("/public")
|
||||
def public_endpoint():
|
||||
return {"message": "Public"}
|
||||
|
||||
# Apply the OpenAPI customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get the OpenAPI schema
|
||||
schema = app.openapi()
|
||||
|
||||
# Verify basic schema properties
|
||||
assert schema["info"]["title"] == "Test App"
|
||||
assert schema["info"]["version"] == "1.0.0"
|
||||
|
||||
# Verify 401 response component is added
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# Verify 401 response structure
|
||||
error_response = schema["components"]["responses"]["HTTP401NotAuthenticatedError"]
|
||||
assert error_response["description"] == "Authentication required"
|
||||
assert "application/json" in error_response["content"]
|
||||
assert "schema" in error_response["content"]["application/json"]
|
||||
|
||||
# Verify schema properties
|
||||
response_schema = error_response["content"]["application/json"]["schema"]
|
||||
assert response_schema["type"] == "object"
|
||||
assert "detail" in response_schema["properties"]
|
||||
assert response_schema["properties"]["detail"]["type"] == "string"
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_with_security():
|
||||
"""Test that 401 responses are added only to secured endpoints."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock endpoint with security
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.dependencies import get_user_id
|
||||
|
||||
@app.get("/secured")
|
||||
def secured_endpoint(user_id: str = Security(get_user_id)):
|
||||
return {"user_id": user_id}
|
||||
|
||||
@app.post("/also-secured")
|
||||
def another_secured(user_id: str = Security(get_user_id)):
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/unsecured")
|
||||
def unsecured_endpoint():
|
||||
return {"public": True}
|
||||
|
||||
# Apply OpenAPI customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema
|
||||
schema = app.openapi()
|
||||
|
||||
# Check that secured endpoints have 401 responses
|
||||
if "/secured" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/secured"]:
|
||||
secured_get = schema["paths"]["/secured"]["get"]
|
||||
if "responses" in secured_get:
|
||||
assert "401" in secured_get["responses"]
|
||||
assert (
|
||||
secured_get["responses"]["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
if "/also-secured" in schema["paths"]:
|
||||
if "post" in schema["paths"]["/also-secured"]:
|
||||
secured_post = schema["paths"]["/also-secured"]["post"]
|
||||
if "responses" in secured_post:
|
||||
assert "401" in secured_post["responses"]
|
||||
|
||||
# Check that unsecured endpoint does not have 401 response
|
||||
if "/unsecured" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/unsecured"]:
|
||||
unsecured_get = schema["paths"]["/unsecured"]["get"]
|
||||
if "responses" in unsecured_get:
|
||||
assert "401" not in unsecured_get.get("responses", {})
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_cached_schema():
|
||||
"""Test that OpenAPI schema is cached after first generation."""
|
||||
app = FastAPI()
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema twice
|
||||
schema1 = app.openapi()
|
||||
schema2 = app.openapi()
|
||||
|
||||
# Should return the same cached object
|
||||
assert schema1 is schema2
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_existing_responses():
|
||||
"""Test handling endpoints that already have responses defined."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get(
|
||||
"/with-responses",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
404: {"description": "Not found"},
|
||||
},
|
||||
)
|
||||
def endpoint_with_responses(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"data": "test"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Check that existing responses are preserved and 401 is added
|
||||
if "/with-responses" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/with-responses"]:
|
||||
responses = schema["paths"]["/with-responses"]["get"].get("responses", {})
|
||||
# Original responses should be preserved
|
||||
if "200" in responses:
|
||||
assert responses["200"]["description"] == "Success"
|
||||
if "404" in responses:
|
||||
assert responses["404"]["description"] == "Not found"
|
||||
# 401 should be added
|
||||
if "401" in responses:
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_no_security_endpoints():
|
||||
"""Test with app that has no secured endpoints."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/public1")
|
||||
def public1():
|
||||
return {"message": "public1"}
|
||||
|
||||
@app.post("/public2")
|
||||
def public2():
|
||||
return {"message": "public2"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Component should still be added for consistency
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# But no endpoints should have 401 responses
|
||||
for path in schema["paths"].values():
|
||||
for method in path.values():
|
||||
if isinstance(method, dict) and "responses" in method:
|
||||
assert "401" not in method["responses"]
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_multiple_security_schemes():
|
||||
"""Test endpoints with multiple security requirements."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.dependencies import requires_admin_user, requires_user
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
@app.get("/multi-auth")
|
||||
def multi_auth(
|
||||
user: User = Security(requires_user),
|
||||
admin: User = Security(requires_admin_user),
|
||||
):
|
||||
return {"status": "super secure"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Should have 401 response
|
||||
if "/multi-auth" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/multi-auth"]:
|
||||
responses = schema["paths"]["/multi-auth"]["get"].get("responses", {})
|
||||
if "401" in responses:
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_empty_components():
|
||||
"""Test when OpenAPI schema has no components section initially."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock get_openapi to return schema without components
|
||||
original_get_openapi = get_openapi
|
||||
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Remove components if it exists
|
||||
if "components" in schema:
|
||||
del schema["components"]
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Components should be created
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_all_http_methods():
|
||||
"""Test that all HTTP methods are handled correctly."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get("/resource")
|
||||
def get_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "GET"}
|
||||
|
||||
@app.post("/resource")
|
||||
def post_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "POST"}
|
||||
|
||||
@app.put("/resource")
|
||||
def put_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "PUT"}
|
||||
|
||||
@app.patch("/resource")
|
||||
def patch_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "PATCH"}
|
||||
|
||||
@app.delete("/resource")
|
||||
def delete_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "DELETE"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# All methods should have 401 response
|
||||
if "/resource" in schema["paths"]:
|
||||
for method in ["get", "post", "put", "patch", "delete"]:
|
||||
if method in schema["paths"]["/resource"]:
|
||||
method_spec = schema["paths"]["/resource"][method]
|
||||
if "responses" in method_spec:
|
||||
assert "401" in method_spec["responses"]
|
||||
|
||||
|
||||
def test_bearer_jwt_auth_scheme_config():
|
||||
"""Test that bearer_jwt_auth is configured correctly."""
|
||||
assert bearer_jwt_auth.scheme_name == "HTTPBearerJWT"
|
||||
assert bearer_jwt_auth.auto_error is False
|
||||
|
||||
|
||||
def test_add_auth_responses_with_no_routes():
|
||||
"""Test OpenAPI generation with app that has no routes."""
|
||||
app = FastAPI(title="Empty App")
|
||||
|
||||
# Apply customization to empty app
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Should still have basic structure
|
||||
assert schema["info"]["title"] == "Empty App"
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
|
||||
def test_custom_openapi_function_replacement():
|
||||
"""Test that the custom openapi function properly replaces the default."""
|
||||
app = FastAPI()
|
||||
|
||||
# Store original function
|
||||
original_openapi = app.openapi
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Function should be replaced
|
||||
assert app.openapi != original_openapi
|
||||
assert callable(app.openapi)
|
||||
|
||||
|
||||
def test_endpoint_without_responses_section():
|
||||
"""Test endpoint that has security but no responses section initially."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
from fastapi.openapi.utils import get_openapi as original_get_openapi
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
# Create endpoint
|
||||
@app.get("/no-responses")
|
||||
def endpoint_without_responses(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"data": "test"}
|
||||
|
||||
# Mock get_openapi to remove responses from the endpoint
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Remove responses from our endpoint to trigger line 40
|
||||
if "/no-responses" in schema.get("paths", {}):
|
||||
if "get" in schema["paths"]["/no-responses"]:
|
||||
# Delete responses to force the code to create it
|
||||
if "responses" in schema["paths"]["/no-responses"]["get"]:
|
||||
del schema["paths"]["/no-responses"]["get"]["responses"]
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema and verify 401 was added
|
||||
schema = app.openapi()
|
||||
|
||||
# The endpoint should now have 401 response
|
||||
if "/no-responses" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/no-responses"]:
|
||||
responses = schema["paths"]["/no-responses"]["get"].get("responses", {})
|
||||
assert "401" in responses
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_components_with_existing_responses():
|
||||
"""Test when components already has a responses section."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock get_openapi to return schema with existing components/responses
|
||||
from fastapi.openapi.utils import get_openapi as original_get_openapi
|
||||
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Add existing components/responses
|
||||
if "components" not in schema:
|
||||
schema["components"] = {}
|
||||
schema["components"]["responses"] = {
|
||||
"ExistingResponse": {"description": "An existing response"}
|
||||
}
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Both responses should exist
|
||||
assert "ExistingResponse" in schema["components"]["responses"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# Verify our 401 response structure
|
||||
error_response = schema["components"]["responses"][
|
||||
"HTTP401NotAuthenticatedError"
|
||||
]
|
||||
assert error_response["description"] == "Authentication required"
|
||||
|
||||
|
||||
def test_openapi_schema_persistence():
|
||||
"""Test that modifications to OpenAPI schema persist correctly."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"test": True}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema multiple times
|
||||
schema1 = app.openapi()
|
||||
|
||||
# Modify the cached schema (shouldn't affect future calls)
|
||||
schema1["info"]["title"] = "Modified Title"
|
||||
|
||||
# Clear cache and get again
|
||||
app.openapi_schema = None
|
||||
schema2 = app.openapi()
|
||||
|
||||
# Should regenerate with original title
|
||||
assert schema2["info"]["title"] == app.title
|
||||
assert schema2["info"]["title"] != "Modified Title"
|
||||
@@ -1,11 +1,48 @@
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from .config import settings
|
||||
from .config import get_settings
|
||||
from .models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Bearer token authentication scheme
|
||||
bearer_jwt_auth = HTTPBearer(
|
||||
bearerFormat="jwt", scheme_name="HTTPBearerJWT", auto_error=False
|
||||
)
|
||||
|
||||
|
||||
def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract and validate JWT payload from HTTP Authorization header.
|
||||
|
||||
This is the core authentication function that handles:
|
||||
- Reading the `Authorization` header to obtain the JWT token
|
||||
- Verifying the JWT token's signature
|
||||
- Decoding the JWT token's payload
|
||||
|
||||
:param credentials: HTTP Authorization credentials from bearer token
|
||||
:return: JWT payload dictionary
|
||||
:raises HTTPException: 401 if authentication fails
|
||||
"""
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
logger.debug("Token decoded successfully")
|
||||
return payload
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
|
||||
def parse_jwt_token(token: str) -> dict[str, Any]:
|
||||
"""
|
||||
Parse and validate a JWT token.
|
||||
|
||||
@@ -13,10 +50,11 @@ def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
:return: The decoded payload
|
||||
:raises ValueError: If the token is invalid or expired
|
||||
"""
|
||||
settings = get_settings()
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
settings.JWT_VERIFY_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience="authenticated",
|
||||
)
|
||||
@@ -25,3 +63,18 @@ def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise ValueError(f"Invalid token: {str(e)}")
|
||||
|
||||
|
||||
def verify_user(jwt_payload: dict | None, admin_only: bool) -> User:
|
||||
if jwt_payload is None:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
user_id = jwt_payload.get("sub")
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||
|
||||
if admin_only and jwt_payload["role"] != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return User.from_payload(jwt_payload)
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Comprehensive tests for JWT token parsing and validation.
|
||||
Ensures 100% line and branch coverage for JWT security functions.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth import config, jwt_utils
|
||||
from autogpt_libs.auth.config import Settings
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
MOCK_JWT_SECRET = "test-secret-key-with-at-least-32-characters"
|
||||
TEST_USER_PAYLOAD = {
|
||||
"sub": "test-user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
TEST_ADMIN_PAYLOAD = {
|
||||
"sub": "admin-user-id",
|
||||
"role": "admin",
|
||||
"aud": "authenticated",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_config(mocker: MockerFixture):
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": MOCK_JWT_SECRET}, clear=True)
|
||||
mocker.patch.object(config, "_settings", Settings())
|
||||
yield
|
||||
|
||||
|
||||
def create_token(payload, secret=None, algorithm="HS256"):
|
||||
"""Helper to create JWT tokens."""
|
||||
if secret is None:
|
||||
secret = MOCK_JWT_SECRET
|
||||
return jwt.encode(payload, secret, algorithm=algorithm)
|
||||
|
||||
|
||||
def test_parse_jwt_token_valid():
|
||||
"""Test parsing a valid JWT token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
result = jwt_utils.parse_jwt_token(token)
|
||||
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
assert result["aud"] == "authenticated"
|
||||
|
||||
|
||||
def test_parse_jwt_token_expired():
|
||||
"""Test parsing an expired JWT token."""
|
||||
expired_payload = {
|
||||
**TEST_USER_PAYLOAD,
|
||||
"exp": datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
}
|
||||
token = create_token(expired_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Token has expired" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_invalid_signature():
|
||||
"""Test parsing a token with invalid signature."""
|
||||
# Create token with different secret
|
||||
token = create_token(TEST_USER_PAYLOAD, secret="wrong-secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_malformed():
|
||||
"""Test parsing a malformed token."""
|
||||
malformed_tokens = [
|
||||
"not.a.token",
|
||||
"invalid",
|
||||
"",
|
||||
# Header only
|
||||
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9",
|
||||
# No signature
|
||||
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0",
|
||||
]
|
||||
|
||||
for token in malformed_tokens:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_wrong_audience():
|
||||
"""Test parsing a token with wrong audience."""
|
||||
wrong_aud_payload = {**TEST_USER_PAYLOAD, "aud": "wrong-audience"}
|
||||
token = create_token(wrong_aud_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_missing_audience():
|
||||
"""Test parsing a token without audience claim."""
|
||||
no_aud_payload = {k: v for k, v in TEST_USER_PAYLOAD.items() if k != "aud"}
|
||||
token = create_token(no_aud_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_with_valid_user():
|
||||
"""Test verifying a valid user."""
|
||||
user = jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=False)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "test-user-id"
|
||||
assert user.role == "user"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
|
||||
def test_verify_user_with_admin():
|
||||
"""Test verifying an admin user."""
|
||||
user = jwt_utils.verify_user(TEST_ADMIN_PAYLOAD, admin_only=True)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "admin-user-id"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_admin_only_with_regular_user():
|
||||
"""Test verifying regular user when admin is required."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=True)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_no_payload():
|
||||
"""Test verifying user with no payload."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(None, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_missing_sub():
|
||||
"""Test verifying user with payload missing 'sub' field."""
|
||||
invalid_payload = {"role": "user", "email": "test@example.com"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_empty_sub():
|
||||
"""Test verifying user with empty 'sub' field."""
|
||||
invalid_payload = {"sub": "", "role": "user"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_none_sub():
|
||||
"""Test verifying user with None 'sub' field."""
|
||||
invalid_payload = {"sub": None, "role": "user"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_missing_role_admin_check():
|
||||
"""Test verifying admin when role field is missing."""
|
||||
no_role_payload = {"sub": "user-id"}
|
||||
with pytest.raises(KeyError):
|
||||
# This will raise KeyError when checking payload["role"]
|
||||
jwt_utils.verify_user(no_role_payload, admin_only=True)
|
||||
|
||||
|
||||
# ======================== EDGE CASES ======================== #
|
||||
|
||||
|
||||
def test_jwt_with_additional_claims():
|
||||
"""Test JWT token with additional custom claims."""
|
||||
extra_claims_payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"custom_claim": "custom_value",
|
||||
"permissions": ["read", "write"],
|
||||
"metadata": {"key": "value"},
|
||||
}
|
||||
token = create_token(extra_claims_payload)
|
||||
|
||||
result = jwt_utils.parse_jwt_token(token)
|
||||
assert result["sub"] == "user-id"
|
||||
assert result["custom_claim"] == "custom_value"
|
||||
assert result["permissions"] == ["read", "write"]
|
||||
|
||||
|
||||
def test_jwt_with_numeric_sub():
|
||||
"""Test JWT token with numeric user ID."""
|
||||
payload = {
|
||||
"sub": 12345, # Numeric ID
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
# Should convert to string internally
|
||||
user = jwt_utils.verify_user(payload, admin_only=False)
|
||||
assert user.user_id == 12345
|
||||
|
||||
|
||||
def test_jwt_with_very_long_sub():
|
||||
"""Test JWT token with very long user ID."""
|
||||
long_id = "a" * 1000
|
||||
payload = {
|
||||
"sub": long_id,
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
user = jwt_utils.verify_user(payload, admin_only=False)
|
||||
assert user.user_id == long_id
|
||||
|
||||
|
||||
def test_jwt_with_special_characters_in_claims():
|
||||
"""Test JWT token with special characters in claims."""
|
||||
payload = {
|
||||
"sub": "user@example.com/special-chars!@#$%",
|
||||
"role": "admin",
|
||||
"aud": "authenticated",
|
||||
"email": "test+special@example.com",
|
||||
}
|
||||
user = jwt_utils.verify_user(payload, admin_only=True)
|
||||
assert "special-chars!@#$%" in user.user_id
|
||||
|
||||
|
||||
def test_jwt_with_future_iat():
|
||||
"""Test JWT token with issued-at time in future."""
|
||||
future_payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"iat": datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
}
|
||||
token = create_token(future_payload)
|
||||
|
||||
# PyJWT validates iat claim and should reject future tokens
|
||||
with pytest.raises(ValueError, match="not yet valid"):
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
|
||||
|
||||
def test_jwt_with_different_algorithms():
|
||||
"""Test that only HS256 algorithm is accepted."""
|
||||
payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
|
||||
# Try different algorithms
|
||||
algorithms = ["HS384", "HS512", "none"]
|
||||
for algo in algorithms:
|
||||
if algo == "none":
|
||||
# Special case for 'none' algorithm (security vulnerability if accepted)
|
||||
token = create_token(payload, "", algorithm="none")
|
||||
else:
|
||||
token = create_token(payload, algorithm=algo)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
@@ -1,140 +0,0 @@
|
||||
import inspect
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Security
|
||||
from fastapi.security import APIKeyHeader, HTTPBearer
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
from .config import settings
|
||||
from .jwt_utils import parse_jwt_token
|
||||
|
||||
security = HTTPBearer()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def auth_middleware(request: Request):
|
||||
if not settings.ENABLE_AUTH:
|
||||
# If authentication is disabled, allow the request to proceed
|
||||
logger.warning("Auth disabled")
|
||||
return {}
|
||||
|
||||
security = HTTPBearer()
|
||||
credentials = await security(request)
|
||||
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
request.state.user = payload
|
||||
logger.debug("Token decoded successfully")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
return payload
|
||||
|
||||
|
||||
class APIKeyValidator:
|
||||
"""
|
||||
Configurable API key validator that supports custom validation functions
|
||||
for FastAPI applications.
|
||||
|
||||
This class provides a flexible way to implement API key authentication with optional
|
||||
custom validation logic. It can be used for simple token matching
|
||||
or more complex validation scenarios like database lookups.
|
||||
|
||||
Examples:
|
||||
Simple token validation:
|
||||
```python
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="your-secret-token"
|
||||
)
|
||||
|
||||
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
|
||||
def protected_endpoint():
|
||||
return {"message": "Access granted"}
|
||||
```
|
||||
|
||||
Custom validation with database lookup:
|
||||
```python
|
||||
async def validate_with_db(api_key: str):
|
||||
api_key_obj = await db.get_api_key(api_key)
|
||||
return api_key_obj if api_key_obj and api_key_obj.is_active else None
|
||||
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
validate_fn=validate_with_db
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
header_name (str): The name of the header containing the API key
|
||||
expected_token (Optional[str]): The expected API key value for simple token matching
|
||||
validate_fn (Optional[Callable]): Custom validation function that takes an API key
|
||||
string and returns a boolean or object. Can be async.
|
||||
error_status (int): HTTP status code to use for validation errors
|
||||
error_message (str): Error message to return when validation fails
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
header_name: str,
|
||||
expected_token: Optional[str] = None,
|
||||
validate_fn: Optional[Callable[[str], bool]] = None,
|
||||
error_status: int = HTTP_401_UNAUTHORIZED,
|
||||
error_message: str = "Invalid API key",
|
||||
):
|
||||
# Create the APIKeyHeader as a class property
|
||||
self.security_scheme = APIKeyHeader(name=header_name)
|
||||
self.expected_token = expected_token
|
||||
self.custom_validate_fn = validate_fn
|
||||
self.error_status = error_status
|
||||
self.error_message = error_message
|
||||
|
||||
async def default_validator(self, api_key: str) -> bool:
|
||||
if not self.expected_token:
|
||||
raise ValueError(
|
||||
"Expected Token Required to be set when uisng API Key Validator default validation"
|
||||
)
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, api_key: str = Security(APIKeyHeader)
|
||||
) -> Any:
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=self.error_status, detail="Missing API key")
|
||||
|
||||
# Use custom validation if provided, otherwise use default equality check
|
||||
validator = self.custom_validate_fn or self.default_validator
|
||||
result = (
|
||||
await validator(api_key)
|
||||
if inspect.iscoroutinefunction(validator)
|
||||
else validator(api_key)
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=self.error_status, detail=self.error_message
|
||||
)
|
||||
|
||||
# Store validation result in request state if it's not just a boolean
|
||||
if result is not True:
|
||||
request.state.api_key = result
|
||||
|
||||
return result
|
||||
|
||||
def get_dependency(self):
|
||||
"""
|
||||
Returns a callable dependency that FastAPI will recognize as a security scheme
|
||||
"""
|
||||
|
||||
async def validate_api_key(
|
||||
request: Request, api_key: str = Security(self.security_scheme)
|
||||
) -> Any:
|
||||
return await self(request, api_key)
|
||||
|
||||
# This helps FastAPI recognize it as a security dependency
|
||||
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
|
||||
return validate_api_key
|
||||
379
autogpt_platform/autogpt_libs/poetry.lock
generated
379
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -54,7 +54,7 @@ version = "1.2.0"
|
||||
description = "Backport of asyncio.Runner, a context manager that controls event loop life cycle."
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
@@ -85,6 +85,87 @@ files = [
|
||||
{file = "certifi-2025.7.14.tar.gz", hash = "sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cffi"
|
||||
version = "1.17.1"
|
||||
description = "Foreign Function Interface for Python calling C code."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "platform_python_implementation != \"PyPy\""
|
||||
files = [
|
||||
{file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-win32.whl", hash = "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"},
|
||||
{file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pycparser = "*"
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.2"
|
||||
@@ -208,12 +289,176 @@ version = "0.4.6"
|
||||
description = "Cross-platform colored terminal text."
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.10.5"
|
||||
description = "Code coverage measurement for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "coverage-7.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c6a5c3414bfc7451b879141ce772c546985163cf553f08e0f135f0699a911801"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bc8e4d99ce82f1710cc3c125adc30fd1487d3cf6c2cd4994d78d68a47b16989a"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:02252dc1216e512a9311f596b3169fad54abcb13827a8d76d5630c798a50a754"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:73269df37883e02d460bee0cc16be90509faea1e3bd105d77360b512d5bb9c33"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f8a81b0614642f91c9effd53eec284f965577591f51f547a1cbeb32035b4c2f"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6a29f8e0adb7f8c2b95fa2d4566a1d6e6722e0a637634c6563cb1ab844427dd9"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fcf6ab569436b4a647d4e91accba12509ad9f2554bc93d3aee23cc596e7f99c3"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:90dc3d6fb222b194a5de60af8d190bedeeddcbc7add317e4a3cd333ee6b7c879"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-win32.whl", hash = "sha256:414a568cd545f9dc75f0686a0049393de8098414b58ea071e03395505b73d7a8"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:e551f9d03347196271935fd3c0c165f0e8c049220280c1120de0084d65e9c7ff"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c177e6ffe2ebc7c410785307758ee21258aa8e8092b44d09a2da767834f075f2"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:14d6071c51ad0f703d6440827eaa46386169b5fdced42631d5a5ac419616046f"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:61f78c7c3bc272a410c5ae3fde7792b4ffb4acc03d35a7df73ca8978826bb7ab"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f39071caa126f69d63f99b324fb08c7b1da2ec28cbb1fe7b5b1799926492f65c"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343a023193f04d46edc46b2616cdbee68c94dd10208ecd3adc56fcc54ef2baa1"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:585ffe93ae5894d1ebdee69fc0b0d4b7c75d8007983692fb300ac98eed146f78"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b0ef4e66f006ed181df29b59921bd8fc7ed7cd6a9289295cd8b2824b49b570df"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eb7b0bbf7cc1d0453b843eca7b5fa017874735bef9bfdfa4121373d2cc885ed6"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-win32.whl", hash = "sha256:1d043a8a06987cc0c98516e57c4d3fc2c1591364831e9deb59c9e1b4937e8caf"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:fefafcca09c3ac56372ef64a40f5fe17c5592fab906e0fdffd09543f3012ba50"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-win_arm64.whl", hash = "sha256:7e78b767da8b5fc5b2faa69bb001edafcd6f3995b42a331c53ef9572c55ceb82"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c2d05c7e73c60a4cecc7d9b60dbfd603b4ebc0adafaef371445b47d0f805c8a9"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:32ddaa3b2c509778ed5373b177eb2bf5662405493baeff52278a0b4f9415188b"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:dd382410039fe062097aa0292ab6335a3f1e7af7bba2ef8d27dcda484918f20c"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7fa22800f3908df31cea6fb230f20ac49e343515d968cc3a42b30d5c3ebf9b5a"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f366a57ac81f5e12797136552f5b7502fa053c861a009b91b80ed51f2ce651c6"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5f1dc8f1980a272ad4a6c84cba7981792344dad33bf5869361576b7aef42733a"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:2285c04ee8676f7938b02b4936d9b9b672064daab3187c20f73a55f3d70e6b4a"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c2492e4dd9daab63f5f56286f8a04c51323d237631eb98505d87e4c4ff19ec34"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-win32.whl", hash = "sha256:38a9109c4ee8135d5df5505384fc2f20287a47ccbe0b3f04c53c9a1989c2bbaf"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:6b87f1ad60b30bc3c43c66afa7db6b22a3109902e28c5094957626a0143a001f"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-win_arm64.whl", hash = "sha256:672a6c1da5aea6c629819a0e1461e89d244f78d7b60c424ecf4f1f2556c041d8"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ef3b83594d933020f54cf65ea1f4405d1f4e41a009c46df629dd964fcb6e907c"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2b96bfdf7c0ea9faebce088a3ecb2382819da4fbc05c7b80040dbc428df6af44"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:63df1fdaffa42d914d5c4d293e838937638bf75c794cf20bee12978fc8c4e3bc"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8002dc6a049aac0e81ecec97abfb08c01ef0c1fbf962d0c98da3950ace89b869"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63d4bb2966d6f5f705a6b0c6784c8969c468dbc4bcf9d9ded8bff1c7e092451f"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1f672efc0731a6846b157389b6e6d5d5e9e59d1d1a23a5c66a99fd58339914d5"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:3f39cef43d08049e8afc1fde4a5da8510fc6be843f8dea350ee46e2a26b2f54c"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2968647e3ed5a6c019a419264386b013979ff1fb67dd11f5c9886c43d6a31fc2"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-win32.whl", hash = "sha256:0d511dda38595b2b6934c2b730a1fd57a3635c6aa2a04cb74714cdfdd53846f4"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:9a86281794a393513cf117177fd39c796b3f8e3759bb2764259a2abba5cce54b"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-win_arm64.whl", hash = "sha256:cebd8e906eb98bb09c10d1feed16096700b1198d482267f8bf0474e63a7b8d84"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0520dff502da5e09d0d20781df74d8189ab334a1e40d5bafe2efaa4158e2d9e7"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d9cd64aca68f503ed3f1f18c7c9174cbb797baba02ca8ab5112f9d1c0328cd4b"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0913dd1613a33b13c4f84aa6e3f4198c1a21ee28ccb4f674985c1f22109f0aae"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1b7181c0feeb06ed8a02da02792f42f829a7b29990fef52eff257fef0885d760"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36d42b7396b605f774d4372dd9c49bed71cbabce4ae1ccd074d155709dd8f235"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b4fdc777e05c4940b297bf47bf7eedd56a39a61dc23ba798e4b830d585486ca5"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:42144e8e346de44a6f1dbd0a56575dd8ab8dfa7e9007da02ea5b1c30ab33a7db"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:66c644cbd7aed8fe266d5917e2c9f65458a51cfe5eeff9c05f15b335f697066e"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-win32.whl", hash = "sha256:2d1b73023854068c44b0c554578a4e1ef1b050ed07cf8b431549e624a29a66ee"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-win_amd64.whl", hash = "sha256:54a1532c8a642d8cc0bd5a9a51f5a9dcc440294fd06e9dda55e743c5ec1a8f14"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-win_arm64.whl", hash = "sha256:74d5b63fe3f5f5d372253a4ef92492c11a4305f3550631beaa432fc9df16fcff"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:68c5e0bc5f44f68053369fa0d94459c84548a77660a5f2561c5e5f1e3bed7031"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:cf33134ffae93865e32e1e37df043bef15a5e857d8caebc0099d225c579b0fa3"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ad8fa9d5193bafcf668231294241302b5e683a0518bf1e33a9a0dfb142ec3031"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:146fa1531973d38ab4b689bc764592fe6c2f913e7e80a39e7eeafd11f0ef6db2"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6013a37b8a4854c478d3219ee8bc2392dea51602dd0803a12d6f6182a0061762"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:eb90fe20db9c3d930fa2ad7a308207ab5b86bf6a76f54ab6a40be4012d88fcae"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:384b34482272e960c438703cafe63316dfbea124ac62006a455c8410bf2a2262"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:467dc74bd0a1a7de2bedf8deaf6811f43602cb532bd34d81ffd6038d6d8abe99"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-win32.whl", hash = "sha256:556d23d4e6393ca898b2e63a5bca91e9ac2d5fb13299ec286cd69a09a7187fde"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-win_amd64.whl", hash = "sha256:f4446a9547681533c8fa3e3c6cf62121eeee616e6a92bd9201c6edd91beffe13"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-win_arm64.whl", hash = "sha256:5e78bd9cf65da4c303bf663de0d73bf69f81e878bf72a94e9af67137c69b9fe9"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:5661bf987d91ec756a47c7e5df4fbcb949f39e32f9334ccd3f43233bbb65e508"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a46473129244db42a720439a26984f8c6f834762fc4573616c1f37f13994b357"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1f64b8d3415d60f24b058b58d859e9512624bdfa57a2d1f8aff93c1ec45c429b"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:44d43de99a9d90b20e0163f9770542357f58860a26e24dc1d924643bd6aa7cb4"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a931a87e5ddb6b6404e65443b742cb1c14959622777f2a4efd81fba84f5d91ba"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f9559b906a100029274448f4c8b8b0a127daa4dade5661dfd821b8c188058842"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:b08801e25e3b4526ef9ced1aa29344131a8f5213c60c03c18fe4c6170ffa2874"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ed9749bb8eda35f8b636fb7632f1c62f735a236a5d4edadd8bbcc5ea0542e732"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-win32.whl", hash = "sha256:609b60d123fc2cc63ccee6d17e4676699075db72d14ac3c107cc4976d516f2df"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-win_amd64.whl", hash = "sha256:0666cf3d2c1626b5a3463fd5b05f5e21f99e6aec40a3192eee4d07a15970b07f"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-win_arm64.whl", hash = "sha256:bc85eb2d35e760120540afddd3044a5bf69118a91a296a8b3940dfc4fdcfe1e2"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:62835c1b00c4a4ace24c1a88561a5a59b612fbb83a525d1c70ff5720c97c0610"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5255b3bbcc1d32a4069d6403820ac8e6dbcc1d68cb28a60a1ebf17e47028e898"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3876385722e335d6e991c430302c24251ef9c2a9701b2b390f5473199b1b8ebf"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8048ce4b149c93447a55d279078c8ae98b08a6951a3c4d2d7e87f4efc7bfe100"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4028e7558e268dd8bcf4d9484aad393cafa654c24b4885f6f9474bf53183a82a"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03f47dc870eec0367fcdd603ca6a01517d2504e83dc18dbfafae37faec66129a"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2d488d7d42b6ded7ea0704884f89dcabd2619505457de8fc9a6011c62106f6e5"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b3dcf2ead47fa8be14224ee817dfc1df98043af568fe120a22f81c0eb3c34ad2"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-win32.whl", hash = "sha256:02650a11324b80057b8c9c29487020073d5e98a498f1857f37e3f9b6ea1b2426"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:b45264dd450a10f9e03237b41a9a24e85cbb1e278e5a32adb1a303f58f0017f3"},
|
||||
{file = "coverage-7.10.5-py3-none-any.whl", hash = "sha256:0be24d35e4db1d23d0db5c0f6a74a962e2ec83c426b5cac09f4234aadef38e4a"},
|
||||
{file = "coverage-7.10.5.tar.gz", hash = "sha256:f2e57716a78bc3ae80b2207be0709a3b2b63b9f2dcf9740ee6ac03588a2015b6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
|
||||
|
||||
[package.extras]
|
||||
toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "45.0.6"
|
||||
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||
optional = false
|
||||
python-versions = "!=3.9.0,!=3.9.1,>=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "cryptography-45.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:048e7ad9e08cf4c0ab07ff7f36cc3115924e22e2266e034450a890d9e312dd74"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:44647c5d796f5fc042bbc6d61307d04bf29bccb74d188f18051b635f20a9c75f"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e40b80ecf35ec265c452eea0ba94c9587ca763e739b8e559c128d23bff7ebbbf"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:00e8724bdad672d75e6f069b27970883179bd472cd24a63f6e620ca7e41cc0c5"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7a3085d1b319d35296176af31c90338eeb2ddac8104661df79f80e1d9787b8b2"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1b7fa6a1c1188c7ee32e47590d16a5a0646270921f8020efc9a511648e1b2e08"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:275ba5cc0d9e320cd70f8e7b96d9e59903c815ca579ab96c1e37278d231fc402"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f4028f29a9f38a2025abedb2e409973709c660d44319c61762202206ed577c42"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ee411a1b977f40bd075392c80c10b58025ee5c6b47a822a33c1198598a7a5f05"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e2a21a8eda2d86bb604934b6b37691585bd095c1f788530c1fcefc53a82b3453"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-win32.whl", hash = "sha256:d063341378d7ee9c91f9d23b431a3502fc8bfacd54ef0a27baa72a0843b29159"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:833dc32dfc1e39b7376a87b9a6a4288a10aae234631268486558920029b086ec"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:3436128a60a5e5490603ab2adbabc8763613f638513ffa7d311c900a8349a2a0"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0d9ef57b6768d9fa58e92f4947cea96ade1233c0e236db22ba44748ffedca394"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea3c42f2016a5bbf71825537c2ad753f2870191134933196bee408aac397b3d9"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:20ae4906a13716139d6d762ceb3e0e7e110f7955f3bc3876e3a07f5daadec5f3"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2dac5ec199038b8e131365e2324c03d20e97fe214af051d20c49db129844e8b3"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:18f878a34b90d688982e43f4b700408b478102dd58b3e39de21b5ebf6509c301"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:5bd6020c80c5b2b2242d6c48487d7b85700f5e0038e67b29d706f98440d66eb5"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:eccddbd986e43014263eda489abbddfbc287af5cddfd690477993dbb31e31016"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:550ae02148206beb722cfe4ef0933f9352bab26b087af00e48fdfb9ade35c5b3"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5b64e668fc3528e77efa51ca70fadcd6610e8ab231e3e06ae2bab3b31c2b8ed9"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-win32.whl", hash = "sha256:780c40fb751c7d2b0c6786ceee6b6f871e86e8718a8ff4bc35073ac353c7cd02"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-win_amd64.whl", hash = "sha256:20d15aed3ee522faac1a39fbfdfee25d17b1284bafd808e1640a74846d7c4d1b"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:705bb7c7ecc3d79a50f236adda12ca331c8e7ecfbea51edd931ce5a7a7c4f012"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:826b46dae41a1155a0c0e66fafba43d0ede1dc16570b95e40c4d83bfcf0a451d"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:cc4d66f5dc4dc37b89cfef1bd5044387f7a1f6f0abb490815628501909332d5d"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:f68f833a9d445cc49f01097d95c83a850795921b3f7cc6488731e69bde3288da"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3b5bf5267e98661b9b888a9250d05b063220dfa917a8203744454573c7eb79db"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2384f2ab18d9be88a6e4f8972923405e2dbb8d3e16c6b43f15ca491d7831bd18"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fc022c1fa5acff6def2fc6d7819bbbd31ccddfe67d075331a65d9cfb28a20983"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3de77e4df42ac8d4e4d6cdb342d989803ad37707cf8f3fbf7b088c9cbdd46427"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:599c8d7df950aa68baa7e98f7b73f4f414c9f02d0e8104a30c0182a07732638b"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:31a2b9a10530a1cb04ffd6aa1cd4d3be9ed49f7d77a4dafe198f3b382f41545c"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:e5b3dda1b00fb41da3af4c5ef3f922a200e33ee5ba0f0bc9ecf0b0c173958385"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:629127cfdcdc6806dfe234734d7cb8ac54edaf572148274fa377a7d3405b0043"},
|
||||
{file = "cryptography-45.0.6.tar.gz", hash = "sha256:5c966c732cf6e4a276ce83b6e4c729edda2df6929083a952cc7da973c539c719"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cffi = {version = ">=1.14", markers = "platform_python_implementation != \"PyPy\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3.0)", "sphinx-inline-tabs ; python_full_version >= \"3.8.0\"", "sphinx-rtd-theme (>=3.0.0) ; python_full_version >= \"3.8.0\""]
|
||||
docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"]
|
||||
nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_full_version >= \"3.8.0\""]
|
||||
pep8test = ["check-sdist ; python_full_version >= \"3.8.0\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"]
|
||||
sdist = ["build (>=1.0.0)"]
|
||||
ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==45.0.6)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test-randomorder = ["pytest-randomly"]
|
||||
|
||||
[[package]]
|
||||
name = "deprecation"
|
||||
version = "2.1.0"
|
||||
@@ -235,7 +480,7 @@ version = "1.3.0"
|
||||
description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||
@@ -710,7 +955,7 @@ version = "2.1.0"
|
||||
description = "brain-dead simple config-ini parsing"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"},
|
||||
{file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
|
||||
@@ -757,6 +1002,18 @@ dynamodb = ["boto3 (>=1.9.71)"]
|
||||
redis = ["redis (>=2.10.5)"]
|
||||
test-filesource = ["pyyaml (>=5.3.1)", "watchdog (>=3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
description = "Node.js virtual environment builder"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"},
|
||||
{file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.35.0"
|
||||
@@ -779,7 +1036,7 @@ version = "25.0"
|
||||
description = "Core utilities for Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
|
||||
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
|
||||
@@ -791,7 +1048,7 @@ version = "1.6.0"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"},
|
||||
{file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"},
|
||||
@@ -883,6 +1140,19 @@ files = [
|
||||
[package.dependencies]
|
||||
pyasn1 = ">=0.6.1,<0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.22"
|
||||
description = "C parser in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "platform_python_implementation != \"PyPy\""
|
||||
files = [
|
||||
{file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"},
|
||||
{file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.11.7"
|
||||
@@ -1047,7 +1317,7 @@ version = "2.19.2"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"},
|
||||
{file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"},
|
||||
@@ -1068,6 +1338,9 @@ files = [
|
||||
{file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"crypto\""}
|
||||
|
||||
[package.extras]
|
||||
crypto = ["cryptography (>=3.4.0)"]
|
||||
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"]
|
||||
@@ -1086,13 +1359,34 @@ files = [
|
||||
{file = "pyrfc3339-2.0.1.tar.gz", hash = "sha256:e47843379ea35c1296c3b6c67a948a1a490ae0584edfcbdea0eaffb5dd29960b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.404"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pyright-1.1.404-py3-none-any.whl", hash = "sha256:c7b7ff1fdb7219c643079e4c3e7d4125f0dafcc19d253b47e898d130ea426419"},
|
||||
{file = "pyright-1.1.404.tar.gz", hash = "sha256:455e881a558ca6be9ecca0b30ce08aa78343ecc031d37a198ffa9a7a1abeb63e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nodeenv = ">=1.6.0"
|
||||
typing-extensions = ">=4.1"
|
||||
|
||||
[package.extras]
|
||||
all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
|
||||
dev = ["twine (>=3.4.1)"]
|
||||
nodejs = ["nodejs-wheel-binaries"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.1"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7"},
|
||||
{file = "pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c"},
|
||||
@@ -1116,7 +1410,7 @@ version = "1.1.0"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"},
|
||||
{file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"},
|
||||
@@ -1130,13 +1424,33 @@ pytest = ">=8.2,<9"
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "6.2.1"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5"},
|
||||
{file = "pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
coverage = {version = ">=7.5", extras = ["toml"]}
|
||||
pluggy = ">=1.2"
|
||||
pytest = ">=6.2.5"
|
||||
|
||||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-mock"
|
||||
version = "3.14.1"
|
||||
description = "Thin-wrapper around the mock package for easier use with pytest"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0"},
|
||||
{file = "pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e"},
|
||||
@@ -1253,30 +1567,31 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.3"
|
||||
version = "0.12.11"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2"},
|
||||
{file = "ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041"},
|
||||
{file = "ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f"},
|
||||
{file = "ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d"},
|
||||
{file = "ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7"},
|
||||
{file = "ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1"},
|
||||
{file = "ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77"},
|
||||
{file = "ruff-0.12.11-py3-none-linux_armv6l.whl", hash = "sha256:93fce71e1cac3a8bf9200e63a38ac5c078f3b6baebffb74ba5274fb2ab276065"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8e33ac7b28c772440afa80cebb972ffd823621ded90404f29e5ab6d1e2d4b93"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d69fb9d4937aa19adb2e9f058bc4fbfe986c2040acb1a4a9747734834eaa0bfd"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:411954eca8464595077a93e580e2918d0a01a19317af0a72132283e28ae21bee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a2c0a2e1a450f387bf2c6237c727dd22191ae8c00e448e0672d624b2bbd7fb0"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ca4c3a7f937725fd2413c0e884b5248a19369ab9bdd850b5781348ba283f644"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4d1df0098124006f6a66ecf3581a7f7e754c4df7644b2e6704cd7ca80ff95211"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a8dd5f230efc99a24ace3b77e3555d3fbc0343aeed3fc84c8d89e75ab2ff793"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc75533039d0ed04cd33fb8ca9ac9620b99672fe7ff1533b6402206901c34ee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fc58f9266d62c6eccc75261a665f26b4ef64840887fc6cbc552ce5b29f96cc8"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5a0113bd6eafd545146440225fe60b4e9489f59eb5f5f107acd715ba5f0b3d2f"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0d737b4059d66295c3ea5720e6efc152623bb83fde5444209b69cd33a53e2000"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:916fc5defee32dbc1fc1650b576a8fed68f5e8256e2180d4d9855aea43d6aab2"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c984f07d7adb42d3ded5be894fb4007f30f82c87559438b4879fe7aa08c62b39"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e07fbb89f2e9249f219d88331c833860489b49cdf4b032b8e4432e9b13e8a4b9"},
|
||||
{file = "ruff-0.12.11-py3-none-win32.whl", hash = "sha256:c792e8f597c9c756e9bcd4d87cf407a00b60af77078c96f7b6366ea2ce9ba9d3"},
|
||||
{file = "ruff-0.12.11-py3-none-win_amd64.whl", hash = "sha256:a3283325960307915b6deb3576b96919ee89432ebd9c48771ca12ee8afe4a0fd"},
|
||||
{file = "ruff-0.12.11-py3-none-win_arm64.whl", hash = "sha256:bae4d6e6a2676f8fb0f98b74594a048bae1b944aab17e9f5d504062303c6dbea"},
|
||||
{file = "ruff-0.12.11.tar.gz", hash = "sha256:c6b09ae8426a65bbee5425b9d0b82796dbb07cb1af045743c79bfb163001165d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1410,7 +1725,7 @@ version = "2.2.1"
|
||||
description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
@@ -1453,7 +1768,7 @@ version = "4.14.1"
|
||||
description = "Backported and Experimental Type Hints for Python 3.9+"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
|
||||
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
|
||||
@@ -1614,4 +1929,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "f67db13e6f68b1d67a55eee908c1c560bfa44da8509f98f842889a7570a9830f"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
|
||||
@@ -9,21 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.12.3"
|
||||
pyright = "^1.1.404"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -16,7 +16,6 @@ DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
ENABLE_AUTH=true
|
||||
|
||||
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
|
||||
# Redis Configuration
|
||||
@@ -31,7 +30,7 @@ RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
# Supabase Authentication
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
@@ -106,6 +105,15 @@ TODOIST_CLIENT_SECRET=
|
||||
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
|
||||
# Discord OAuth App credentials
|
||||
# 1. Go to https://discord.com/developers/applications
|
||||
# 2. Create a new application
|
||||
# 3. Go to OAuth2 section and add redirect URI: http://localhost:3000/auth/integrations/oauth_callback
|
||||
# 4. Copy Client ID and Client Secret below
|
||||
DISCORD_CLIENT_ID=
|
||||
DISCORD_CLIENT_SECRET=
|
||||
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
|
||||
@@ -166,4 +174,4 @@ SMARTLEAD_API_KEY=
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
AUTOMOD_API_KEY=
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
FROM python:3.11.10-slim-bookworm AS builder
|
||||
FROM debian:13-slim AS builder
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Update package list and install build dependencies in a single layer
|
||||
# Update package list and install Python and build dependencies
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
python3.13-venv \
|
||||
python3-pip \
|
||||
build-essential \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
@@ -19,13 +24,11 @@ RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV POETRY_VIRTUALENVS_CREATE=true
|
||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
RUN pip3 install poetry
|
||||
RUN pip3 install poetry --break-system-packages
|
||||
|
||||
# Copy and install dependencies
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
@@ -37,27 +40,30 @@ RUN poetry install --no-ansi --no-root
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM python:3.11.10-slim-bookworm AS server_dependencies
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false
|
||||
POETRY_VIRTUALENVS_CREATE=true \
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3.11 /usr/local/lib/python3.11
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||
RUN mkdir -p /app/autogpt_platform/backend
|
||||
|
||||
@@ -132,17 +132,58 @@ def test_endpoint_success(snapshot: Snapshot):
|
||||
|
||||
### Testing with Authentication
|
||||
|
||||
For the main API routes that use JWT authentication, auth is provided by the `autogpt_libs.auth` module. If the test actually uses the `user_id`, the recommended approach for testing is to mock the `get_jwt_payload` function, which underpins all higher-level auth functions used in the API (`requires_user`, `requires_admin_user`, `get_user_id`).
|
||||
|
||||
If the test doesn't need the `user_id` specifically, mocking is not necessary as during tests auth is disabled anyway (see `conftest.py`).
|
||||
|
||||
#### Using Global Auth Fixtures
|
||||
|
||||
Two global auth fixtures are provided by `backend/server/conftest.py`:
|
||||
|
||||
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
||||
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
||||
|
||||
These provide the easiest way to set up authentication mocking in test modules:
|
||||
|
||||
```python
|
||||
def override_auth_middleware():
|
||||
return {"sub": "test-user-id"}
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
def override_get_user_id():
|
||||
return "test-user-id"
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
app.dependency_overrides[auth_middleware] = override_auth_middleware
|
||||
app.dependency_overrides[get_user_id] = override_get_user_id
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user['get_jwt_payload']
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
```
|
||||
|
||||
For admin-only endpoints, use `mock_jwt_admin` instead:
|
||||
|
||||
```python
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_admin):
|
||||
"""Setup auth overrides for admin tests"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin['get_jwt_payload']
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
```
|
||||
|
||||
The IDs are also available separately as fixtures:
|
||||
|
||||
- `test_user_id`
|
||||
- `admin_user_id`
|
||||
- `target_user_id` (for admin <-> user operations)
|
||||
|
||||
### Mocking External Services
|
||||
|
||||
```python
|
||||
@@ -153,10 +194,10 @@ def test_external_api_call(mocker, snapshot):
|
||||
"backend.services.external_api.call",
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
|
||||
response = client.post("/api/process")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(
|
||||
json.dumps(response.json(), indent=2, sort_keys=True),
|
||||
@@ -187,6 +228,17 @@ def test_external_api_call(mocker, snapshot):
|
||||
- Use `async def` with `@pytest.mark.asyncio` for testing async functions directly
|
||||
|
||||
### 5. Fixtures
|
||||
|
||||
#### Global Fixtures (conftest.py)
|
||||
|
||||
Authentication fixtures are available globally from `conftest.py`:
|
||||
|
||||
- `mock_jwt_user` - Standard user authentication
|
||||
- `mock_jwt_admin` - Admin user authentication
|
||||
- `configured_snapshot` - Pre-configured snapshot fixture
|
||||
|
||||
#### Custom Fixtures
|
||||
|
||||
Create reusable fixtures for common test data:
|
||||
|
||||
```python
|
||||
@@ -202,9 +254,18 @@ def test_create_user(sample_user, snapshot):
|
||||
# ... test implementation
|
||||
```
|
||||
|
||||
#### Test Isolation
|
||||
|
||||
All tests must use fixtures that ensure proper isolation:
|
||||
|
||||
- Authentication overrides are automatically cleaned up after each test
|
||||
- Database connections are properly managed with cleanup
|
||||
- Mock objects are reset between tests
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
The GitHub Actions workflow automatically runs tests on:
|
||||
|
||||
- Pull requests
|
||||
- Pushes to main branch
|
||||
|
||||
@@ -216,16 +277,19 @@ Snapshot tests work in CI by:
|
||||
## Troubleshooting
|
||||
|
||||
### Snapshot Mismatches
|
||||
|
||||
- Review the diff carefully
|
||||
- If changes are expected: `poetry run pytest --snapshot-update`
|
||||
- If changes are unexpected: Fix the code causing the difference
|
||||
|
||||
### Async Test Issues
|
||||
|
||||
- Ensure async functions use `@pytest.mark.asyncio`
|
||||
- Use `AsyncMock` for mocking async functions
|
||||
- FastAPI TestClient handles async automatically
|
||||
|
||||
### Import Errors
|
||||
|
||||
- Check that all dependencies are in `pyproject.toml`
|
||||
- Run `poetry install` to ensure dependencies are installed
|
||||
- Verify import paths are correct
|
||||
@@ -234,4 +298,4 @@ Snapshot tests work in CI by:
|
||||
|
||||
Snapshot testing provides a powerful way to ensure API responses remain consistent. Combined with traditional assertions, it creates a robust test suite that catches regressions while remaining maintainable.
|
||||
|
||||
Remember: Good tests are as important as good code!
|
||||
Remember: Good tests are as important as good code!
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -12,7 +10,7 @@ from backend.data.block import (
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.execution import ExecutionStatus, NodesInputMasks
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
@@ -25,12 +23,15 @@ class AgentExecutorBlock(Block):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
agent_name: Optional[str] = SchemaField(
|
||||
default=None, description="Name to display in the Builder UI"
|
||||
)
|
||||
|
||||
inputs: BlockInput = SchemaField(description="Input data for the graph")
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
nodes_input_masks: Optional[NodesInputMasks] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
|
||||
|
||||
154
autogpt_platform/backend/backend/blocks/ai_image_customizer.py
Normal file
154
autogpt_platform/backend/backend/blocks/ai_image_customizer.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
JPG = "jpg"
|
||||
PNG = "png"
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class AIImageCustomizerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Google Gemini image models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="A text description of the image you want to generate",
|
||||
title="Prompt",
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
description="Optional list of input images to reference or modify",
|
||||
default=[],
|
||||
title="Input Images",
|
||||
)
|
||||
output_format: OutputFormat = SchemaField(
|
||||
description="Format of the output image",
|
||||
default=OutputFormat.PNG,
|
||||
title="Output Format",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
image_url: MediaFileType = SchemaField(description="URL of the generated image")
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
input_schema=AIImageCustomizerBlock.Input,
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"images": [],
|
||||
"output_format": OutputFormat.JPG,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||
"https://replicate.delivery/generated-image.jpg"
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.value,
|
||||
prompt=input_data.prompt,
|
||||
images=input_data.images,
|
||||
output_format=input_data.output_format.value,
|
||||
)
|
||||
yield "image_url", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
images: list[MediaFileType],
|
||||
output_format: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"output_format": output_format,
|
||||
}
|
||||
|
||||
# Add images to input if provided (API expects "image_input" parameter)
|
||||
if images:
|
||||
input_params["image_input"] = [str(img) for img in images]
|
||||
|
||||
output: FileOutput | str = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
if isinstance(output, FileOutput):
|
||||
return MediaFileType(output.url)
|
||||
if isinstance(output, str):
|
||||
return MediaFileType(output)
|
||||
|
||||
raise ValueError("No output received from the model")
|
||||
@@ -166,7 +166,7 @@ class AIMusicGeneratorBlock(Block):
|
||||
output_format=input_data.output_format,
|
||||
normalization_strategy=input_data.normalization_strategy,
|
||||
)
|
||||
if result and result != "No output received":
|
||||
if result and isinstance(result, str) and result.startswith("http"):
|
||||
yield "result", result
|
||||
return
|
||||
else:
|
||||
|
||||
@@ -661,6 +661,167 @@ async def update_field(
|
||||
#################################################################
|
||||
|
||||
|
||||
async def get_table_schema(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
table_id_or_name: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the schema for a specific table, including all field definitions.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The base ID
|
||||
table_id_or_name: The table ID or name
|
||||
|
||||
Returns:
|
||||
Dict containing table schema with fields information
|
||||
"""
|
||||
# First get all tables to find the right one
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
tables = data.get("tables", [])
|
||||
|
||||
# Find the matching table
|
||||
for table in tables:
|
||||
if table.get("id") == table_id_or_name or table.get("name") == table_id_or_name:
|
||||
return table
|
||||
|
||||
raise ValueError(f"Table '{table_id_or_name}' not found in base '{base_id}'")
|
||||
|
||||
|
||||
def get_empty_value_for_field(field_type: str) -> Any:
|
||||
"""
|
||||
Return the appropriate empty value for a given Airtable field type.
|
||||
|
||||
Args:
|
||||
field_type: The Airtable field type
|
||||
|
||||
Returns:
|
||||
The appropriate empty value for that field type
|
||||
"""
|
||||
# Fields that should be false when empty
|
||||
if field_type == "checkbox":
|
||||
return False
|
||||
|
||||
# Fields that should be empty arrays
|
||||
if field_type in [
|
||||
"multipleSelects",
|
||||
"multipleRecordLinks",
|
||||
"multipleAttachments",
|
||||
"multipleLookupValues",
|
||||
"multipleCollaborators",
|
||||
]:
|
||||
return []
|
||||
|
||||
# Fields that should be 0 when empty (numeric types)
|
||||
if field_type in [
|
||||
"number",
|
||||
"percent",
|
||||
"currency",
|
||||
"rating",
|
||||
"duration",
|
||||
"count",
|
||||
"autoNumber",
|
||||
]:
|
||||
return 0
|
||||
|
||||
# Fields that should be empty strings
|
||||
if field_type in [
|
||||
"singleLineText",
|
||||
"multilineText",
|
||||
"email",
|
||||
"url",
|
||||
"phoneNumber",
|
||||
"richText",
|
||||
"barcode",
|
||||
]:
|
||||
return ""
|
||||
|
||||
# Everything else gets null (dates, single selects, formulas, etc.)
|
||||
return None
|
||||
|
||||
|
||||
async def normalize_records(
|
||||
records: list[dict],
|
||||
table_schema: dict,
|
||||
include_field_metadata: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Normalize Airtable records to include all fields with proper empty values.
|
||||
|
||||
Args:
|
||||
records: List of record objects from Airtable API
|
||||
table_schema: Table schema containing field definitions
|
||||
include_field_metadata: Whether to include field metadata in response
|
||||
|
||||
Returns:
|
||||
Dict with normalized records and optionally field metadata
|
||||
"""
|
||||
fields = table_schema.get("fields", [])
|
||||
|
||||
# Normalize each record
|
||||
normalized_records = []
|
||||
for record in records:
|
||||
normalized = {
|
||||
"id": record.get("id"),
|
||||
"createdTime": record.get("createdTime"),
|
||||
"fields": {},
|
||||
}
|
||||
|
||||
# Add existing fields
|
||||
existing_fields = record.get("fields", {})
|
||||
|
||||
# Add all fields from schema, using empty values for missing ones
|
||||
for field in fields:
|
||||
field_name = field["name"]
|
||||
field_type = field["type"]
|
||||
|
||||
if field_name in existing_fields:
|
||||
# Field exists, use its value
|
||||
normalized["fields"][field_name] = existing_fields[field_name]
|
||||
else:
|
||||
# Field is missing, add appropriate empty value
|
||||
normalized["fields"][field_name] = get_empty_value_for_field(field_type)
|
||||
|
||||
normalized_records.append(normalized)
|
||||
|
||||
# Build result dictionary
|
||||
if include_field_metadata:
|
||||
field_metadata = {}
|
||||
for field in fields:
|
||||
metadata = {"type": field["type"], "id": field["id"]}
|
||||
|
||||
# Add type-specific metadata
|
||||
options = field.get("options", {})
|
||||
if field["type"] == "currency" and "symbol" in options:
|
||||
metadata["symbol"] = options["symbol"]
|
||||
metadata["precision"] = options.get("precision", 2)
|
||||
elif field["type"] == "duration" and "durationFormat" in options:
|
||||
metadata["format"] = options["durationFormat"]
|
||||
elif field["type"] == "percent" and "precision" in options:
|
||||
metadata["precision"] = options["precision"]
|
||||
elif (
|
||||
field["type"] in ["singleSelect", "multipleSelects"]
|
||||
and "choices" in options
|
||||
):
|
||||
metadata["choices"] = [choice["name"] for choice in options["choices"]]
|
||||
elif field["type"] == "rating" and "max" in options:
|
||||
metadata["max"] = options["max"]
|
||||
metadata["icon"] = options.get("icon", "star")
|
||||
metadata["color"] = options.get("color", "yellowBright")
|
||||
|
||||
field_metadata[field["name"]] = metadata
|
||||
|
||||
return {"records": normalized_records, "field_metadata": field_metadata}
|
||||
else:
|
||||
return {"records": normalized_records}
|
||||
|
||||
|
||||
async def list_records(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
@@ -1249,3 +1410,26 @@ async def list_bases(
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
async def get_base_tables(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get all tables for a specific base.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The ID of the base
|
||||
|
||||
Returns:
|
||||
list[dict]: List of table objects with their schemas
|
||||
"""
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
return data.get("tables", [])
|
||||
|
||||
@@ -14,13 +14,13 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import create_base, list_bases
|
||||
from ._api import create_base, get_base_tables, list_bases
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableCreateBaseBlock(Block):
|
||||
"""
|
||||
Creates a new base in an Airtable workspace.
|
||||
Creates a new base in an Airtable workspace, or returns existing base if one with the same name exists.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
@@ -31,6 +31,10 @@ class AirtableCreateBaseBlock(Block):
|
||||
description="The workspace ID where the base will be created"
|
||||
)
|
||||
name: str = SchemaField(description="The name of the new base")
|
||||
find_existing: bool = SchemaField(
|
||||
description="If true, return existing base with same name instead of creating duplicate",
|
||||
default=True,
|
||||
)
|
||||
tables: list[dict] = SchemaField(
|
||||
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
|
||||
default=[
|
||||
@@ -50,14 +54,18 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
base_id: str = SchemaField(description="The ID of the created base")
|
||||
base_id: str = SchemaField(description="The ID of the created or found base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if a new base was created, False if existing was found",
|
||||
default=True,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
|
||||
description="Create a new base in Airtable",
|
||||
description="Create or find a base in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
@@ -66,6 +74,31 @@ class AirtableCreateBaseBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# If find_existing is true, check if a base with this name already exists
|
||||
if input_data.find_existing:
|
||||
# List all bases to check for existing one with same name
|
||||
# Note: Airtable API doesn't have a direct search, so we need to list and filter
|
||||
existing_bases = await list_bases(credentials)
|
||||
|
||||
for base in existing_bases.get("bases", []):
|
||||
if base.get("name") == input_data.name:
|
||||
# Base already exists, return it
|
||||
base_id = base.get("id")
|
||||
yield "base_id", base_id
|
||||
yield "was_created", False
|
||||
|
||||
# Get the tables for this base
|
||||
try:
|
||||
tables = await get_base_tables(credentials, base_id)
|
||||
yield "tables", tables
|
||||
for table in tables:
|
||||
yield "table", table
|
||||
except Exception:
|
||||
# If we can't get tables, return empty list
|
||||
yield "tables", []
|
||||
return
|
||||
|
||||
# No existing base found or find_existing is false, create new one
|
||||
data = await create_base(
|
||||
credentials,
|
||||
input_data.workspace_id,
|
||||
@@ -74,6 +107,7 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
yield "base_id", data.get("id", None)
|
||||
yield "was_created", True
|
||||
yield "tables", data.get("tables", [])
|
||||
for table in data.get("tables", []):
|
||||
yield "table", table
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Airtable record operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
@@ -18,7 +18,9 @@ from ._api import (
|
||||
create_record,
|
||||
delete_multiple_records,
|
||||
get_record,
|
||||
get_table_schema,
|
||||
list_records,
|
||||
normalize_records,
|
||||
update_multiple_records,
|
||||
)
|
||||
from ._config import airtable
|
||||
@@ -54,12 +56,24 @@ class AirtableListRecordsBlock(Block):
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return (comma-separated)", default=[]
|
||||
)
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
)
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -73,6 +87,7 @@ class AirtableListRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
data = await list_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -88,8 +103,33 @@ class AirtableListRecordsBlock(Block):
|
||||
fields=input_data.return_fields if input_data.return_fields else None,
|
||||
)
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
records = data.get("records", [])
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
records,
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
yield "records", normalized_data["records"]
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "records", records
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
|
||||
class AirtableGetRecordBlock(Block):
|
||||
@@ -104,11 +144,23 @@ class AirtableGetRecordBlock(Block):
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -122,6 +174,7 @@ class AirtableGetRecordBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
record = await get_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -129,9 +182,34 @@ class AirtableGetRecordBlock(Block):
|
||||
input_data.record_id,
|
||||
)
|
||||
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the single record (wrap in list and unwrap result)
|
||||
normalized_data = await normalize_records(
|
||||
[record],
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
normalized_record = normalized_data["records"][0]
|
||||
yield "id", normalized_record.get("id", None)
|
||||
yield "fields", normalized_record.get("fields", None)
|
||||
yield "created_time", normalized_record.get("createdTime", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
@@ -148,6 +226,10 @@ class AirtableCreateRecordsBlock(Block):
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
skip_normalization: bool = SchemaField(
|
||||
description="Skip output normalization to get raw Airtable response (faster but may have missing fields)",
|
||||
default=False,
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
@@ -159,7 +241,6 @@ class AirtableCreateRecordsBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of created record objects")
|
||||
details: dict = SchemaField(description="Details of the created records")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -173,7 +254,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# The create_record API expects records in a specific format
|
||||
|
||||
data = await create_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -182,11 +263,22 @@ class AirtableCreateRecordsBlock(Block):
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=input_data.return_fields_by_field_id,
|
||||
)
|
||||
result_records = cast(list[dict], data.get("records", []))
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
# Normalize output unless explicitly disabled
|
||||
if not input_data.skip_normalization and result_records:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
result_records, table_schema, include_field_metadata=False
|
||||
)
|
||||
result_records = normalized_data["records"]
|
||||
|
||||
yield "records", result_records
|
||||
|
||||
|
||||
class AirtableUpdateRecordsBlock(Block):
|
||||
|
||||
205
autogpt_platform/backend/backend/blocks/baas/_api.py
Normal file
205
autogpt_platform/backend/backend/blocks/baas/_api.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Meeting BaaS API client module.
|
||||
All API calls centralized for consistency and maintainability.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import Requests
|
||||
|
||||
|
||||
class MeetingBaasAPI:
|
||||
"""Client for Meeting BaaS API endpoints."""
|
||||
|
||||
BASE_URL = "https://api.meetingbaas.com"
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""Initialize API client with authentication key."""
|
||||
self.api_key = api_key
|
||||
self.headers = {"x-meeting-baas-api-key": api_key}
|
||||
self.requests = Requests()
|
||||
|
||||
# Bot Management Endpoints
|
||||
|
||||
async def join_meeting(
|
||||
self,
|
||||
bot_name: str,
|
||||
meeting_url: str,
|
||||
reserved: bool = False,
|
||||
bot_image: Optional[str] = None,
|
||||
entry_message: Optional[str] = None,
|
||||
start_time: Optional[int] = None,
|
||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
automatic_leave: Optional[Dict[str, Any]] = None,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
recording_mode: str = "speaker_view",
|
||||
streaming: Optional[Dict[str, Any]] = None,
|
||||
deduplication_key: Optional[str] = None,
|
||||
zoom_sdk_id: Optional[str] = None,
|
||||
zoom_sdk_pwd: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Deploy a bot to join and record a meeting.
|
||||
|
||||
POST /bots
|
||||
"""
|
||||
body = {
|
||||
"bot_name": bot_name,
|
||||
"meeting_url": meeting_url,
|
||||
"reserved": reserved,
|
||||
"recording_mode": recording_mode,
|
||||
}
|
||||
|
||||
# Add optional fields if provided
|
||||
if bot_image is not None:
|
||||
body["bot_image"] = bot_image
|
||||
if entry_message is not None:
|
||||
body["entry_message"] = entry_message
|
||||
if start_time is not None:
|
||||
body["start_time"] = start_time
|
||||
if speech_to_text is not None:
|
||||
body["speech_to_text"] = speech_to_text
|
||||
if webhook_url is not None:
|
||||
body["webhook_url"] = webhook_url
|
||||
if automatic_leave is not None:
|
||||
body["automatic_leave"] = automatic_leave
|
||||
if extra is not None:
|
||||
body["extra"] = extra
|
||||
if streaming is not None:
|
||||
body["streaming"] = streaming
|
||||
if deduplication_key is not None:
|
||||
body["deduplication_key"] = deduplication_key
|
||||
if zoom_sdk_id is not None:
|
||||
body["zoom_sdk_id"] = zoom_sdk_id
|
||||
if zoom_sdk_pwd is not None:
|
||||
body["zoom_sdk_pwd"] = zoom_sdk_pwd
|
||||
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots",
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def leave_meeting(self, bot_id: str) -> bool:
|
||||
"""
|
||||
Remove a bot from an ongoing meeting.
|
||||
|
||||
DELETE /bots/{uuid}
|
||||
"""
|
||||
response = await self.requests.delete(
|
||||
f"{self.BASE_URL}/bots/{bot_id}",
|
||||
headers=self.headers,
|
||||
)
|
||||
return response.status in [200, 204]
|
||||
|
||||
async def retranscribe(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Re-run transcription on a bot's audio.
|
||||
|
||||
POST /bots/retranscribe
|
||||
"""
|
||||
body: Dict[str, Any] = {"bot_uuid": bot_uuid}
|
||||
|
||||
if speech_to_text is not None:
|
||||
body["speech_to_text"] = speech_to_text
|
||||
if webhook_url is not None:
|
||||
body["webhook_url"] = webhook_url
|
||||
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots/retranscribe",
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
if response.status == 202:
|
||||
return {"accepted": True}
|
||||
return response.json()
|
||||
|
||||
# Data Retrieval Endpoints
|
||||
|
||||
async def get_meeting_data(
|
||||
self, bot_id: str, include_transcripts: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve meeting data including recording and transcripts.
|
||||
|
||||
GET /bots/meeting_data
|
||||
"""
|
||||
params = {
|
||||
"bot_id": bot_id,
|
||||
"include_transcripts": str(include_transcripts).lower(),
|
||||
}
|
||||
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/meeting_data",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def get_screenshots(self, bot_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve screenshots captured during a meeting.
|
||||
|
||||
GET /bots/{uuid}/screenshots
|
||||
"""
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/{bot_id}/screenshots",
|
||||
headers=self.headers,
|
||||
)
|
||||
result = response.json()
|
||||
# Ensure we return a list
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
return []
|
||||
|
||||
async def delete_data(self, bot_id: str) -> bool:
|
||||
"""
|
||||
Delete a bot's recorded data.
|
||||
|
||||
POST /bots/{uuid}/delete_data
|
||||
"""
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots/{bot_id}/delete_data",
|
||||
headers=self.headers,
|
||||
)
|
||||
return response.status == 200
|
||||
|
||||
async def list_bots_with_metadata(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: Optional[str] = None,
|
||||
filter_by: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
List bots with metadata including IDs, names, and meeting details.
|
||||
|
||||
GET /bots/bots_with_metadata
|
||||
"""
|
||||
params = {}
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if sort_by is not None:
|
||||
params["sort_by"] = sort_by
|
||||
if sort_order is not None:
|
||||
params["sort_order"] = sort_order
|
||||
if filter_by is not None:
|
||||
params.update(filter_by)
|
||||
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/bots_with_metadata",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
13
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
13
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Shared configuration for all Meeting BaaS blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Configure the Meeting BaaS provider with API key authentication
|
||||
baas = (
|
||||
ProviderBuilder("baas")
|
||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
||||
.build()
|
||||
)
|
||||
217
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
217
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Meeting BaaS bot (recording) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import MeetingBaasAPI
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasBotJoinMeetingBlock(Block):
|
||||
"""
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
meeting_url: str = SchemaField(
|
||||
description="The URL of the meeting the bot should join"
|
||||
)
|
||||
bot_name: str = SchemaField(
|
||||
description="Display name for the bot in the meeting"
|
||||
)
|
||||
bot_image: str = SchemaField(
|
||||
description="URL to an image for the bot's avatar (16:9 ratio recommended)",
|
||||
default="",
|
||||
)
|
||||
entry_message: str = SchemaField(
|
||||
description="Chat message the bot will post upon entry", default=""
|
||||
)
|
||||
reserved: bool = SchemaField(
|
||||
description="Use a reserved bot slot (joins 4 min before meeting)",
|
||||
default=False,
|
||||
)
|
||||
start_time: Optional[int] = SchemaField(
|
||||
description="Unix timestamp (ms) when bot should join", default=None
|
||||
)
|
||||
webhook_url: str | None = SchemaField(
|
||||
description="URL to receive webhook events for this bot", default=None
|
||||
)
|
||||
timeouts: dict = SchemaField(
|
||||
description="Automatic leave timeouts configuration", default={}
|
||||
)
|
||||
extra: dict = SchemaField(
|
||||
description="Custom metadata to attach to the bot", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
||||
join_response: dict = SchemaField(
|
||||
description="Full response from join operation"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="377d1a6a-a99b-46cf-9af3-1d1b12758e04",
|
||||
description="Deploy a bot to join and record a meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Call API with all parameters
|
||||
data = await api.join_meeting(
|
||||
bot_name=input_data.bot_name,
|
||||
meeting_url=input_data.meeting_url,
|
||||
reserved=input_data.reserved,
|
||||
bot_image=input_data.bot_image if input_data.bot_image else None,
|
||||
entry_message=(
|
||||
input_data.entry_message if input_data.entry_message else None
|
||||
),
|
||||
start_time=input_data.start_time,
|
||||
speech_to_text={"provider": "Default"},
|
||||
webhook_url=input_data.webhook_url if input_data.webhook_url else None,
|
||||
automatic_leave=input_data.timeouts if input_data.timeouts else None,
|
||||
extra=input_data.extra if input_data.extra else None,
|
||||
)
|
||||
|
||||
yield "bot_id", data.get("bot_id", "")
|
||||
yield "join_response", data
|
||||
|
||||
|
||||
class BaasBotLeaveMeetingBlock(Block):
|
||||
"""
|
||||
Force the bot to exit the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
||||
|
||||
class Output(BlockSchema):
|
||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf77d128-8b25-4280-b5c7-2d553ba7e482",
|
||||
description="Remove a bot from an ongoing meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Leave meeting
|
||||
left = await api.leave_meeting(input_data.bot_id)
|
||||
|
||||
yield "left", left
|
||||
|
||||
|
||||
class BaasBotFetchMeetingDataBlock(Block):
|
||||
"""
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to fetch")
|
||||
include_transcripts: bool = SchemaField(
|
||||
description="Include transcript data in response", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
mp4_url: str = SchemaField(
|
||||
description="URL to download the meeting recording (time-limited)"
|
||||
)
|
||||
transcript: list = SchemaField(description="Meeting transcript data")
|
||||
metadata: dict = SchemaField(description="Meeting metadata and bot information")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ea7c1309-303c-4da1-893f-89c0e9d64e78",
|
||||
description="Retrieve recorded meeting data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Fetch meeting data
|
||||
data = await api.get_meeting_data(
|
||||
bot_id=input_data.bot_id,
|
||||
include_transcripts=input_data.include_transcripts,
|
||||
)
|
||||
|
||||
yield "mp4_url", data.get("mp4", "")
|
||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
||||
|
||||
|
||||
class BaasBotDeleteRecordingBlock(Block):
|
||||
"""
|
||||
Purge MP4 + transcript data for privacy or storage management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the data was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf8d1aa6-42d8-4944-b6bd-6bac554c0d3b",
|
||||
description="Permanently delete a meeting's recorded data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Delete recording data
|
||||
deleted = await api.delete_data(input_data.bot_id)
|
||||
|
||||
yield "deleted", deleted
|
||||
@@ -0,0 +1,3 @@
|
||||
from .text_overlay import BannerbearTextOverlayBlock
|
||||
|
||||
__all__ = ["BannerbearTextOverlayBlock"]
|
||||
@@ -0,0 +1,8 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
bannerbear = (
|
||||
ProviderBuilder("bannerbear")
|
||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -0,0 +1,239 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import bannerbear
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="bannerbear",
|
||||
api_key=SecretStr("mock-bannerbear-api-key"),
|
||||
title="Mock Bannerbear API Key",
|
||||
)
|
||||
|
||||
|
||||
class TextModification(BlockSchema):
|
||||
name: str = SchemaField(
|
||||
description="The name of the layer to modify in the template"
|
||||
)
|
||||
text: str = SchemaField(description="The text content to add to this layer")
|
||||
color: str = SchemaField(
|
||||
description="Hex color code for the text (e.g., '#FF0000')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_family: str = SchemaField(
|
||||
description="Font family to use for the text",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_size: int = SchemaField(
|
||||
description="Font size in pixels",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
font_weight: str = SchemaField(
|
||||
description="Font weight (e.g., 'bold', 'normal')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_align: str = SchemaField(
|
||||
description="Text alignment (left, center, right)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class BannerbearTextOverlayBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = bannerbear.credentials_field(
|
||||
description="API credentials for Bannerbear"
|
||||
)
|
||||
template_id: str = SchemaField(
|
||||
description="The unique ID of your Bannerbear template"
|
||||
)
|
||||
project_id: str = SchemaField(
|
||||
description="Optional: Project ID (required when using Master API Key)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_modifications: List[TextModification] = SchemaField(
|
||||
description="List of text layers to modify in the template"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="Optional: URL of an image to use in the template",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
image_layer_name: str = SchemaField(
|
||||
description="Optional: Name of the image layer in the template",
|
||||
default="photo",
|
||||
advanced=True,
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="Optional: URL to receive webhook notification when image is ready",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: str = SchemaField(
|
||||
description="Optional: Custom metadata to attach to the image",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the image generation was successfully initiated"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="URL of the generated image (if synchronous) or placeholder"
|
||||
)
|
||||
uid: str = SchemaField(description="Unique identifier for the generated image")
|
||||
status: str = SchemaField(description="Status of the image generation")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c7d3a5c2-05fc-450e-8dce-3b0e04626009",
|
||||
description="Add text overlay to images using Bannerbear templates. Perfect for creating social media graphics, marketing materials, and dynamic image content.",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"template_id": "jJWBKNELpQPvbX5R93Gk",
|
||||
"text_modifications": [
|
||||
{
|
||||
"name": "headline",
|
||||
"text": "Amazing Product Launch!",
|
||||
"color": "#FF0000",
|
||||
},
|
||||
{
|
||||
"name": "subtitle",
|
||||
"text": "50% OFF Today Only",
|
||||
},
|
||||
],
|
||||
"credentials": {
|
||||
"provider": "bannerbear",
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "api_key",
|
||||
},
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
||||
("uid", "test-uid-123"),
|
||||
("status", "completed"),
|
||||
],
|
||||
test_mock={
|
||||
"_make_api_request": lambda *args, **kwargs: {
|
||||
"uid": "test-uid-123",
|
||||
"status": "completed",
|
||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
||||
}
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def _make_api_request(self, payload: dict, api_key: str) -> dict:
|
||||
"""Make the actual API request to Bannerbear. This is separated for easy mocking in tests."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
"https://sync.api.bannerbear.com/v2/images",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status in [200, 201, 202]:
|
||||
return response.json()
|
||||
else:
|
||||
error_msg = f"API request failed with status {response.status}"
|
||||
if response.text:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = (
|
||||
f"{error_msg}: {error_data.get('message', response.text)}"
|
||||
)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg}: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build the modifications array
|
||||
modifications = []
|
||||
|
||||
# Add text modifications
|
||||
for text_mod in input_data.text_modifications:
|
||||
mod_data: Dict[str, Any] = {
|
||||
"name": text_mod.name,
|
||||
"text": text_mod.text,
|
||||
}
|
||||
|
||||
# Add optional text styling parameters only if they have values
|
||||
if text_mod.color and text_mod.color.strip():
|
||||
mod_data["color"] = text_mod.color
|
||||
if text_mod.font_family and text_mod.font_family.strip():
|
||||
mod_data["font_family"] = text_mod.font_family
|
||||
if text_mod.font_size and text_mod.font_size > 0:
|
||||
mod_data["font_size"] = text_mod.font_size
|
||||
if text_mod.font_weight and text_mod.font_weight.strip():
|
||||
mod_data["font_weight"] = text_mod.font_weight
|
||||
if text_mod.text_align and text_mod.text_align.strip():
|
||||
mod_data["text_align"] = text_mod.text_align
|
||||
|
||||
modifications.append(mod_data)
|
||||
|
||||
# Add image modification if provided and not empty
|
||||
if input_data.image_url and input_data.image_url.strip():
|
||||
modifications.append(
|
||||
{
|
||||
"name": input_data.image_layer_name,
|
||||
"image_url": input_data.image_url,
|
||||
}
|
||||
)
|
||||
|
||||
# Build the request payload - only include non-empty optional fields
|
||||
payload = {
|
||||
"template": input_data.template_id,
|
||||
"modifications": modifications,
|
||||
}
|
||||
|
||||
# Add project_id if provided (required for Master API keys)
|
||||
if input_data.project_id and input_data.project_id.strip():
|
||||
payload["project_id"] = input_data.project_id
|
||||
|
||||
if input_data.webhook_url and input_data.webhook_url.strip():
|
||||
payload["webhook_url"] = input_data.webhook_url
|
||||
if input_data.metadata and input_data.metadata.strip():
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
# Make the API request using the private method
|
||||
data = await self._make_api_request(
|
||||
payload, credentials.api_key.get_secret_value()
|
||||
)
|
||||
|
||||
# Synchronous request - image should be ready
|
||||
yield "success", True
|
||||
yield "image_url", data.get("image_url", "")
|
||||
yield "uid", data.get("uid", "")
|
||||
yield "status", data.get("status", "completed")
|
||||
178
autogpt_platform/backend/backend/blocks/dataforseo/_api.py
Normal file
178
autogpt_platform/backend/backend/blocks/dataforseo/_api.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
DataForSEO API client with async support using the SDK patterns.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import Requests, UserPasswordCredentials
|
||||
|
||||
|
||||
class DataForSeoClient:
|
||||
"""Client for the DataForSEO API using async requests."""
|
||||
|
||||
API_URL = "https://api.dataforseo.com"
|
||||
|
||||
def __init__(self, credentials: UserPasswordCredentials):
|
||||
self.credentials = credentials
|
||||
self.requests = Requests(
|
||||
trusted_origins=["https://api.dataforseo.com"],
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Generate the authorization header using Basic Auth."""
|
||||
username = self.credentials.username.get_secret_value()
|
||||
password = self.credentials.password.get_secret_value()
|
||||
credentials_str = f"{username}:{password}"
|
||||
encoded = base64.b64encode(credentials_str.encode("ascii")).decode("ascii")
|
||||
return {
|
||||
"Authorization": f"Basic {encoded}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def keyword_suggestions(
|
||||
self,
|
||||
keyword: str,
|
||||
location_code: Optional[int] = None,
|
||||
language_code: Optional[str] = None,
|
||||
include_seed_keyword: bool = True,
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get keyword suggestions from DataForSEO Labs.
|
||||
|
||||
Args:
|
||||
keyword: Seed keyword
|
||||
location_code: Location code for targeting
|
||||
language_code: Language code (e.g., "en")
|
||||
include_seed_keyword: Include seed keyword in results
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
|
||||
Returns:
|
||||
API response with keyword suggestions
|
||||
"""
|
||||
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/keyword_suggestions/live"
|
||||
|
||||
# Build payload only with non-None values to avoid sending null fields
|
||||
task_data: dict[str, Any] = {
|
||||
"keyword": keyword,
|
||||
}
|
||||
|
||||
if location_code is not None:
|
||||
task_data["location_code"] = location_code
|
||||
if language_code is not None:
|
||||
task_data["language_code"] = language_code
|
||||
if include_seed_keyword is not None:
|
||||
task_data["include_seed_keyword"] = include_seed_keyword
|
||||
if include_serp_info is not None:
|
||||
task_data["include_serp_info"] = include_serp_info
|
||||
if include_clickstream_data is not None:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
response = await self.requests.post(
|
||||
endpoint,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check for API errors
|
||||
if response.status != 200:
|
||||
error_message = data.get("status_message", "Unknown error")
|
||||
raise Exception(
|
||||
f"DataForSEO API error ({response.status}): {error_message}"
|
||||
)
|
||||
|
||||
# Extract the results from the response
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
raise Exception(f"DataForSEO task error: {error_msg}")
|
||||
|
||||
return []
|
||||
|
||||
async def related_keywords(
|
||||
self,
|
||||
keyword: str,
|
||||
location_code: Optional[int] = None,
|
||||
language_code: Optional[str] = None,
|
||||
include_seed_keyword: bool = True,
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get related keywords from DataForSEO Labs.
|
||||
|
||||
Args:
|
||||
keyword: Seed keyword
|
||||
location_code: Location code for targeting
|
||||
language_code: Language code (e.g., "en")
|
||||
include_seed_keyword: Include seed keyword in results
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
|
||||
Returns:
|
||||
API response with related keywords
|
||||
"""
|
||||
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/related_keywords/live"
|
||||
|
||||
# Build payload only with non-None values to avoid sending null fields
|
||||
task_data: dict[str, Any] = {
|
||||
"keyword": keyword,
|
||||
}
|
||||
|
||||
if location_code is not None:
|
||||
task_data["location_code"] = location_code
|
||||
if language_code is not None:
|
||||
task_data["language_code"] = language_code
|
||||
if include_seed_keyword is not None:
|
||||
task_data["include_seed_keyword"] = include_seed_keyword
|
||||
if include_serp_info is not None:
|
||||
task_data["include_serp_info"] = include_serp_info
|
||||
if include_clickstream_data is not None:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
response = await self.requests.post(
|
||||
endpoint,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check for API errors
|
||||
if response.status != 200:
|
||||
error_message = data.get("status_message", "Unknown error")
|
||||
raise Exception(
|
||||
f"DataForSEO API error ({response.status}): {error_message}"
|
||||
)
|
||||
|
||||
# Extract the results from the response
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
raise Exception(f"DataForSEO task error: {error_msg}")
|
||||
|
||||
return []
|
||||
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Configuration for all DataForSEO blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Build the DataForSEO provider with username/password authentication
|
||||
dataforseo = (
|
||||
ProviderBuilder("dataforseo")
|
||||
.with_user_password(
|
||||
username_env_var="DATAFORSEO_USERNAME",
|
||||
password_env_var="DATAFORSEO_PASSWORD",
|
||||
title="DataForSEO Credentials",
|
||||
)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
DataForSEO Google Keyword Suggestions block.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class KeywordSuggestion(BlockSchema):
|
||||
"""Schema for a keyword suggestion result."""
|
||||
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="data from SERP for each keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
|
||||
class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
"""Block for getting keyword suggestions from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
keyword: str = SchemaField(description="Seed keyword to get suggestions for")
|
||||
location_code: Optional[int] = SchemaField(
|
||||
description="Location code for targeting (e.g., 2840 for USA)",
|
||||
default=2840, # USA
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (e.g., 'en' for English)",
|
||||
default="en",
|
||||
)
|
||||
include_seed_keyword: bool = SchemaField(
|
||||
description="Include the seed keyword in results",
|
||||
default=True,
|
||||
)
|
||||
include_serp_info: bool = SchemaField(
|
||||
description="Include SERP information",
|
||||
default=False,
|
||||
)
|
||||
include_clickstream_data: bool = SchemaField(
|
||||
description="Include clickstream metrics",
|
||||
default=False,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results (up to 3000)",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
suggestions: List[KeywordSuggestion] = SchemaField(
|
||||
description="List of keyword suggestions with metrics"
|
||||
)
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="A single keyword suggestion with metrics"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of suggestions returned"
|
||||
)
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="73c3e7c4-2b3f-4e9f-9e3e-8f7a5c3e2d45",
|
||||
description="Get keyword suggestions from DataForSEO Labs Google API",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": dataforseo.get_test_credentials().model_dump(),
|
||||
"keyword": "digital marketing",
|
||||
"location_code": 2840,
|
||||
"language_code": "en",
|
||||
"limit": 1,
|
||||
},
|
||||
test_credentials=dataforseo.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"suggestion",
|
||||
lambda x: hasattr(x, "keyword")
|
||||
and x.keyword == "digital marketing strategy",
|
||||
),
|
||||
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
|
||||
("total_count", 1),
|
||||
("seed_keyword", "digital marketing"),
|
||||
],
|
||||
test_mock={
|
||||
"_fetch_keyword_suggestions": lambda *args, **kwargs: [
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"keyword": "digital marketing strategy",
|
||||
"keyword_info": {
|
||||
"search_volume": 10000,
|
||||
"competition": 0.5,
|
||||
"cpc": 2.5,
|
||||
},
|
||||
"keyword_properties": {
|
||||
"keyword_difficulty": 50,
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def _fetch_keyword_suggestions(
|
||||
self,
|
||||
client: DataForSeoClient,
|
||||
input_data: Input,
|
||||
) -> Any:
|
||||
"""Private method to fetch keyword suggestions - can be mocked for testing."""
|
||||
return await client.keyword_suggestions(
|
||||
keyword=input_data.keyword,
|
||||
location_code=input_data.location_code,
|
||||
language_code=input_data.language_code,
|
||||
include_seed_keyword=input_data.include_seed_keyword,
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: UserPasswordCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
"""Extracts individual fields from a KeywordSuggestion object."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="The keyword suggestion object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="data from SERP for each keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4193cb94-677c-48b0-9eec-6ac72fffd0f2",
|
||||
description="Extract individual fields from a KeywordSuggestion object",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"suggestion": KeywordSuggestion(
|
||||
keyword="test keyword",
|
||||
search_volume=1000,
|
||||
competition=0.5,
|
||||
cpc=2.5,
|
||||
keyword_difficulty=60,
|
||||
).model_dump()
|
||||
},
|
||||
test_output=[
|
||||
("keyword", "test keyword"),
|
||||
("search_volume", 1000),
|
||||
("competition", 0.5),
|
||||
("cpc", 2.5),
|
||||
("keyword_difficulty", 60),
|
||||
("serp_info", None),
|
||||
("clickstream_data", None),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Extract fields from the KeywordSuggestion object."""
|
||||
suggestion = input_data.suggestion
|
||||
|
||||
yield "keyword", suggestion.keyword
|
||||
yield "search_volume", suggestion.search_volume
|
||||
yield "competition", suggestion.competition
|
||||
yield "cpc", suggestion.cpc
|
||||
yield "keyword_difficulty", suggestion.keyword_difficulty
|
||||
yield "serp_info", suggestion.serp_info
|
||||
yield "clickstream_data", suggestion.clickstream_data
|
||||
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
DataForSEO Google Related Keywords block.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class RelatedKeyword(BlockSchema):
|
||||
"""Schema for a related keyword result."""
|
||||
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="SERP data for the keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
|
||||
class DataForSeoRelatedKeywordsBlock(Block):
|
||||
"""Block for getting related keywords from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
keyword: str = SchemaField(
|
||||
description="Seed keyword to find related keywords for"
|
||||
)
|
||||
location_code: Optional[int] = SchemaField(
|
||||
description="Location code for targeting (e.g., 2840 for USA)",
|
||||
default=2840, # USA
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (e.g., 'en' for English)",
|
||||
default="en",
|
||||
)
|
||||
include_seed_keyword: bool = SchemaField(
|
||||
description="Include the seed keyword in results",
|
||||
default=True,
|
||||
)
|
||||
include_serp_info: bool = SchemaField(
|
||||
description="Include SERP information",
|
||||
default=False,
|
||||
)
|
||||
include_clickstream_data: bool = SchemaField(
|
||||
description="Include clickstream metrics",
|
||||
default=False,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results (up to 3000)",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
||||
description="List of related keywords with metrics"
|
||||
)
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="A related keyword with metrics"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of related keywords returned"
|
||||
)
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f2e4d6a-1b3c-4a5e-9d7f-2c8e6a4b3f1d",
|
||||
description="Get related keywords from DataForSEO Labs Google API",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": dataforseo.get_test_credentials().model_dump(),
|
||||
"keyword": "content marketing",
|
||||
"location_code": 2840,
|
||||
"language_code": "en",
|
||||
"limit": 1,
|
||||
},
|
||||
test_credentials=dataforseo.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"related_keyword",
|
||||
lambda x: hasattr(x, "keyword") and x.keyword == "content strategy",
|
||||
),
|
||||
("related_keywords", lambda x: isinstance(x, list) and len(x) == 1),
|
||||
("total_count", 1),
|
||||
("seed_keyword", "content marketing"),
|
||||
],
|
||||
test_mock={
|
||||
"_fetch_related_keywords": lambda *args, **kwargs: [
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"keyword_data": {
|
||||
"keyword": "content strategy",
|
||||
"keyword_info": {
|
||||
"search_volume": 8000,
|
||||
"competition": 0.4,
|
||||
"cpc": 3.0,
|
||||
},
|
||||
"keyword_properties": {
|
||||
"keyword_difficulty": 45,
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def _fetch_related_keywords(
|
||||
self,
|
||||
client: DataForSeoClient,
|
||||
input_data: Input,
|
||||
) -> Any:
|
||||
"""Private method to fetch related keywords - can be mocked for testing."""
|
||||
return await client.related_keywords(
|
||||
keyword=input_data.keyword,
|
||||
location_code=input_data.location_code,
|
||||
language_code=input_data.language_code,
|
||||
include_seed_keyword=input_data.include_seed_keyword,
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: UserPasswordCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
"""Extracts individual fields from a RelatedKeyword object."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="The related keyword object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="SERP data for the keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="98342061-09d2-4952-bf77-0761fc8cc9a8",
|
||||
description="Extract individual fields from a RelatedKeyword object",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"related_keyword": RelatedKeyword(
|
||||
keyword="test related keyword",
|
||||
search_volume=800,
|
||||
competition=0.4,
|
||||
cpc=3.0,
|
||||
keyword_difficulty=55,
|
||||
).model_dump()
|
||||
},
|
||||
test_output=[
|
||||
("keyword", "test related keyword"),
|
||||
("search_volume", 800),
|
||||
("competition", 0.4),
|
||||
("cpc", 3.0),
|
||||
("keyword_difficulty", 55),
|
||||
("serp_info", None),
|
||||
("clickstream_data", None),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Extract fields from the RelatedKeyword object."""
|
||||
related_keyword = input_data.related_keyword
|
||||
|
||||
yield "keyword", related_keyword.keyword
|
||||
yield "search_volume", related_keyword.search_volume
|
||||
yield "competition", related_keyword.competition
|
||||
yield "cpc", related_keyword.cpc
|
||||
yield "keyword_difficulty", related_keyword.keyword_difficulty
|
||||
yield "serp_info", related_keyword.serp_info
|
||||
yield "clickstream_data", related_keyword.clickstream_data
|
||||
117
autogpt_platform/backend/backend/blocks/discord/_api.py
Normal file
117
autogpt_platform/backend/backend/blocks/discord/_api.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Discord API helper functions for making authenticated requests.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscordAPIException(Exception):
|
||||
"""Exception raised for Discord API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class DiscordOAuthUser(BaseModel):
|
||||
"""Model for Discord OAuth user response."""
|
||||
|
||||
user_id: str
|
||||
username: str
|
||||
avatar_url: str
|
||||
banner: Optional[str] = None
|
||||
accent_color: Optional[int] = None
|
||||
|
||||
|
||||
def get_api(credentials: OAuth2Credentials) -> Requests:
|
||||
"""
|
||||
Create a Requests instance configured for Discord API calls with OAuth2 credentials.
|
||||
|
||||
Args:
|
||||
credentials: The OAuth2 credentials containing the access token.
|
||||
|
||||
Returns:
|
||||
A configured Requests instance for Discord API calls.
|
||||
"""
|
||||
return Requests(
|
||||
trusted_origins=[],
|
||||
extra_headers={
|
||||
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
|
||||
"""
|
||||
Fetch the current user's information using Discord OAuth2 API.
|
||||
|
||||
Reference: https://discord.com/developers/docs/resources/user#get-current-user
|
||||
|
||||
Args:
|
||||
credentials: The OAuth2 credentials.
|
||||
|
||||
Returns:
|
||||
A model containing user data with avatar URL.
|
||||
|
||||
Raises:
|
||||
DiscordAPIException: If the API request fails.
|
||||
"""
|
||||
api = get_api(credentials)
|
||||
response = await api.get("https://discord.com/api/oauth2/@me")
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise DiscordAPIException(
|
||||
f"Failed to fetch user info: {response.status} - {error_text}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
logger.info(f"Discord OAuth2 API Response: {data}")
|
||||
|
||||
# The /api/oauth2/@me endpoint returns a user object nested in the response
|
||||
user_info = data.get("user", {})
|
||||
logger.info(f"User info extracted: {user_info}")
|
||||
|
||||
# Build avatar URL
|
||||
user_id = user_info.get("id")
|
||||
avatar_hash = user_info.get("avatar")
|
||||
if avatar_hash:
|
||||
# Custom avatar
|
||||
avatar_ext = "gif" if avatar_hash.startswith("a_") else "png"
|
||||
avatar_url = (
|
||||
f"https://cdn.discordapp.com/avatars/{user_id}/{avatar_hash}.{avatar_ext}"
|
||||
)
|
||||
else:
|
||||
# Default avatar based on discriminator or user ID
|
||||
discriminator = user_info.get("discriminator", "0")
|
||||
if discriminator == "0":
|
||||
# New username system - use user ID for default avatar
|
||||
default_avatar_index = (int(user_id) >> 22) % 6
|
||||
else:
|
||||
# Legacy discriminator system
|
||||
default_avatar_index = int(discriminator) % 5
|
||||
avatar_url = (
|
||||
f"https://cdn.discordapp.com/embed/avatars/{default_avatar_index}.png"
|
||||
)
|
||||
|
||||
result = DiscordOAuthUser(
|
||||
user_id=user_id,
|
||||
username=user_info.get("username", ""),
|
||||
avatar_url=avatar_url,
|
||||
banner=user_info.get("banner"),
|
||||
accent_color=user_info.get("accent_color"),
|
||||
)
|
||||
|
||||
logger.info(f"Returning user data: {result.model_dump()}")
|
||||
return result
|
||||
74
autogpt_platform/backend/backend/blocks/discord/_auth.py
Normal file
74
autogpt_platform/backend/backend/blocks/discord/_auth.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
DISCORD_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.discord_client_id and secrets.discord_client_secret
|
||||
)
|
||||
|
||||
# Bot token credentials (existing)
|
||||
DiscordBotCredentials = APIKeyCredentials
|
||||
DiscordBotCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||
]
|
||||
|
||||
# OAuth2 credentials (new)
|
||||
DiscordOAuthCredentials = OAuth2Credentials
|
||||
DiscordOAuthCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
def DiscordBotCredentialsField() -> DiscordBotCredentialsInput:
|
||||
"""Creates a Discord bot token credentials field."""
|
||||
return CredentialsField(description="Discord bot token")
|
||||
|
||||
|
||||
def DiscordOAuthCredentialsField(scopes: list[str]) -> DiscordOAuthCredentialsInput:
|
||||
"""Creates a Discord OAuth2 credentials field."""
|
||||
return CredentialsField(
|
||||
description="Discord OAuth2 credentials",
|
||||
required_scopes=set(scopes) | {"identify"}, # Basic user info scope
|
||||
)
|
||||
|
||||
|
||||
# Test credentials for bot tokens
|
||||
TEST_BOT_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
api_key=SecretStr("test_api_key"),
|
||||
title="Mock Discord API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_BOT_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_BOT_CREDENTIALS.provider,
|
||||
"id": TEST_BOT_CREDENTIALS.id,
|
||||
"type": TEST_BOT_CREDENTIALS.type,
|
||||
"title": TEST_BOT_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
# Test credentials for OAuth2
|
||||
TEST_OAUTH_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
access_token=SecretStr("test_access_token"),
|
||||
title="Mock Discord OAuth",
|
||||
scopes=["identify"],
|
||||
username="testuser",
|
||||
)
|
||||
TEST_OAUTH_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_OAUTH_CREDENTIALS.provider,
|
||||
"id": TEST_OAUTH_CREDENTIALS.id,
|
||||
"type": TEST_OAUTH_CREDENTIALS.type,
|
||||
"title": TEST_OAUTH_CREDENTIALS.type,
|
||||
}
|
||||
@@ -2,45 +2,29 @@ import base64
|
||||
import io
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
DiscordCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def DiscordCredentialsField() -> DiscordCredentials:
|
||||
return CredentialsField(description="Discord bot token")
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
api_key=SecretStr("test_api_key"),
|
||||
title="Mock Discord API key",
|
||||
expires_at=None,
|
||||
from ._auth import (
|
||||
TEST_BOT_CREDENTIALS,
|
||||
TEST_BOT_CREDENTIALS_INPUT,
|
||||
DiscordBotCredentialsField,
|
||||
DiscordBotCredentialsInput,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
# Keep backward compatibility alias
|
||||
DiscordCredentials = DiscordBotCredentialsInput
|
||||
DiscordCredentialsField = DiscordBotCredentialsField
|
||||
TEST_CREDENTIALS = TEST_BOT_CREDENTIALS
|
||||
TEST_CREDENTIALS_INPUT = TEST_BOT_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
class ReadDiscordMessagesBlock(Block):
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Discord OAuth-based blocks.
|
||||
"""
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import DiscordOAuthUser, get_current_user
|
||||
from ._auth import (
|
||||
DISCORD_OAUTH_IS_CONFIGURED,
|
||||
TEST_OAUTH_CREDENTIALS,
|
||||
TEST_OAUTH_CREDENTIALS_INPUT,
|
||||
DiscordOAuthCredentialsField,
|
||||
DiscordOAuthCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class DiscordGetCurrentUserBlock(Block):
|
||||
"""
|
||||
Gets information about the currently authenticated Discord user using OAuth2.
|
||||
This block requires Discord OAuth2 credentials (not bot tokens).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordOAuthCredentialsInput = DiscordOAuthCredentialsField(
|
||||
["identify"]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
user_id: str = SchemaField(description="The authenticated user's Discord ID")
|
||||
username: str = SchemaField(description="The user's username")
|
||||
avatar_url: str = SchemaField(description="URL to the user's avatar image")
|
||||
banner_url: str = SchemaField(
|
||||
description="URL to the user's banner image (if set)", default=""
|
||||
)
|
||||
accent_color: int = SchemaField(
|
||||
description="The user's accent color as an integer", default=0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8c7e39b8-4e9d-4f3a-b4e1-2a8c9d5f6e3b",
|
||||
input_schema=DiscordGetCurrentUserBlock.Input,
|
||||
output_schema=DiscordGetCurrentUserBlock.Output,
|
||||
description="Gets information about the currently authenticated Discord user using OAuth2 credentials.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=not DISCORD_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"credentials": TEST_OAUTH_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_OAUTH_CREDENTIALS,
|
||||
test_output=[
|
||||
("user_id", "123456789012345678"),
|
||||
("username", "testuser"),
|
||||
(
|
||||
"avatar_url",
|
||||
"https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
|
||||
),
|
||||
("banner_url", ""),
|
||||
("accent_color", 0),
|
||||
],
|
||||
test_mock={
|
||||
"get_user": lambda _: DiscordOAuthUser(
|
||||
user_id="123456789012345678",
|
||||
username="testuser",
|
||||
avatar_url="https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
|
||||
banner=None,
|
||||
accent_color=0,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
|
||||
user_info = await get_current_user(credentials)
|
||||
return user_info
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.get_user(credentials)
|
||||
|
||||
# Yield each output field
|
||||
yield "user_id", result.user_id
|
||||
yield "username", result.username
|
||||
yield "avatar_url", result.avatar_url
|
||||
|
||||
# Handle banner URL if banner hash exists
|
||||
if result.banner:
|
||||
banner_url = f"https://cdn.discordapp.com/banners/{result.user_id}/{result.banner}.png"
|
||||
yield "banner_url", banner_url
|
||||
else:
|
||||
yield "banner_url", ""
|
||||
|
||||
yield "accent_color", result.accent_color or 0
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get Discord user info: {e}")
|
||||
@@ -93,11 +93,11 @@ class Webset(BaseModel):
|
||||
"""
|
||||
Set of key-value pairs you want to associate with this object.
|
||||
"""
|
||||
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
|
||||
created_at: Annotated[datetime | None, Field(alias="createdAt")] = None
|
||||
"""
|
||||
The date and time the webset was created
|
||||
"""
|
||||
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
|
||||
updated_at: Annotated[datetime | None, Field(alias="updatedAt")] = None
|
||||
"""
|
||||
The date and time the webset was last updated
|
||||
"""
|
||||
|
||||
@@ -1094,6 +1094,117 @@ class GmailGetThreadBlock(GmailBase):
|
||||
return thread
|
||||
|
||||
|
||||
async def _build_reply_message(
|
||||
service, input_data, graph_exec_id: str, user_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds a reply MIME message for Gmail threads.
|
||||
|
||||
Returns:
|
||||
tuple: (base64-encoded raw message, threadId)
|
||||
"""
|
||||
# Get parent message for reply context
|
||||
parent = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Build headers dictionary, preserving all values for duplicate headers
|
||||
headers = {}
|
||||
for h in parent.get("payload", {}).get("headers", []):
|
||||
name = h["name"].lower()
|
||||
value = h["value"]
|
||||
if name in headers:
|
||||
# For duplicate headers, keep the first occurrence (most relevant for reply context)
|
||||
continue
|
||||
headers[name] = value
|
||||
|
||||
# Determine recipients if not specified
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [addr for _, addr in getaddresses([headers.get("to", "")])]
|
||||
recipients += [addr for _, addr in getaddresses([headers.get("cc", "")])]
|
||||
# Use dict.fromkeys() for O(n) deduplication while preserving order
|
||||
input_data.to = list(dict.fromkeys(filter(None, recipients)))
|
||||
else:
|
||||
# Check Reply-To header first, fall back to From header
|
||||
reply_to = headers.get("reply-to", "")
|
||||
from_addr = headers.get("from", "")
|
||||
sender = parseaddr(reply_to if reply_to else from_addr)[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
|
||||
# Set subject with Re: prefix if not already present
|
||||
if input_data.subject:
|
||||
subject = input_data.subject
|
||||
else:
|
||||
parent_subject = headers.get("subject", "").strip()
|
||||
# Only add "Re:" if not already present (case-insensitive check)
|
||||
if parent_subject.lower().startswith("re:"):
|
||||
subject = parent_subject
|
||||
else:
|
||||
subject = f"Re: {parent_subject}" if parent_subject else "Re:"
|
||||
|
||||
# Build references header for proper threading
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
|
||||
# Use the helper function for consistent content type handling
|
||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
||||
|
||||
# Handle attachments
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
# Encode message
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
return raw, input_data.threadId
|
||||
|
||||
|
||||
class GmailReplyBlock(GmailBase):
|
||||
"""
|
||||
Replies to Gmail threads with intelligent content type detection.
|
||||
@@ -1230,93 +1341,146 @@ class GmailReplyBlock(GmailBase):
|
||||
async def _reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
parent = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in parent.get("payload", {}).get("headers", [])
|
||||
}
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("to", "")])
|
||||
]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("cc", "")])
|
||||
]
|
||||
dedup: list[str] = []
|
||||
for r in recipients:
|
||||
if r and r not in dedup:
|
||||
dedup.append(r)
|
||||
input_data.to = dedup
|
||||
else:
|
||||
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
subject = input_data.subject or (f"Re: {headers.get('subject', '')}".strip())
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
# Use the new helper function for consistent content type handling
|
||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
||||
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
# Send the message
|
||||
return await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"threadId": input_data.threadId, "raw": raw})
|
||||
.send(userId="me", body={"threadId": thread_id, "raw": raw})
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class GmailDraftReplyBlock(GmailBase):
|
||||
"""
|
||||
Creates draft replies to Gmail threads with intelligent content type detection.
|
||||
|
||||
Features:
|
||||
- Automatic HTML detection: Draft replies containing HTML tags are formatted as text/html
|
||||
- No hard-wrap for plain text: Plain text draft replies preserve natural line flow
|
||||
- Manual content type override: Use content_type parameter to force specific format
|
||||
- Reply-all functionality: Option to reply to all original recipients
|
||||
- Thread preservation: Maintains proper email threading with headers
|
||||
- Full Unicode/emoji support with UTF-8 encoding
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
]
|
||||
)
|
||||
threadId: str = SchemaField(description="Thread ID to reply in")
|
||||
parentMessageId: str = SchemaField(
|
||||
description="ID of the message being replied to"
|
||||
)
|
||||
to: list[str] = SchemaField(description="To recipients", default_factory=list)
|
||||
cc: list[str] = SchemaField(description="CC recipients", default_factory=list)
|
||||
bcc: list[str] = SchemaField(description="BCC recipients", default_factory=list)
|
||||
replyAll: bool = SchemaField(
|
||||
description="Reply to all original recipients", default=False
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject", default="")
|
||||
body: str = SchemaField(description="Email body (plain text or HTML)")
|
||||
content_type: Optional[Literal["auto", "plain", "html"]] = SchemaField(
|
||||
description="Content type: 'auto' (default - detects HTML), 'plain', or 'html'",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
attachments: list[MediaFileType] = SchemaField(
|
||||
description="Files to attach", default_factory=list, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
draftId: str = SchemaField(description="Created draft ID")
|
||||
messageId: str = SchemaField(description="Draft message ID")
|
||||
threadId: str = SchemaField(description="Thread ID")
|
||||
status: str = SchemaField(description="Draft creation status")
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d7a9f3e2-8b4c-4d6f-9e1a-3c5b7f8d2a6e",
|
||||
description="Create draft replies to Gmail threads with automatic HTML detection and proper text formatting. Plain text draft replies maintain natural paragraph flow without 78-character line wrapping. HTML content is automatically detected and formatted correctly.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailDraftReplyBlock.Input,
|
||||
output_schema=GmailDraftReplyBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"threadId": "t1",
|
||||
"parentMessageId": "m1",
|
||||
"body": "Thanks for your message. I'll review and get back to you.",
|
||||
"replyAll": False,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("draftId", "draft1"),
|
||||
("messageId", "m2"),
|
||||
("threadId", "t1"),
|
||||
("status", "draft_created"),
|
||||
],
|
||||
test_mock={
|
||||
"_create_draft_reply": lambda *args, **kwargs: {
|
||||
"id": "draft1",
|
||||
"message": {"id": "m2", "threadId": "t1"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
draft = await self._create_draft_reply(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "draftId", draft["id"]
|
||||
yield "messageId", draft["message"]["id"]
|
||||
yield "threadId", draft["message"].get("threadId", input_data.threadId)
|
||||
yield "status", "draft_created"
|
||||
|
||||
async def _create_draft_reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Create draft with proper thread association
|
||||
draft = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.drafts()
|
||||
.create(
|
||||
userId="me",
|
||||
body={
|
||||
"message": {
|
||||
"threadId": thread_id,
|
||||
"raw": raw,
|
||||
}
|
||||
},
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return draft
|
||||
|
||||
|
||||
class GmailGetProfileBlock(GmailBase):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
|
||||
@@ -30,6 +30,7 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
|
||||
class IdeogramModelName(str, Enum):
|
||||
V3 = "V_3"
|
||||
V2 = "V_2"
|
||||
V1 = "V_1"
|
||||
V1_TURBO = "V_1_TURBO"
|
||||
@@ -95,8 +96,8 @@ class IdeogramModelBlock(Block):
|
||||
title="Prompt",
|
||||
)
|
||||
ideogram_model_name: IdeogramModelName = SchemaField(
|
||||
description="The name of the Image Generation Model, e.g., V_2",
|
||||
default=IdeogramModelName.V2,
|
||||
description="The name of the Image Generation Model, e.g., V_3",
|
||||
default=IdeogramModelName.V3,
|
||||
title="Image Generation Model",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -236,6 +237,111 @@ class IdeogramModelBlock(Block):
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
# Use V3 endpoint for V3 model, legacy endpoint for others
|
||||
if model_name == "V_3":
|
||||
return await self._run_model_v3(
|
||||
api_key,
|
||||
prompt,
|
||||
seed,
|
||||
aspect_ratio,
|
||||
magic_prompt_option,
|
||||
style_type,
|
||||
negative_prompt,
|
||||
color_palette_name,
|
||||
custom_colors,
|
||||
)
|
||||
else:
|
||||
return await self._run_model_legacy(
|
||||
api_key,
|
||||
model_name,
|
||||
prompt,
|
||||
seed,
|
||||
aspect_ratio,
|
||||
magic_prompt_option,
|
||||
style_type,
|
||||
negative_prompt,
|
||||
color_palette_name,
|
||||
custom_colors,
|
||||
)
|
||||
|
||||
async def _run_model_v3(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
prompt: str,
|
||||
seed: Optional[int],
|
||||
aspect_ratio: str,
|
||||
magic_prompt_option: str,
|
||||
style_type: str,
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
url = "https://api.ideogram.ai/v1/ideogram-v3/generate"
|
||||
headers = {
|
||||
"Api-Key": api_key.get_secret_value(),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Map legacy aspect ratio values to V3 format
|
||||
aspect_ratio_map = {
|
||||
"ASPECT_10_16": "10x16",
|
||||
"ASPECT_16_10": "16x10",
|
||||
"ASPECT_9_16": "9x16",
|
||||
"ASPECT_16_9": "16x9",
|
||||
"ASPECT_3_2": "3x2",
|
||||
"ASPECT_2_3": "2x3",
|
||||
"ASPECT_4_3": "4x3",
|
||||
"ASPECT_3_4": "3x4",
|
||||
"ASPECT_1_1": "1x1",
|
||||
"ASPECT_1_3": "1x3",
|
||||
"ASPECT_3_1": "3x1",
|
||||
# Additional V3 supported ratios
|
||||
"ASPECT_1_2": "1x2",
|
||||
"ASPECT_2_1": "2x1",
|
||||
"ASPECT_4_5": "4x5",
|
||||
"ASPECT_5_4": "5x4",
|
||||
}
|
||||
|
||||
v3_aspect_ratio = aspect_ratio_map.get(
|
||||
aspect_ratio, "1x1"
|
||||
) # Default to 1x1 if not found
|
||||
|
||||
# Use JSON for V3 endpoint (simpler than multipart/form-data)
|
||||
data: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": v3_aspect_ratio,
|
||||
"magic_prompt": magic_prompt_option,
|
||||
"style_type": style_type,
|
||||
}
|
||||
|
||||
if seed is not None:
|
||||
data["seed"] = seed
|
||||
|
||||
if negative_prompt:
|
||||
data["negative_prompt"] = negative_prompt
|
||||
|
||||
# Note: V3 endpoint may have different color palette support
|
||||
# For now, we'll omit color palettes for V3 to avoid errors
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
|
||||
|
||||
async def _run_model_legacy(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
seed: Optional[int],
|
||||
aspect_ratio: str,
|
||||
magic_prompt_option: str,
|
||||
style_type: str,
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
url = "https://api.ideogram.ai/generate"
|
||||
headers = {
|
||||
@@ -249,28 +355,33 @@ class IdeogramModelBlock(Block):
|
||||
"model": model_name,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"magic_prompt_option": magic_prompt_option,
|
||||
"style_type": style_type,
|
||||
}
|
||||
}
|
||||
|
||||
# Only add style_type for V2, V2_TURBO, and V3 models (V1 models don't support it)
|
||||
if model_name in ["V_2", "V_2_TURBO", "V_3"]:
|
||||
data["image_request"]["style_type"] = style_type
|
||||
|
||||
if seed is not None:
|
||||
data["image_request"]["seed"] = seed
|
||||
|
||||
if negative_prompt:
|
||||
data["image_request"]["negative_prompt"] = negative_prompt
|
||||
|
||||
if color_palette_name != "NONE":
|
||||
data["color_palette"] = {"name": color_palette_name}
|
||||
elif custom_colors:
|
||||
data["color_palette"] = {
|
||||
"members": [{"color_hex": color} for color in custom_colors]
|
||||
}
|
||||
# Only add color palette for V2 and V2_TURBO models (V1 models don't support it)
|
||||
if model_name in ["V_2", "V_2_TURBO"]:
|
||||
if color_palette_name != "NONE":
|
||||
data["color_palette"] = {"name": color_palette_name}
|
||||
elif custom_colors:
|
||||
data["color_palette"] = {
|
||||
"members": [{"color_hex": color} for color in custom_colors]
|
||||
}
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image: {str(e)}")
|
||||
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
|
||||
|
||||
async def upscale_image(self, api_key: SecretStr, image_url: str):
|
||||
url = "https://api.ideogram.ai/upscale"
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
stagehand = (
|
||||
ProviderBuilder("stagehand")
|
||||
.with_api_key("STAGEHAND_API_KEY", "Stagehand API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
393
autogpt_platform/backend/backend/blocks/stagehand/blocks.py
Normal file
393
autogpt_platform/backend/backend/blocks/stagehand/blocks.py
Normal file
@@ -0,0 +1,393 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
# Monkey patch Stagehands to prevent signal handling in worker threads
|
||||
import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
ModelMetadata,
|
||||
)
|
||||
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
def safe_register_signal_handlers(self):
|
||||
"""Only register signal handlers in the main thread"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_register_signal_handlers(self)
|
||||
else:
|
||||
# Skip signal handling in worker threads
|
||||
pass
|
||||
|
||||
|
||||
# Replace the method
|
||||
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_signal_handling():
|
||||
"""Context manager to temporarily disable signal handling"""
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# In worker threads, temporarily replace signal.signal with a no-op
|
||||
original_signal = signal.signal
|
||||
|
||||
def noop_signal(*args, **kwargs):
|
||||
pass
|
||||
|
||||
signal.signal = noop_signal
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal = original_signal
|
||||
else:
|
||||
# In main thread, don't modify anything
|
||||
yield
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StagehandRecommendedLlmModel(str, Enum):
|
||||
"""
|
||||
This is subset of LLModel from autogpt_platform/backend/backend/blocks/llm.py
|
||||
|
||||
It contains only the models recommended by Stagehand
|
||||
"""
|
||||
|
||||
# OpenAI
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
|
||||
# Anthropic
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""
|
||||
Returns the provider name for the model in the required format for Stagehand:
|
||||
provider/model_name
|
||||
"""
|
||||
model_metadata = MODEL_METADATA[LlmModel(self.value)]
|
||||
model_name = self.value
|
||||
|
||||
if len(model_name.split("/")) == 1 and not self.value.startswith(
|
||||
model_metadata.provider
|
||||
):
|
||||
assert (
|
||||
model_metadata.provider != "open_router"
|
||||
), "Logic failed and open_router provider attempted to be prepended to model name! in stagehand/_config.py"
|
||||
model_name = f"{model_metadata.provider}/{model_name}"
|
||||
|
||||
logger.error(f"Model name: {model_name}")
|
||||
return model_name
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return MODEL_METADATA[LlmModel(self.value)].provider
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[LlmModel(self.value)]
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return MODEL_METADATA[LlmModel(self.value)].context_window
|
||||
|
||||
@property
|
||||
def max_output_tokens(self) -> int | None:
|
||||
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
|
||||
|
||||
|
||||
class StagehandObserveBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
# Browserbase credentials (Stagehand provider) or raw API key
|
||||
stagehand_credentials: CredentialsMetaInput = (
|
||||
stagehand_provider.credentials_field(
|
||||
description="Stagehand/Browserbase API key"
|
||||
)
|
||||
)
|
||||
browserbase_project_id: str = SchemaField(
|
||||
description="Browserbase project ID (required if using Browserbase)",
|
||||
)
|
||||
# Model selection and credentials (provider-discriminated like llm.py)
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
url: str = SchemaField(
|
||||
description="URL to navigate to.",
|
||||
)
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
selector: str = SchemaField(description="XPath selector to locate element.")
|
||||
description: str = SchemaField(description="Human-readable description")
|
||||
method: str | None = SchemaField(description="Suggested action method")
|
||||
arguments: list[str] | None = SchemaField(
|
||||
description="Additional action parameters"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d3863944-0eaf-45c4-a0c9-63e0fe1ee8b9",
|
||||
description="Find suggested actions for your workflows",
|
||||
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StagehandObserveBlock.Input,
|
||||
output_schema=StagehandObserveBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
stagehand_credentials: APIKeyCredentials,
|
||||
model_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
|
||||
observe_results = await page.observe(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
for result in observe_results:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
|
||||
|
||||
class StagehandActBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
# Browserbase credentials (Stagehand provider) or raw API key
|
||||
stagehand_credentials: CredentialsMetaInput = (
|
||||
stagehand_provider.credentials_field(
|
||||
description="Stagehand/Browserbase API key"
|
||||
)
|
||||
)
|
||||
browserbase_project_id: str = SchemaField(
|
||||
description="Browserbase project ID (required if using Browserbase)",
|
||||
)
|
||||
# Model selection and credentials (provider-discriminated like llm.py)
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
url: str = SchemaField(
|
||||
description="URL to navigate to.",
|
||||
)
|
||||
action: list[str] = SchemaField(
|
||||
description="Action to perform. Suggested actions are: click, fill, type, press, scroll, select from dropdown. For multi-step actions, add an entry for each step.",
|
||||
)
|
||||
variables: dict[str, str] = SchemaField(
|
||||
description="Variables to use in the action. Variables contains data you want the action to use.",
|
||||
default_factory=dict,
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
timeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
|
||||
default=60000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the action was completed successfully"
|
||||
)
|
||||
message: str = SchemaField(description="Details about the action’s execution.")
|
||||
action: str = SchemaField(description="Action performed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="86eba68b-9549-4c0b-a0db-47d85a56cc27",
|
||||
description="Interact with a web page by performing actions on a web page. Use it to build self-healing and deterministic automations that adapt to website chang.",
|
||||
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StagehandActBlock.Input,
|
||||
output_schema=StagehandActBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
stagehand_credentials: APIKeyCredentials,
|
||||
model_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
for action in input_data.action:
|
||||
action_results = await page.act(
|
||||
action,
|
||||
variables=input_data.variables,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
timeoutMs=input_data.timeoutMs,
|
||||
)
|
||||
yield "success", action_results.success
|
||||
yield "message", action_results.message
|
||||
yield "action", action_results.action
|
||||
|
||||
|
||||
class StagehandExtractBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
# Browserbase credentials (Stagehand provider) or raw API key
|
||||
stagehand_credentials: CredentialsMetaInput = (
|
||||
stagehand_provider.credentials_field(
|
||||
description="Stagehand/Browserbase API key"
|
||||
)
|
||||
)
|
||||
browserbase_project_id: str = SchemaField(
|
||||
description="Browserbase project ID (required if using Browserbase)",
|
||||
)
|
||||
# Model selection and credentials (provider-discriminated like llm.py)
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
url: str = SchemaField(
|
||||
description="URL to navigate to.",
|
||||
)
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
extraction: str = SchemaField(description="Extracted data from the page.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fd3c0b18-2ba6-46ae-9339-fcb40537ad98",
|
||||
description="Extract structured data from a webpage.",
|
||||
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StagehandExtractBlock.Input,
|
||||
output_schema=StagehandExtractBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
stagehand_credentials: APIKeyCredentials,
|
||||
model_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
extraction = await page.extract(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
yield "extraction", str(extraction.model_dump()["extraction"])
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,283 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
|
||||
class LibraryAgent(BaseModel):
|
||||
"""Model representing an agent in the user's library."""
|
||||
|
||||
library_agent_id: str = ""
|
||||
agent_id: str = ""
|
||||
agent_version: int = 0
|
||||
agent_name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
is_archived: bool = False
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class AddToLibraryFromStoreBlock(Block):
|
||||
"""
|
||||
Block that adds an agent from the store to the user's library.
|
||||
This enables users to easily import agents from the marketplace into their personal collection.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
store_listing_version_id: str = SchemaField(
|
||||
description="The ID of the store listing version to add to library"
|
||||
)
|
||||
agent_name: str | None = SchemaField(
|
||||
description="Optional custom name for the agent in your library",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the agent was successfully added to library"
|
||||
)
|
||||
library_agent_id: str = SchemaField(
|
||||
description="The ID of the library agent entry"
|
||||
)
|
||||
agent_id: str = SchemaField(description="The ID of the agent graph")
|
||||
agent_version: int = SchemaField(
|
||||
description="The version number of the agent graph"
|
||||
)
|
||||
agent_name: str = SchemaField(description="The name of the agent")
|
||||
message: str = SchemaField(description="Success or error message")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2602a7b1-3f4d-4e5f-9c8b-1a2b3c4d5e6f",
|
||||
description="Add an agent from the store to your personal library",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToLibraryFromStoreBlock.Input,
|
||||
output_schema=AddToLibraryFromStoreBlock.Output,
|
||||
test_input={
|
||||
"store_listing_version_id": "test-listing-id",
|
||||
"agent_name": "My Custom Agent",
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("library_agent_id", "test-library-id"),
|
||||
("agent_id", "test-agent-id"),
|
||||
("agent_version", 1),
|
||||
("agent_name", "Test Agent"),
|
||||
("message", "Agent successfully added to library"),
|
||||
],
|
||||
test_mock={
|
||||
"_add_to_library": lambda *_, **__: LibraryAgent(
|
||||
library_agent_id="test-library-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Agent",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
library_agent = await self._add_to_library(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=input_data.store_listing_version_id,
|
||||
custom_name=input_data.agent_name,
|
||||
)
|
||||
|
||||
yield "success", True
|
||||
yield "library_agent_id", library_agent.library_agent_id
|
||||
yield "agent_id", library_agent.agent_id
|
||||
yield "agent_version", library_agent.agent_version
|
||||
yield "agent_name", library_agent.agent_name
|
||||
yield "message", "Agent successfully added to library"
|
||||
|
||||
async def _add_to_library(
|
||||
self,
|
||||
user_id: str,
|
||||
store_listing_version_id: str,
|
||||
custom_name: str | None = None,
|
||||
) -> LibraryAgent:
|
||||
"""
|
||||
Add a store agent to the user's library using the existing library database function.
|
||||
"""
|
||||
library_agent = (
|
||||
await get_database_manager_async_client().add_store_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id, user_id=user_id
|
||||
)
|
||||
)
|
||||
|
||||
# If custom name is provided, we could update the library agent name here
|
||||
# For now, we'll just return the agent info
|
||||
agent_name = custom_name if custom_name else library_agent.name
|
||||
|
||||
return LibraryAgent(
|
||||
library_agent_id=library_agent.id,
|
||||
agent_id=library_agent.graph_id,
|
||||
agent_version=library_agent.graph_version,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
|
||||
class ListLibraryAgentsBlock(Block):
|
||||
"""
|
||||
Block that lists all agents in the user's library.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
search_query: str | None = SchemaField(
|
||||
description="Optional search query to filter agents", default=None
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of agents to return", default=50, ge=1, le=100
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination", default=1, ge=1
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
agents: list[LibraryAgent] = SchemaField(
|
||||
description="List of agents in the library",
|
||||
default_factory=list,
|
||||
)
|
||||
agent: LibraryAgent = SchemaField(
|
||||
description="Individual library agent (yielded for each agent)"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of agents in library", default=0
|
||||
)
|
||||
page: int = SchemaField(description="Current page number", default=1)
|
||||
total_pages: int = SchemaField(description="Total number of pages", default=1)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="082602d3-a74d-4600-9e9c-15b3af7eae98",
|
||||
description="List all agents in your personal library",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=ListLibraryAgentsBlock.Input,
|
||||
output_schema=ListLibraryAgentsBlock.Output,
|
||||
test_input={
|
||||
"search_query": None,
|
||||
"limit": 10,
|
||||
"page": 1,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"agents",
|
||||
[
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
),
|
||||
],
|
||||
),
|
||||
("total_count", 1),
|
||||
("page", 1),
|
||||
("total_pages", 1),
|
||||
(
|
||||
"agent",
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_list_library_agents": lambda *_, **__: {
|
||||
"agents": [
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
)
|
||||
],
|
||||
"total": 1,
|
||||
"page": 1,
|
||||
"total_pages": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self._list_library_agents(
|
||||
user_id=user_id,
|
||||
search_query=input_data.search_query,
|
||||
limit=input_data.limit,
|
||||
page=input_data.page,
|
||||
)
|
||||
|
||||
agents = result["agents"]
|
||||
|
||||
yield "agents", agents
|
||||
yield "total_count", result["total"]
|
||||
yield "page", result["page"]
|
||||
yield "total_pages", result["total_pages"]
|
||||
|
||||
# Yield each agent individually for better graph connectivity
|
||||
for agent in agents:
|
||||
yield "agent", agent
|
||||
|
||||
async def _list_library_agents(
|
||||
self,
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
limit: int = 50,
|
||||
page: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
List agents in the user's library using the database client.
|
||||
"""
|
||||
result = await get_database_manager_async_client().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=page,
|
||||
page_size=limit,
|
||||
)
|
||||
|
||||
agents = [
|
||||
LibraryAgent(
|
||||
library_agent_id=agent.id,
|
||||
agent_id=agent.graph_id,
|
||||
agent_version=agent.graph_version,
|
||||
agent_name=agent.name,
|
||||
description=getattr(agent, "description", ""),
|
||||
creator=getattr(agent, "creator", ""),
|
||||
is_archived=getattr(agent, "is_archived", False),
|
||||
categories=getattr(agent, "categories", []),
|
||||
)
|
||||
for agent in result.agents
|
||||
]
|
||||
|
||||
return {
|
||||
"agents": agents,
|
||||
"total": result.pagination.total_items,
|
||||
"page": result.pagination.current_page,
|
||||
"total_pages": result.pagination.total_pages,
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
|
||||
class StoreAgent(BaseModel):
|
||||
"""Model representing a store agent."""
|
||||
|
||||
slug: str = ""
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
rating: float = 0.0
|
||||
runs: int = 0
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class StoreAgentDict(BaseModel):
|
||||
"""Dictionary representation of a store agent."""
|
||||
|
||||
slug: str
|
||||
name: str
|
||||
description: str
|
||||
creator: str
|
||||
rating: float
|
||||
runs: int
|
||||
|
||||
|
||||
class SearchAgentsResponse(BaseModel):
|
||||
"""Response from searching store agents."""
|
||||
|
||||
agents: list[StoreAgentDict]
|
||||
total_count: int
|
||||
|
||||
|
||||
class StoreAgentDetails(BaseModel):
|
||||
"""Detailed information about a store agent."""
|
||||
|
||||
found: bool
|
||||
store_listing_version_id: str = ""
|
||||
agent_name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
categories: list[str] = []
|
||||
runs: int = 0
|
||||
rating: float = 0.0
|
||||
|
||||
|
||||
class GetStoreAgentDetailsBlock(Block):
|
||||
"""
|
||||
Block that retrieves detailed information about an agent from the store.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
creator: str = SchemaField(description="The username of the agent creator")
|
||||
slug: str = SchemaField(description="The name of the agent")
|
||||
|
||||
class Output(BlockSchema):
|
||||
found: bool = SchemaField(
|
||||
description="Whether the agent was found in the store"
|
||||
)
|
||||
store_listing_version_id: str = SchemaField(
|
||||
description="The store listing version ID"
|
||||
)
|
||||
agent_name: str = SchemaField(description="Name of the agent")
|
||||
description: str = SchemaField(description="Description of the agent")
|
||||
creator: str = SchemaField(description="Creator of the agent")
|
||||
categories: list[str] = SchemaField(
|
||||
description="Categories the agent belongs to", default_factory=list
|
||||
)
|
||||
runs: int = SchemaField(
|
||||
description="Number of times the agent has been run", default=0
|
||||
)
|
||||
rating: float = SchemaField(
|
||||
description="Average rating of the agent", default=0.0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b604f0ec-6e0d-40a7-bf55-9fd09997cced",
|
||||
description="Get detailed information about an agent from the store",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=GetStoreAgentDetailsBlock.Input,
|
||||
output_schema=GetStoreAgentDetailsBlock.Output,
|
||||
test_input={"creator": "test-creator", "slug": "test-agent-slug"},
|
||||
test_output=[
|
||||
("found", True),
|
||||
("store_listing_version_id", "test-listing-id"),
|
||||
("agent_name", "Test Agent"),
|
||||
("description", "A test agent"),
|
||||
("creator", "Test Creator"),
|
||||
("categories", ["productivity", "automation"]),
|
||||
("runs", 100),
|
||||
("rating", 4.5),
|
||||
],
|
||||
test_mock={
|
||||
"_get_agent_details": lambda *_, **__: StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id="test-listing-id",
|
||||
agent_name="Test Agent",
|
||||
description="A test agent",
|
||||
creator="Test Creator",
|
||||
categories=["productivity", "automation"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
)
|
||||
},
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
details = await self._get_agent_details(
|
||||
creator=input_data.creator, slug=input_data.slug
|
||||
)
|
||||
yield "found", details.found
|
||||
yield "store_listing_version_id", details.store_listing_version_id
|
||||
yield "agent_name", details.agent_name
|
||||
yield "description", details.description
|
||||
yield "creator", details.creator
|
||||
yield "categories", details.categories
|
||||
yield "runs", details.runs
|
||||
yield "rating", details.rating
|
||||
|
||||
async def _get_agent_details(self, creator: str, slug: str) -> StoreAgentDetails:
|
||||
"""
|
||||
Retrieve detailed information about a store agent.
|
||||
"""
|
||||
# Get by specific version ID
|
||||
agent_details = (
|
||||
await get_database_manager_async_client().get_store_agent_details(
|
||||
username=creator, agent_name=slug
|
||||
)
|
||||
)
|
||||
|
||||
return StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id=agent_details.store_listing_version_id,
|
||||
agent_name=agent_details.agent_name,
|
||||
description=agent_details.description,
|
||||
creator=agent_details.creator,
|
||||
categories=(
|
||||
agent_details.categories if hasattr(agent_details, "categories") else []
|
||||
),
|
||||
runs=agent_details.runs,
|
||||
rating=agent_details.rating,
|
||||
)
|
||||
|
||||
|
||||
class SearchStoreAgentsBlock(Block):
|
||||
"""
|
||||
Block that searches for agents in the store based on various criteria.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
query: str | None = SchemaField(
|
||||
description="Search query to find agents", default=None
|
||||
)
|
||||
category: str | None = SchemaField(
|
||||
description="Filter by category", default=None
|
||||
)
|
||||
sort_by: Literal["rating", "runs", "name", "recent"] = SchemaField(
|
||||
description="How to sort the results", default="rating"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
agents: list[StoreAgent] = SchemaField(
|
||||
description="List of agents matching the search criteria",
|
||||
default_factory=list,
|
||||
)
|
||||
agent: StoreAgent = SchemaField(description="Basic information of the agent")
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of agents found", default=0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="39524701-026c-4328-87cc-1b88c8e2cb4c",
|
||||
description="Search for agents in the store",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=SearchStoreAgentsBlock.Input,
|
||||
output_schema=SearchStoreAgentsBlock.Output,
|
||||
test_input={
|
||||
"query": "productivity",
|
||||
"category": None,
|
||||
"sort_by": "rating",
|
||||
"limit": 10,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"agents",
|
||||
[
|
||||
{
|
||||
"slug": "test-agent",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"creator": "Test Creator",
|
||||
"rating": 4.5,
|
||||
"runs": 100,
|
||||
}
|
||||
],
|
||||
),
|
||||
("total_count", 1),
|
||||
(
|
||||
"agent",
|
||||
{
|
||||
"slug": "test-agent",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"creator": "Test Creator",
|
||||
"rating": 4.5,
|
||||
"runs": 100,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_search_agents": lambda *_, **__: SearchAgentsResponse(
|
||||
agents=[
|
||||
StoreAgentDict(
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="A test agent",
|
||||
creator="Test Creator",
|
||||
rating=4.5,
|
||||
runs=100,
|
||||
)
|
||||
],
|
||||
total_count=1,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self._search_agents(
|
||||
query=input_data.query,
|
||||
category=input_data.category,
|
||||
sort_by=input_data.sort_by,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
agents = result.agents
|
||||
total_count = result.total_count
|
||||
|
||||
# Convert to dict for output
|
||||
agents_as_dicts = [agent.model_dump() for agent in agents]
|
||||
|
||||
yield "agents", agents_as_dicts
|
||||
yield "total_count", total_count
|
||||
|
||||
for agent_dict in agents_as_dicts:
|
||||
yield "agent", agent_dict
|
||||
|
||||
async def _search_agents(
|
||||
self,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
sort_by: str = "rating",
|
||||
limit: int = 10,
|
||||
) -> SearchAgentsResponse:
|
||||
"""
|
||||
Search for agents in the store using the existing store database function.
|
||||
"""
|
||||
# Map our sort_by to the store's sorted_by parameter
|
||||
sorted_by_map = {
|
||||
"rating": "most_popular",
|
||||
"runs": "most_runs",
|
||||
"name": "alphabetical",
|
||||
"recent": "recently_updated",
|
||||
}
|
||||
|
||||
result = await get_database_manager_async_client().get_store_agents(
|
||||
featured=False,
|
||||
creators=None,
|
||||
sorted_by=sorted_by_map.get(sort_by, "most_popular"),
|
||||
search_query=query,
|
||||
category=category,
|
||||
page=1,
|
||||
page_size=limit,
|
||||
)
|
||||
|
||||
agents = [
|
||||
StoreAgentDict(
|
||||
slug=agent.slug,
|
||||
name=agent.agent_name,
|
||||
description=agent.description,
|
||||
creator=agent.creator,
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
)
|
||||
for agent in result.agents
|
||||
]
|
||||
|
||||
return SearchAgentsResponse(agents=agents, total_count=len(agents))
|
||||
@@ -35,20 +35,19 @@ async def execute_graph(
|
||||
logger.info("Input data: %s", input_data)
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
response = await agent_server.test_execute_graph(
|
||||
graph_exec = await agent_server.test_execute_graph(
|
||||
user_id=test_user.id,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
node_input=input_data,
|
||||
)
|
||||
graph_exec_id = response.graph_exec_id
|
||||
logger.info("Created execution with ID: %s", graph_exec_id)
|
||||
logger.info("Created execution with ID: %s", graph_exec.id)
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, graph_exec_id, 30)
|
||||
result = await wait_execution(test_user.id, graph_exec.id, 30)
|
||||
logger.info("Execution completed with %d results", len(result))
|
||||
return graph_exec_id
|
||||
return graph_exec.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.system.library_operations import (
|
||||
AddToLibraryFromStoreBlock,
|
||||
LibraryAgent,
|
||||
)
|
||||
from backend.blocks.system.store_operations import (
|
||||
GetStoreAgentDetailsBlock,
|
||||
SearchAgentsResponse,
|
||||
SearchStoreAgentsBlock,
|
||||
StoreAgentDetails,
|
||||
StoreAgentDict,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_to_library_from_store_block_success(mocker):
|
||||
"""Test successful addition of agent from store to library."""
|
||||
block = AddToLibraryFromStoreBlock()
|
||||
|
||||
# Mock the library agent response
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.id = "lib-agent-123"
|
||||
mock_library_agent.graph_id = "graph-456"
|
||||
mock_library_agent.graph_version = 1
|
||||
mock_library_agent.name = "Test Agent"
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_add_to_library",
|
||||
return_value=LibraryAgent(
|
||||
library_agent_id="lib-agent-123",
|
||||
agent_id="graph-456",
|
||||
agent_version=1,
|
||||
agent_name="Test Agent",
|
||||
),
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
store_listing_version_id="store-listing-v1", agent_name="Custom Agent Name"
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, user_id="test-user"):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["success"] is True
|
||||
assert outputs["library_agent_id"] == "lib-agent-123"
|
||||
assert outputs["agent_id"] == "graph-456"
|
||||
assert outputs["agent_version"] == 1
|
||||
assert outputs["agent_name"] == "Test Agent"
|
||||
assert outputs["message"] == "Agent successfully added to library"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_agent_details_block_success(mocker):
|
||||
"""Test successful retrieval of store agent details."""
|
||||
block = GetStoreAgentDetailsBlock()
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_get_agent_details",
|
||||
return_value=StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id="version-123",
|
||||
agent_name="Test Agent",
|
||||
description="A test agent for testing",
|
||||
creator="Test Creator",
|
||||
categories=["productivity", "automation"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
),
|
||||
)
|
||||
|
||||
input_data = block.Input(creator="Test Creator", slug="test-slug")
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["found"] is True
|
||||
assert outputs["store_listing_version_id"] == "version-123"
|
||||
assert outputs["agent_name"] == "Test Agent"
|
||||
assert outputs["description"] == "A test agent for testing"
|
||||
assert outputs["creator"] == "Test Creator"
|
||||
assert outputs["categories"] == ["productivity", "automation"]
|
||||
assert outputs["runs"] == 100
|
||||
assert outputs["rating"] == 4.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_store_agents_block(mocker):
|
||||
"""Test searching for store agents."""
|
||||
block = SearchStoreAgentsBlock()
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_search_agents",
|
||||
return_value=SearchAgentsResponse(
|
||||
agents=[
|
||||
StoreAgentDict(
|
||||
slug="creator1/agent1",
|
||||
name="Agent One",
|
||||
description="First test agent",
|
||||
creator="Creator 1",
|
||||
rating=4.8,
|
||||
runs=500,
|
||||
),
|
||||
StoreAgentDict(
|
||||
slug="creator2/agent2",
|
||||
name="Agent Two",
|
||||
description="Second test agent",
|
||||
creator="Creator 2",
|
||||
rating=4.2,
|
||||
runs=200,
|
||||
),
|
||||
],
|
||||
total_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
query="test", category="productivity", sort_by="rating", limit=10
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data):
|
||||
outputs[name] = value
|
||||
|
||||
assert len(outputs["agents"]) == 2
|
||||
assert outputs["total_count"] == 2
|
||||
assert outputs["agents"][0]["name"] == "Agent One"
|
||||
assert outputs["agents"][0]["rating"] == 4.8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_store_agents_block_empty_results(mocker):
|
||||
"""Test searching with no results."""
|
||||
block = SearchStoreAgentsBlock()
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_search_agents",
|
||||
return_value=SearchAgentsResponse(agents=[], total_count=0),
|
||||
)
|
||||
|
||||
input_data = block.Input(query="nonexistent", limit=10)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["agents"] == []
|
||||
assert outputs["total_count"] == 0
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Literal, Union
|
||||
@@ -7,6 +8,7 @@ from zoneinfo import ZoneInfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.execution import UserContext
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
# Shared timezone literal type for all time/date blocks
|
||||
@@ -51,16 +53,80 @@ TimezoneLiteral = Literal[
|
||||
"Etc/GMT+12", # UTC-12:00
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_timezone(
|
||||
format_type: Any, # Any format type with timezone and use_user_timezone attributes
|
||||
user_timezone: str | None,
|
||||
) -> ZoneInfo:
|
||||
"""
|
||||
Determine which timezone to use based on format settings and user context.
|
||||
|
||||
Args:
|
||||
format_type: The format configuration containing timezone settings
|
||||
user_timezone: The user's timezone from context
|
||||
|
||||
Returns:
|
||||
ZoneInfo object for the determined timezone
|
||||
"""
|
||||
if format_type.use_user_timezone and user_timezone:
|
||||
tz = ZoneInfo(user_timezone)
|
||||
logger.debug(f"Using user timezone: {user_timezone}")
|
||||
else:
|
||||
tz = ZoneInfo(format_type.timezone)
|
||||
logger.debug(f"Using specified timezone: {format_type.timezone}")
|
||||
return tz
|
||||
|
||||
|
||||
def _format_datetime_iso8601(dt: datetime, include_microseconds: bool = False) -> str:
|
||||
"""
|
||||
Format a datetime object to ISO8601 string.
|
||||
|
||||
Args:
|
||||
dt: The datetime object to format
|
||||
include_microseconds: Whether to include microseconds in the output
|
||||
|
||||
Returns:
|
||||
ISO8601 formatted string
|
||||
"""
|
||||
if include_microseconds:
|
||||
return dt.isoformat()
|
||||
else:
|
||||
return dt.isoformat(timespec="seconds")
|
||||
|
||||
|
||||
# BACKWARDS COMPATIBILITY NOTE:
|
||||
# The timezone field is kept at the format level (not block level) for backwards compatibility.
|
||||
# Existing graphs have timezone saved within format_type, moving it would break them.
|
||||
#
|
||||
# The use_user_timezone flag was added to allow using the user's profile timezone.
|
||||
# Default is False to maintain backwards compatibility - existing graphs will continue
|
||||
# using their specified timezone.
|
||||
#
|
||||
# KNOWN ISSUE: If a user switches between format types (strftime <-> iso8601),
|
||||
# the timezone setting doesn't carry over. This is a UX issue but fixing it would
|
||||
# require either:
|
||||
# 1. Moving timezone to block level (breaking change, needs migration)
|
||||
# 2. Complex state management to sync timezone across format types
|
||||
#
|
||||
# Future migration path: When we do a major version bump, consider moving timezone
|
||||
# to the block Input level for better UX.
|
||||
|
||||
|
||||
class TimeStrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%H:%M:%S"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class TimeISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
include_microseconds: bool = False
|
||||
|
||||
|
||||
@@ -115,25 +181,27 @@ class GetCurrentTimeBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
async def run(
|
||||
self, input_data: Input, *, user_context: UserContext, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Extract timezone from user_context (always present)
|
||||
effective_timezone = user_context.timezone
|
||||
|
||||
# Get the appropriate timezone
|
||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
if isinstance(input_data.format_type, TimeISO8601Format):
|
||||
# ISO 8601 format for time only (extract time portion from full ISO datetime)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
# Get the full ISO format and extract just the time portion with timezone
|
||||
if input_data.format_type.include_microseconds:
|
||||
full_iso = dt.isoformat()
|
||||
else:
|
||||
full_iso = dt.isoformat(timespec="seconds")
|
||||
|
||||
full_iso = _format_datetime_iso8601(
|
||||
dt, input_data.format_type.include_microseconds
|
||||
)
|
||||
# Extract time portion (everything after 'T')
|
||||
current_time = full_iso.split("T")[1] if "T" in full_iso else full_iso
|
||||
current_time = f"T{current_time}" # Add T prefix for ISO 8601 time format
|
||||
else: # TimeStrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
current_time = dt.strftime(input_data.format_type.format)
|
||||
|
||||
yield "time", current_time
|
||||
|
||||
|
||||
@@ -141,11 +209,15 @@ class DateStrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%Y-%m-%d"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class DateISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class GetCurrentDateBlock(Block):
|
||||
@@ -217,20 +289,23 @@ class GetCurrentDateBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Extract timezone from user_context (required keyword argument)
|
||||
user_context: UserContext = kwargs["user_context"]
|
||||
effective_timezone = user_context.timezone
|
||||
|
||||
try:
|
||||
offset = int(input_data.offset)
|
||||
except ValueError:
|
||||
offset = 0
|
||||
|
||||
# Get the appropriate timezone
|
||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
|
||||
if isinstance(input_data.format_type, DateISO8601Format):
|
||||
# ISO 8601 format for date only (YYYY-MM-DD)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
# ISO 8601 date format is YYYY-MM-DD
|
||||
date_str = current_date.date().isoformat()
|
||||
else: # DateStrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
date_str = current_date.strftime(input_data.format_type.format)
|
||||
|
||||
yield "date", date_str
|
||||
@@ -240,11 +315,15 @@ class StrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%Y-%m-%d %H:%M:%S"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class ISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
include_microseconds: bool = False
|
||||
|
||||
|
||||
@@ -316,20 +395,22 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Extract timezone from user_context (required keyword argument)
|
||||
user_context: UserContext = kwargs["user_context"]
|
||||
effective_timezone = user_context.timezone
|
||||
|
||||
# Get the appropriate timezone
|
||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
if isinstance(input_data.format_type, ISO8601Format):
|
||||
# ISO 8601 format with specified timezone (also RFC3339-compliant)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
# Format with or without microseconds
|
||||
if input_data.format_type.include_microseconds:
|
||||
current_date_time = dt.isoformat()
|
||||
else:
|
||||
current_date_time = dt.isoformat(timespec="seconds")
|
||||
current_date_time = _format_datetime_iso8601(
|
||||
dt, input_data.format_type.include_microseconds
|
||||
)
|
||||
else: # StrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
current_date_time = dt.strftime(input_data.format_type.format)
|
||||
|
||||
yield "date_time", current_date_time
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from logging import getLogger
|
||||
from typing import Any, Dict, List, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from pydantic import field_serializer
|
||||
|
||||
from backend.sdk import BaseModel, Credentials, Requests
|
||||
|
||||
logger = getLogger(__name__)
|
||||
@@ -382,8 +384,9 @@ class CreatePostRequest(BaseModel):
|
||||
# Advanced
|
||||
metadata: List[Dict[str, Any]] | None = None
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
@field_serializer("date")
|
||||
def serialize_date(self, value: datetime | None) -> str | None:
|
||||
return value.isoformat() if value else None
|
||||
|
||||
|
||||
class PostAuthor(BaseModel):
|
||||
|
||||
@@ -6,8 +6,6 @@ from dotenv import load_dotenv
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
|
||||
os.environ["ENABLE_AUTH"] = "false"
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
# Resolve Webhook <- NodeModel forward reference
|
||||
# Resolve Webhook forward references
|
||||
NodeModel.model_rebuild()
|
||||
LibraryAgentPreset.model_rebuild()
|
||||
|
||||
@@ -1,57 +1,31 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.api_key.key_manager import APIKeyManager
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.errors import PrismaError
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import (
|
||||
APIKeyCreateInput,
|
||||
APIKeyUpdateInput,
|
||||
APIKeyWhereInput,
|
||||
APIKeyWhereUniqueInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.db import BaseDbModel
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
|
||||
# Some basic exceptions
|
||||
class APIKeyError(Exception):
|
||||
"""Base exception for API key operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyNotFoundError(APIKeyError):
|
||||
"""Raised when an API key is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyPermissionError(APIKeyError):
|
||||
"""Raised when there are permission issues with API key operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyValidationError(APIKeyError):
|
||||
"""Raised when API key validation fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKey(BaseDbModel):
|
||||
class APIKeyInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
prefix: str
|
||||
key: str
|
||||
status: APIKeyStatus = APIKeyStatus.ACTIVE
|
||||
permissions: List[APIKeyPermission]
|
||||
postfix: str
|
||||
head: str = Field(
|
||||
description=f"The first {APIKeySmith.HEAD_LENGTH} characters of the key"
|
||||
)
|
||||
tail: str = Field(
|
||||
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
|
||||
)
|
||||
status: APIKeyStatus
|
||||
permissions: list[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
@@ -60,266 +34,211 @@ class APIKey(BaseDbModel):
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
try:
|
||||
return APIKey(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
prefix=api_key.prefix,
|
||||
postfix=api_key.postfix,
|
||||
key=api_key.key,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating APIKey from db: {str(e)}")
|
||||
raise APIKeyError(f"Failed to create API key object: {str(e)}")
|
||||
return APIKeyInfo(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
head=api_key.head,
|
||||
tail=api_key.tail,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
|
||||
|
||||
class APIKeyWithoutHash(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
status: APIKeyStatus
|
||||
permissions: List[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime]
|
||||
revoked_at: Optional[datetime]
|
||||
description: Optional[str]
|
||||
user_id: str
|
||||
class APIKeyInfoWithHash(APIKeyInfo):
|
||||
hash: str
|
||||
salt: str | None = None # None for legacy keys
|
||||
|
||||
def match(self, plaintext_key: str) -> bool:
|
||||
"""Returns whether the given key matches this API key object."""
|
||||
return keysmith.verify_key(plaintext_key, self.hash, self.salt)
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
try:
|
||||
return APIKeyWithoutHash(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
prefix=api_key.prefix,
|
||||
postfix=api_key.postfix,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating APIKeyWithoutHash from db: {str(e)}")
|
||||
raise APIKeyError(f"Failed to create API key object: {str(e)}")
|
||||
return APIKeyInfoWithHash(
|
||||
**APIKeyInfo.from_db(api_key).model_dump(),
|
||||
hash=api_key.hash,
|
||||
salt=api_key.salt,
|
||||
)
|
||||
|
||||
def without_hash(self) -> APIKeyInfo:
|
||||
return APIKeyInfo(**self.model_dump(exclude={"hash", "salt"}))
|
||||
|
||||
|
||||
async def generate_api_key(
|
||||
async def create_api_key(
|
||||
name: str,
|
||||
user_id: str,
|
||||
permissions: List[APIKeyPermission],
|
||||
permissions: list[APIKeyPermission],
|
||||
description: Optional[str] = None,
|
||||
) -> tuple[APIKeyWithoutHash, str]:
|
||||
) -> tuple[APIKeyInfo, str]:
|
||||
"""
|
||||
Generate a new API key and store it in the database.
|
||||
Returns the API key object (without hash) and the plain text key.
|
||||
"""
|
||||
try:
|
||||
api_manager = APIKeyManager()
|
||||
key = api_manager.generate_api_key()
|
||||
generated_key = keysmith.generate_key()
|
||||
|
||||
api_key = await PrismaAPIKey.prisma().create(
|
||||
data=APIKeyCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
prefix=key.prefix,
|
||||
postfix=key.postfix,
|
||||
key=key.hash,
|
||||
permissions=[p for p in permissions],
|
||||
description=description,
|
||||
userId=user_id,
|
||||
)
|
||||
)
|
||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
|
||||
api_key_without_hash = APIKeyWithoutHash.from_db(api_key)
|
||||
return api_key_without_hash, key.raw
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while generating API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to generate API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while generating API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to generate API key: {str(e)}")
|
||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||
|
||||
|
||||
async def validate_api_key(plain_text_key: str) -> Optional[APIKey]:
|
||||
async def get_active_api_keys_by_head(head: str) -> list[APIKeyInfoWithHash]:
|
||||
results = await PrismaAPIKey.prisma().find_many(
|
||||
where={"head": head, "status": APIKeyStatus.ACTIVE}
|
||||
)
|
||||
return [APIKeyInfoWithHash.from_db(key) for key in results]
|
||||
|
||||
|
||||
async def validate_api_key(plaintext_key: str) -> Optional[APIKeyInfo]:
|
||||
"""
|
||||
Validate an API key and return the API key object if valid.
|
||||
Validate an API key and return the API key object if valid and active.
|
||||
"""
|
||||
try:
|
||||
if not plain_text_key.startswith(APIKeyManager.PREFIX):
|
||||
if not plaintext_key.startswith(APIKeySmith.PREFIX):
|
||||
logger.warning("Invalid API key format")
|
||||
return None
|
||||
|
||||
prefix = plain_text_key[: APIKeyManager.PREFIX_LENGTH]
|
||||
api_manager = APIKeyManager()
|
||||
head = plaintext_key[: APIKeySmith.HEAD_LENGTH]
|
||||
potential_matches = await get_active_api_keys_by_head(head)
|
||||
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where=APIKeyWhereInput(prefix=prefix, status=(APIKeyStatus.ACTIVE))
|
||||
matched_api_key = next(
|
||||
(pm for pm in potential_matches if pm.match(plaintext_key)),
|
||||
None,
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
logger.warning(f"No active API key found with prefix {prefix}")
|
||||
if not matched_api_key:
|
||||
# API key not found or invalid
|
||||
return None
|
||||
|
||||
is_valid = api_manager.verify_api_key(plain_text_key, api_key.key)
|
||||
if not is_valid:
|
||||
logger.warning("API key verification failed")
|
||||
return None
|
||||
|
||||
return APIKey.from_db(api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating API key: {str(e)}")
|
||||
raise APIKeyValidationError(f"Failed to validate API key: {str(e)}")
|
||||
|
||||
|
||||
async def revoke_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to revoke this API key."
|
||||
# Migrate legacy keys to secure format on successful validation
|
||||
if matched_api_key.salt is None:
|
||||
matched_api_key = await _migrate_key_to_secure_hash(
|
||||
plaintext_key, matched_api_key
|
||||
)
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(
|
||||
status=APIKeyStatus.REVOKED, revokedAt=datetime.now(timezone.utc)
|
||||
),
|
||||
)
|
||||
return matched_api_key.without_hash()
|
||||
except Exception as e:
|
||||
logger.error(f"Error while validating API key: {e}")
|
||||
raise RuntimeError("Failed to validate API key") from e
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
|
||||
async def _migrate_key_to_secure_hash(
|
||||
plaintext_key: str, key_obj: APIKeyInfoWithHash
|
||||
) -> APIKeyInfoWithHash:
|
||||
"""Replace the SHA256 hash of a legacy API key with a salted Scrypt hash."""
|
||||
try:
|
||||
new_hash, new_salt = keysmith.hash_key(plaintext_key)
|
||||
await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_obj.id}, data={"hash": new_hash, "salt": new_salt}
|
||||
)
|
||||
logger.info(f"Migrated legacy API key #{key_obj.id} to secure format")
|
||||
# Update the API key object with new values for return
|
||||
key_obj.hash = new_hash
|
||||
key_obj.salt = new_salt
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate legacy API key #{key_obj.id}: {e}")
|
||||
|
||||
return key_obj
|
||||
|
||||
|
||||
async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to revoke this API key.")
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_id},
|
||||
data={
|
||||
"status": APIKeyStatus.REVOKED,
|
||||
"revokedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to revoke.")
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyInfo.from_db(key) for key in api_keys]
|
||||
|
||||
|
||||
async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
selector: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where=selector)
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to suspend this API key.")
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=selector, data={"status": APIKeyStatus.SUSPENDED}
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to suspend.")
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
|
||||
return required_permission in api_key.permissions
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where={"id": key_id, "userId": user_id}
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while revoking API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while revoking API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> List[APIKeyWithoutHash]:
|
||||
try:
|
||||
where_clause: APIKeyWhereInput = {"userId": user_id}
|
||||
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where=where_clause, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyWithoutHash.from_db(key) for key in api_keys]
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while listing API keys: {str(e)}")
|
||||
raise APIKeyError(f"Failed to list API keys: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing API keys: {str(e)}")
|
||||
raise APIKeyError(f"Failed to list API keys: {str(e)}")
|
||||
|
||||
|
||||
async def suspend_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to suspend this API key."
|
||||
)
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(status=APIKeyStatus.SUSPENDED),
|
||||
)
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while suspending API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while suspending API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
|
||||
|
||||
|
||||
def has_permission(api_key: APIKey, required_permission: APIKeyPermission) -> bool:
|
||||
try:
|
||||
return required_permission in api_key.permissions
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking API key permissions: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where=APIKeyWhereInput(id=key_id, userId=user_id)
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
return APIKeyWithoutHash.from_db(api_key)
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while getting API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to get API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to get API key: {str(e)}")
|
||||
return APIKeyInfo.from_db(api_key)
|
||||
|
||||
|
||||
async def update_api_key_permissions(
|
||||
key_id: str, user_id: str, permissions: List[APIKeyPermission]
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
key_id: str, user_id: str, permissions: list[APIKeyPermission]
|
||||
) -> APIKeyInfo:
|
||||
"""
|
||||
Update the permissions of an API key.
|
||||
"""
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if api_key is None:
|
||||
raise APIKeyNotFoundError("No such API key found.")
|
||||
if api_key is None:
|
||||
raise NotFoundError("No such API key found.")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to update this API key."
|
||||
)
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to update this API key.")
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(permissions=permissions),
|
||||
)
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_id},
|
||||
data={"permissions": permissions},
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to update.")
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while updating API key permissions: {str(e)}")
|
||||
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while updating API key permissions: {str(e)}")
|
||||
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
@@ -8,6 +8,7 @@ from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generic,
|
||||
Optional,
|
||||
@@ -44,9 +45,10 @@ if TYPE_CHECKING:
|
||||
|
||||
app_config = Config()
|
||||
|
||||
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
|
||||
BlockInput = dict[str, Any] # Input: 1 input pin consumes 1 data.
|
||||
BlockOutput = AsyncGen[BlockData, None] # Output: 1 output pin produces n data.
|
||||
BlockOutputEntry = tuple[str, Any] # Output data should be a tuple of (name, value).
|
||||
BlockOutput = AsyncGen[BlockOutputEntry, None] # Output: 1 output pin produces n data.
|
||||
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
|
||||
CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a dict.
|
||||
|
||||
|
||||
@@ -89,6 +91,45 @@ class BlockCategory(Enum):
|
||||
return {"category": self.name, "description": self.value}
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
|
||||
|
||||
class BlockInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
inputSchema: dict[str, Any]
|
||||
outputSchema: dict[str, Any]
|
||||
costs: list[BlockCost]
|
||||
description: str
|
||||
categories: list[dict[str, str]]
|
||||
contributors: list[dict[str, Any]]
|
||||
staticOutput: bool
|
||||
uiType: str
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
|
||||
@@ -306,7 +347,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
input_schema: Type[BlockSchemaInputType] = EmptySchema,
|
||||
output_schema: Type[BlockSchemaOutputType] = EmptySchema,
|
||||
test_input: BlockInput | list[BlockInput] | None = None,
|
||||
test_output: BlockData | list[BlockData] | None = None,
|
||||
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
|
||||
test_mock: dict[str, Any] | None = None,
|
||||
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
|
||||
disabled: bool = False,
|
||||
@@ -452,6 +493,24 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
"uiType": self.block_type.value,
|
||||
}
|
||||
|
||||
def get_info(self) -> BlockInfo:
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
return BlockInfo(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
inputSchema=self.input_schema.jsonschema(),
|
||||
outputSchema=self.output_schema.jsonschema(),
|
||||
costs=get_block_cost(self),
|
||||
description=self.description,
|
||||
categories=[category.dict() for category in self.categories],
|
||||
contributors=[
|
||||
contributor.model_dump() for contributor in self.contributors
|
||||
],
|
||||
staticOutput=self.static_output,
|
||||
uiType=self.block_type.value,
|
||||
)
|
||||
|
||||
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise ValueError(
|
||||
|
||||
@@ -29,8 +29,7 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data.block import Block
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
@@ -307,7 +306,18 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
"type": ideogram_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=18,
|
||||
cost_filter={
|
||||
"ideogram_model_name": "V_3",
|
||||
"credentials": {
|
||||
"id": ideogram_credentials.id,
|
||||
"provider": ideogram_credentials.provider,
|
||||
"type": ideogram_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
AIShortformVideoCreatorBlock: [
|
||||
BlockCost(
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import stripe
|
||||
from prisma import Json
|
||||
@@ -23,7 +23,6 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -34,13 +33,16 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.model import Pagination
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockCost
|
||||
|
||||
settings = Settings()
|
||||
stripe.api_key = settings.secrets.stripe_api_key
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -997,10 +999,14 @@ def get_user_credit_model() -> UserCreditBase:
|
||||
return UserCredit()
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||
def get_block_costs() -> dict[str, list["BlockCost"]]:
|
||||
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
||||
|
||||
|
||||
def get_block_cost(block: "Block") -> list["BlockCost"]:
|
||||
return BLOCK_COSTS.get(block.__class__, [])
|
||||
|
||||
|
||||
async def get_stripe_customer_id(user_id: str) -> str:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from prisma.models import CreditTransaction
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.execution import NodeExecutionEntry, UserContext
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.credentials_store import openai_credentials
|
||||
@@ -75,6 +75,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
"type": openai_credentials.type,
|
||||
},
|
||||
},
|
||||
user_context=UserContext(timezone="UTC"),
|
||||
),
|
||||
)
|
||||
assert spending_amount_1 > 0
|
||||
@@ -88,6 +89,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
node_exec_id="test_node_exec",
|
||||
block_id=AITextGeneratorBlock().id,
|
||||
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
user_context=UserContext(timezone="UTC"),
|
||||
),
|
||||
)
|
||||
assert spending_amount_2 == 0
|
||||
|
||||
@@ -11,11 +11,14 @@ from typing import (
|
||||
Generator,
|
||||
Generic,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
@@ -24,7 +27,6 @@ from prisma.models import (
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionUpdateManyMutationInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
@@ -39,6 +41,7 @@ from pydantic.fields import Field
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Config
|
||||
from backend.util.truncate import truncate
|
||||
@@ -59,7 +62,7 @@ from .includes import (
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
graph_execution_include,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -86,13 +89,19 @@ class BlockErrorStats(BaseModel):
|
||||
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
NodeInputMask = Mapping[str, JsonValue]
|
||||
NodesInputMasks = Mapping[str, NodeInputMask]
|
||||
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
id: str # type: ignore # Override base class to make this required
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
preset_id: Optional[str] = None
|
||||
inputs: Optional[BlockInput] # no default -> required in the OpenAPI spec
|
||||
credential_inputs: Optional[dict[str, CredentialsMetaInput]]
|
||||
nodes_input_masks: Optional[dict[str, BlockInput]]
|
||||
preset_id: Optional[str]
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
@@ -177,6 +186,18 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
user_id=_graph_exec.userId,
|
||||
graph_id=_graph_exec.agentGraphId,
|
||||
graph_version=_graph_exec.agentGraphVersion,
|
||||
inputs=cast(BlockInput | None, _graph_exec.inputs),
|
||||
credential_inputs=(
|
||||
{
|
||||
name: CredentialsMetaInput.model_validate(cmi)
|
||||
for name, cmi in cast(dict, _graph_exec.credentialInputs).items()
|
||||
}
|
||||
if _graph_exec.credentialInputs
|
||||
else None
|
||||
),
|
||||
nodes_input_masks=cast(
|
||||
dict[str, BlockInput] | None, _graph_exec.nodesInputMasks
|
||||
),
|
||||
preset_id=_graph_exec.agentPresetId,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
started_at=start_time,
|
||||
@@ -204,7 +225,7 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
|
||||
|
||||
class GraphExecution(GraphExecutionMeta):
|
||||
inputs: BlockInput
|
||||
inputs: BlockInput # type: ignore - incompatible override is intentional
|
||||
outputs: CompletedBlockOutput
|
||||
|
||||
@staticmethod
|
||||
@@ -224,15 +245,18 @@ class GraphExecution(GraphExecutionMeta):
|
||||
)
|
||||
|
||||
inputs = {
|
||||
**{
|
||||
# inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
},
|
||||
**(
|
||||
graph_exec.inputs
|
||||
or {
|
||||
# fallback: extract inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
}
|
||||
),
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
@@ -250,14 +274,13 @@ class GraphExecution(GraphExecutionMeta):
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
outputs[exec.input_data["name"]].append(
|
||||
exec.input_data.get("value", None)
|
||||
)
|
||||
outputs[exec.input_data["name"]].append(exec.input_data.get("value"))
|
||||
|
||||
return GraphExecution(
|
||||
**{
|
||||
field_name: getattr(graph_exec, field_name)
|
||||
for field_name in GraphExecutionMeta.model_fields
|
||||
if field_name != "inputs"
|
||||
},
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
@@ -290,13 +313,18 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
def to_graph_execution_entry(self):
|
||||
def to_graph_execution_entry(
|
||||
self,
|
||||
user_context: "UserContext",
|
||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
|
||||
nodes_input_masks=compiled_nodes_input_masks,
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -332,7 +360,7 @@ class NodeExecutionResult(BaseModel):
|
||||
else:
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in _node_exec.Input or []:
|
||||
input_data[data.name] = type_utils.convert(data.data, type[Any])
|
||||
input_data[data.name] = type_utils.convert(data.data, JsonValue)
|
||||
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
|
||||
@@ -341,7 +369,7 @@ class NodeExecutionResult(BaseModel):
|
||||
output_data[name].extend(messages)
|
||||
else:
|
||||
for data in _node_exec.Output or []:
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
output_data[data.name].append(type_utils.convert(data.data, JsonValue))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
|
||||
if graph_execution:
|
||||
@@ -368,7 +396,9 @@ class NodeExecutionResult(BaseModel):
|
||||
end_time=_node_exec.endedTime,
|
||||
)
|
||||
|
||||
def to_node_execution_entry(self) -> "NodeExecutionEntry":
|
||||
def to_node_execution_entry(
|
||||
self, user_context: "UserContext"
|
||||
) -> "NodeExecutionEntry":
|
||||
return NodeExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_exec_id=self.graph_exec_id,
|
||||
@@ -377,6 +407,7 @@ class NodeExecutionResult(BaseModel):
|
||||
node_id=self.node_id,
|
||||
block_id=self.block_id,
|
||||
inputs=self.input_data,
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -384,13 +415,13 @@ class NodeExecutionResult(BaseModel):
|
||||
|
||||
|
||||
async def get_graph_executions(
|
||||
graph_exec_id: str | None = None,
|
||||
graph_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
created_time_gte: datetime | None = None,
|
||||
created_time_lte: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
graph_exec_id: Optional[str] = None,
|
||||
graph_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
statuses: Optional[list[ExecutionStatus]] = None,
|
||||
created_time_gte: Optional[datetime] = None,
|
||||
created_time_lte: Optional[datetime] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
@@ -418,6 +449,60 @@ async def get_graph_executions(
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
class GraphExecutionsPaginated(BaseModel):
|
||||
"""Response schema for paginated graph executions."""
|
||||
|
||||
executions: list[GraphExecutionMeta]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
async def get_graph_executions_paginated(
|
||||
user_id: str,
|
||||
graph_id: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
statuses: Optional[list[ExecutionStatus]] = None,
|
||||
created_time_gte: Optional[datetime] = None,
|
||||
created_time_lte: Optional[datetime] = None,
|
||||
) -> GraphExecutionsPaginated:
|
||||
"""Get paginated graph executions for a specific graph."""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
"userId": user_id,
|
||||
}
|
||||
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
if created_time_gte or created_time_lte:
|
||||
where_filter["createdAt"] = {
|
||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
if statuses:
|
||||
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
total_count = await AgentGraphExecution.prisma().count(where=where_filter)
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where=where_filter,
|
||||
order={"createdAt": "desc"},
|
||||
take=page_size,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
return GraphExecutionsPaginated(
|
||||
executions=[GraphExecutionMeta.from_db(execution) for execution in executions],
|
||||
pagination=Pagination(
|
||||
total_items=total_count,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_execution_meta(
|
||||
user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
@@ -479,9 +564,12 @@ async def get_graph_execution(
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
starting_nodes_input: list[tuple[str, BlockInput]],
|
||||
starting_nodes_input: list[tuple[str, BlockInput]], # list[(node_id, BlockInput)]
|
||||
inputs: Mapping[str, JsonValue],
|
||||
user_id: str,
|
||||
preset_id: str | None = None,
|
||||
preset_id: Optional[str] = None,
|
||||
credential_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
@@ -489,11 +577,18 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data=AgentGraphExecutionCreateInput(
|
||||
agentGraphId=graph_id,
|
||||
agentGraphVersion=graph_version,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
NodeExecutions={
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
@@ -509,9 +604,9 @@ async def create_graph_execution(
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
userId=user_id,
|
||||
agentPresetId=preset_id,
|
||||
),
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -522,7 +617,7 @@ async def upsert_execution_input(
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
input_data: JsonValue,
|
||||
node_exec_id: str | None = None,
|
||||
) -> tuple[str, BlockInput]:
|
||||
"""
|
||||
@@ -571,7 +666,7 @@ async def upsert_execution_input(
|
||||
)
|
||||
return existing_execution.id, {
|
||||
**{
|
||||
input_data.name: type_utils.convert(input_data.data, type[Any])
|
||||
input_data.name: type_utils.convert(input_data.data, JsonValue)
|
||||
for input_data in existing_execution.Input or []
|
||||
},
|
||||
input_name: input_data,
|
||||
@@ -817,12 +912,19 @@ async def get_latest_node_execution(
|
||||
# ----------------- Execution Infrastructure ----------------- #
|
||||
|
||||
|
||||
class UserContext(BaseModel):
|
||||
"""Generic user context for graph execution containing user-specific settings."""
|
||||
|
||||
timezone: str
|
||||
|
||||
|
||||
class GraphExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||
user_context: UserContext
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
@@ -833,6 +935,7 @@ class NodeExecutionEntry(BaseModel):
|
||||
node_id: str
|
||||
block_id: str
|
||||
inputs: BlockInput
|
||||
user_context: UserContext
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
|
||||
@@ -12,7 +12,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import Field, JsonValue, create_model
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -34,6 +34,7 @@ from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .execution import NodesInputMasks
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -159,6 +160,7 @@ class BaseGraph(BaseDbModel):
|
||||
is_active: bool = True
|
||||
name: str
|
||||
description: str
|
||||
recommended_schedule_cron: str | None = None
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
forked_from_id: str | None = None
|
||||
@@ -205,6 +207,35 @@ class BaseGraph(BaseDbModel):
|
||||
None,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def trigger_setup_info(self) -> "GraphTriggerInfo | None":
|
||||
if not (
|
||||
self.webhook_input_node
|
||||
and (trigger_block := self.webhook_input_node.block).webhook_config
|
||||
):
|
||||
return None
|
||||
|
||||
return GraphTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
config_schema={
|
||||
**(json_schema := trigger_block.input_schema.jsonschema()),
|
||||
"properties": {
|
||||
pn: sub_schema
|
||||
for pn, sub_schema in json_schema["properties"].items()
|
||||
if not is_credentials_field_name(pn)
|
||||
},
|
||||
"required": [
|
||||
pn
|
||||
for pn in json_schema.get("required", [])
|
||||
if not is_credentials_field_name(pn)
|
||||
],
|
||||
},
|
||||
credentials_input_name=next(
|
||||
iter(trigger_block.input_schema.get_credentials_fields()), None
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
@@ -238,6 +269,14 @@ class BaseGraph(BaseDbModel):
|
||||
}
|
||||
|
||||
|
||||
class GraphTriggerInfo(BaseModel):
|
||||
provider: ProviderName
|
||||
config_schema: dict[str, Any] = Field(
|
||||
description="Input schema for the trigger block"
|
||||
)
|
||||
credentials_input_name: Optional[str]
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||
|
||||
@@ -414,7 +453,7 @@ class GraphModel(Graph):
|
||||
def validate_graph(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
):
|
||||
"""
|
||||
Validate graph structure and raise `ValueError` on issues.
|
||||
@@ -428,7 +467,7 @@ class GraphModel(Graph):
|
||||
def _validate_graph(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
) -> None:
|
||||
errors = GraphModel._validate_graph_get_errors(
|
||||
graph, for_run, nodes_input_masks
|
||||
@@ -442,7 +481,7 @@ class GraphModel(Graph):
|
||||
def validate_graph_get_errors(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph and return structured errors per node.
|
||||
@@ -464,7 +503,7 @@ class GraphModel(Graph):
|
||||
def _validate_graph_get_errors(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph and return structured errors per node.
|
||||
@@ -658,6 +697,7 @@ class GraphModel(Graph):
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
|
||||
links=list(
|
||||
{
|
||||
@@ -1045,6 +1085,7 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
version=graph.version,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
recommendedScheduleCron=graph.recommended_schedule_cron,
|
||||
isActive=graph.is_active,
|
||||
userId=user_id,
|
||||
forkedFromId=graph.forked_from_id,
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.exceptions
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
@@ -317,12 +316,7 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
),
|
||||
autogpt_libs.auth.models.User(
|
||||
user_id=admin_user.id,
|
||||
role="admin",
|
||||
email=admin_user.email,
|
||||
phone_number="1234567890",
|
||||
),
|
||||
user_id=admin_user.id,
|
||||
)
|
||||
|
||||
# Now we check the graph can be accessed by a user that does not own the graph
|
||||
|
||||
@@ -59,9 +59,15 @@ def graph_execution_include(
|
||||
}
|
||||
|
||||
|
||||
AGENT_PRESET_INCLUDE: prisma.types.AgentPresetInclude = {
|
||||
"InputPresets": True,
|
||||
"Webhook": True,
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE},
|
||||
"AgentPresets": {"include": {"InputPresets": True}},
|
||||
"AgentPresets": {"include": AGENT_PRESET_INCLUDE},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -96,6 +96,12 @@ class User(BaseModel):
|
||||
default=True, description="Notify on monthly summary"
|
||||
)
|
||||
|
||||
# User timezone for scheduling and time display
|
||||
timezone: str = Field(
|
||||
default="not-set",
|
||||
description="User timezone (IANA timezone identifier or 'not-set')",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||
"""Convert a database User object to application User model."""
|
||||
@@ -149,6 +155,7 @@ class User(BaseModel):
|
||||
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
|
||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||
timezone=prisma_user.timezone or "not-set",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -54,19 +54,6 @@ class AgentRunData(BaseNotificationData):
|
||||
|
||||
|
||||
class ZeroBalanceData(BaseNotificationData):
|
||||
last_transaction: float
|
||||
last_transaction_time: datetime
|
||||
top_up_link: str
|
||||
|
||||
@field_validator("last_transaction_time")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class LowBalanceData(BaseNotificationData):
|
||||
agent_name: str = Field(..., description="Name of the agent")
|
||||
current_balance: float = Field(
|
||||
..., description="Current balance in credits (100 = $1)"
|
||||
@@ -75,6 +62,13 @@ class LowBalanceData(BaseNotificationData):
|
||||
shortfall: float = Field(..., description="Amount of credits needed to continue")
|
||||
|
||||
|
||||
class LowBalanceData(BaseNotificationData):
|
||||
current_balance: float = Field(
|
||||
..., description="Current balance in credits (100 = $1)"
|
||||
)
|
||||
billing_page_link: str = Field(..., description="Link to billing page")
|
||||
|
||||
|
||||
class BlockExecutionFailedData(BaseNotificationData):
|
||||
block_name: str
|
||||
block_id: str
|
||||
@@ -181,6 +175,42 @@ class RefundRequestData(BaseNotificationData):
|
||||
balance: int
|
||||
|
||||
|
||||
class AgentApprovalData(BaseNotificationData):
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
reviewer_name: str
|
||||
reviewer_email: str
|
||||
comments: str
|
||||
reviewed_at: datetime
|
||||
store_url: str
|
||||
|
||||
@field_validator("reviewed_at")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class AgentRejectionData(BaseNotificationData):
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
reviewer_name: str
|
||||
reviewer_email: str
|
||||
comments: str
|
||||
reviewed_at: datetime
|
||||
resubmit_url: str
|
||||
|
||||
@field_validator("reviewed_at")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
NotificationData = Annotated[
|
||||
Union[
|
||||
AgentRunData,
|
||||
@@ -240,6 +270,8 @@ def get_notif_data_type(
|
||||
NotificationType.MONTHLY_SUMMARY: MonthlySummaryData,
|
||||
NotificationType.REFUND_REQUEST: RefundRequestData,
|
||||
NotificationType.REFUND_PROCESSED: RefundRequestData,
|
||||
NotificationType.AGENT_APPROVED: AgentApprovalData,
|
||||
NotificationType.AGENT_REJECTED: AgentRejectionData,
|
||||
}[notification_type]
|
||||
|
||||
|
||||
@@ -274,7 +306,7 @@ class NotificationTypeOverride:
|
||||
# These are batched by the notification service
|
||||
NotificationType.AGENT_RUN: QueueType.BATCH,
|
||||
# These are batched by the notification service, but with a backoff strategy
|
||||
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
|
||||
NotificationType.ZERO_BALANCE: QueueType.IMMEDIATE,
|
||||
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
|
||||
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
|
||||
@@ -283,6 +315,8 @@ class NotificationTypeOverride:
|
||||
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
|
||||
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
|
||||
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
|
||||
NotificationType.AGENT_APPROVED: QueueType.IMMEDIATE,
|
||||
NotificationType.AGENT_REJECTED: QueueType.IMMEDIATE,
|
||||
}
|
||||
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
|
||||
|
||||
@@ -300,6 +334,8 @@ class NotificationTypeOverride:
|
||||
NotificationType.MONTHLY_SUMMARY: "monthly_summary.html",
|
||||
NotificationType.REFUND_REQUEST: "refund_request.html",
|
||||
NotificationType.REFUND_PROCESSED: "refund_processed.html",
|
||||
NotificationType.AGENT_APPROVED: "agent_approved.html",
|
||||
NotificationType.AGENT_REJECTED: "agent_rejected.html",
|
||||
}[self.notification_type]
|
||||
|
||||
@property
|
||||
@@ -315,6 +351,8 @@ class NotificationTypeOverride:
|
||||
NotificationType.MONTHLY_SUMMARY: "We did a lot this month!",
|
||||
NotificationType.REFUND_REQUEST: "[ACTION REQUIRED] You got a ${{data.amount / 100}} refund request from {{data.user_name}}",
|
||||
NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed",
|
||||
NotificationType.AGENT_APPROVED: "🎉 Your agent '{{data.agent_name}}' has been approved!",
|
||||
NotificationType.AGENT_REJECTED: "Your agent '{{data.agent_name}}' needs some updates",
|
||||
}[self.notification_type]
|
||||
|
||||
|
||||
|
||||
151
autogpt_platform/backend/backend/data/notifications_test.py
Normal file
151
autogpt_platform/backend/backend/data/notifications_test.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for notification data models."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.data.notifications import AgentApprovalData, AgentRejectionData
|
||||
|
||||
|
||||
class TestAgentApprovalData:
|
||||
"""Test cases for AgentApprovalData model."""
|
||||
|
||||
def test_valid_agent_approval_data(self):
|
||||
"""Test creating valid AgentApprovalData."""
|
||||
data = AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="Great agent, approved!",
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.agent_id == "test-agent-123"
|
||||
assert data.agent_version == 1
|
||||
assert data.reviewer_name == "John Doe"
|
||||
assert data.reviewer_email == "john@example.com"
|
||||
assert data.comments == "Great agent, approved!"
|
||||
assert data.store_url == "https://app.autogpt.com/store/test-agent-123"
|
||||
assert data.reviewed_at.tzinfo is not None
|
||||
|
||||
def test_agent_approval_data_without_timezone_raises_error(self):
|
||||
"""Test that AgentApprovalData raises error without timezone."""
|
||||
with pytest.raises(
|
||||
ValidationError, match="datetime must have timezone information"
|
||||
):
|
||||
AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="Great agent, approved!",
|
||||
reviewed_at=datetime.now(), # No timezone
|
||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
||||
)
|
||||
|
||||
def test_agent_approval_data_with_empty_comments(self):
|
||||
"""Test AgentApprovalData with empty comments."""
|
||||
data = AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="", # Empty comments
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.comments == ""
|
||||
|
||||
|
||||
class TestAgentRejectionData:
|
||||
"""Test cases for AgentRejectionData model."""
|
||||
|
||||
def test_valid_agent_rejection_data(self):
|
||||
"""Test creating valid AgentRejectionData."""
|
||||
data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the security issues before resubmitting.",
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.agent_id == "test-agent-123"
|
||||
assert data.agent_version == 1
|
||||
assert data.reviewer_name == "Jane Doe"
|
||||
assert data.reviewer_email == "jane@example.com"
|
||||
assert data.comments == "Please fix the security issues before resubmitting."
|
||||
assert data.resubmit_url == "https://app.autogpt.com/build/test-agent-123"
|
||||
assert data.reviewed_at.tzinfo is not None
|
||||
|
||||
def test_agent_rejection_data_without_timezone_raises_error(self):
|
||||
"""Test that AgentRejectionData raises error without timezone."""
|
||||
with pytest.raises(
|
||||
ValidationError, match="datetime must have timezone information"
|
||||
):
|
||||
AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the security issues.",
|
||||
reviewed_at=datetime.now(), # No timezone
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
def test_agent_rejection_data_with_long_comments(self):
|
||||
"""Test AgentRejectionData with long comments."""
|
||||
long_comment = "A" * 1000 # Very long comment
|
||||
data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments=long_comment,
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.comments == long_comment
|
||||
|
||||
def test_model_serialization(self):
|
||||
"""Test that models can be serialized and deserialized."""
|
||||
original_data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the issues.",
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
# Serialize to dict
|
||||
data_dict = original_data.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_data = AgentRejectionData.model_validate(data_dict)
|
||||
|
||||
assert restored_data.agent_name == original_data.agent_name
|
||||
assert restored_data.agent_id == original_data.agent_id
|
||||
assert restored_data.agent_version == original_data.agent_version
|
||||
assert restored_data.reviewer_name == original_data.reviewer_name
|
||||
assert restored_data.reviewer_email == original_data.reviewer_email
|
||||
assert restored_data.comments == original_data.comments
|
||||
assert restored_data.reviewed_at == original_data.reviewed_at
|
||||
assert restored_data.resubmit_url == original_data.resubmit_url
|
||||
@@ -208,6 +208,8 @@ async def get_user_notification_preference(user_id: str) -> NotificationPreferen
|
||||
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or False,
|
||||
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or False,
|
||||
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or False,
|
||||
NotificationType.AGENT_APPROVED: user.notifyOnAgentApproved or False,
|
||||
NotificationType.AGENT_REJECTED: user.notifyOnAgentRejected or False,
|
||||
}
|
||||
daily_limit = user.maxEmailsPerDay or 3
|
||||
notification_preference = NotificationPreference(
|
||||
@@ -266,6 +268,14 @@ async def update_user_notification_preference(
|
||||
update_data["notifyOnMonthlySummary"] = data.preferences[
|
||||
NotificationType.MONTHLY_SUMMARY
|
||||
]
|
||||
if NotificationType.AGENT_APPROVED in data.preferences:
|
||||
update_data["notifyOnAgentApproved"] = data.preferences[
|
||||
NotificationType.AGENT_APPROVED
|
||||
]
|
||||
if NotificationType.AGENT_REJECTED in data.preferences:
|
||||
update_data["notifyOnAgentRejected"] = data.preferences[
|
||||
NotificationType.AGENT_REJECTED
|
||||
]
|
||||
if data.daily_limit:
|
||||
update_data["maxEmailsPerDay"] = data.daily_limit
|
||||
|
||||
@@ -286,6 +296,8 @@ async def update_user_notification_preference(
|
||||
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or True,
|
||||
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or True,
|
||||
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or True,
|
||||
NotificationType.AGENT_APPROVED: user.notifyOnAgentApproved or True,
|
||||
NotificationType.AGENT_REJECTED: user.notifyOnAgentRejected or True,
|
||||
}
|
||||
notification_preference = NotificationPreference(
|
||||
user_id=user.id,
|
||||
@@ -384,3 +396,17 @@ async def unsubscribe_user_by_token(token: str) -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e
|
||||
|
||||
|
||||
async def update_user_timezone(user_id: str, timezone: str) -> User:
|
||||
"""Update a user's timezone setting."""
|
||||
try:
|
||||
user = await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"timezone": timezone},
|
||||
)
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
return User.from_db(user)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to update timezone for user {user_id}: {e}") from e
|
||||
|
||||
@@ -42,6 +42,8 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
|
||||
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -145,6 +147,14 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(list_library_agents)
|
||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
@@ -173,12 +183,20 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
spend_credits = _(d.spend_credits)
|
||||
get_credits = _(d.get_credits)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(d.get_user_execution_summary_data)
|
||||
|
||||
# Block error monitoring
|
||||
get_block_error_stats = _(d.get_block_error_stats)
|
||||
|
||||
# User Emails
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(d.list_library_agents)
|
||||
add_store_agent_to_library = _(d.add_store_agent_to_library)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(d.get_store_agents)
|
||||
get_store_agent_details = _(d.get_store_agent_details)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -223,5 +241,13 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
list_library_agents = d.list_library_agents
|
||||
add_store_agent_to_library = d.add_store_agent_to_library
|
||||
|
||||
# Store
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
|
||||
# Summary data
|
||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from pydantic import JsonValue
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
@@ -20,6 +19,7 @@ from backend.data.notifications import (
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
@@ -37,9 +37,9 @@ from prometheus_client import Gauge, start_http_server
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockOutputEntry,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
@@ -51,6 +51,8 @@ from backend.data.execution import (
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.utils import (
|
||||
@@ -74,6 +76,7 @@ from backend.util.clients import (
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.decorator import (
|
||||
async_error_logged,
|
||||
@@ -83,6 +86,7 @@ from backend.util.decorator import (
|
||||
)
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -127,7 +131,7 @@ async def execute_node(
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -190,6 +194,9 @@ async def execute_node(
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
# Add user context from NodeExecutionEntry
|
||||
extra_exec_kwargs["user_context"] = data.user_context
|
||||
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
# changes during execution. ⚠️ This means a set of credentials can only be used by
|
||||
# one (running) block at a time; simultaneous execution of blocks using same
|
||||
@@ -230,12 +237,13 @@ async def execute_node(
|
||||
async def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
output: BlockOutputEntry,
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
user_context: UserContext,
|
||||
) -> list[NodeExecutionEntry]:
|
||||
async def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
@@ -254,6 +262,7 @@ async def _enqueue_next_nodes(
|
||||
node_id=node_id,
|
||||
block_id=block_id,
|
||||
inputs=data,
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
async def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
||||
@@ -410,7 +419,7 @@ class ExecutionProcessor:
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
@@ -478,7 +487,7 @@ class ExecutionProcessor:
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
@@ -678,19 +687,20 @@ class ExecutionProcessor:
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
) -> int:
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return total_cost
|
||||
return total_cost, 0
|
||||
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
db_client.spend_credits(
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -708,7 +718,7 @@ class ExecutionProcessor:
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
db_client.spend_credits(
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -723,7 +733,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
return total_cost
|
||||
return total_cost, remaining_balance
|
||||
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
@@ -787,7 +797,8 @@ class ExecutionProcessor:
|
||||
ExecutionStatus.TERMINATED,
|
||||
],
|
||||
):
|
||||
execution_queue.add(node_exec.to_node_execution_entry())
|
||||
node_entry = node_exec.to_node_execution_entry(graph_exec.user_context)
|
||||
execution_queue.add(node_entry)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Main dispatch / polling loop -----------------------------
|
||||
@@ -805,12 +816,19 @@ class ExecutionProcessor:
|
||||
|
||||
# Charge usage (may raise) ------------------------------
|
||||
try:
|
||||
cost = self._charge_usage(
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
)
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
db_client=db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
transaction_cost=cost,
|
||||
)
|
||||
except InsufficientBalanceError as balance_error:
|
||||
error = balance_error # Set error to trigger FAILED status
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
@@ -825,11 +843,10 @@ class ExecutionProcessor:
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
self._handle_low_balance_notif(
|
||||
self._handle_insufficient_funds_notif(
|
||||
db_client,
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
execution_stats,
|
||||
error,
|
||||
)
|
||||
# Gracefully stop the execution loop
|
||||
@@ -1036,7 +1053,7 @@ class ExecutionProcessor:
|
||||
node_id: str,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||||
) -> None:
|
||||
"""Process a node's output, update its status, and enqueue next nodes.
|
||||
@@ -1052,6 +1069,7 @@ class ExecutionProcessor:
|
||||
db_client = get_db_async_client()
|
||||
|
||||
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
|
||||
|
||||
for next_execution in await _enqueue_next_nodes(
|
||||
db_client=db_client,
|
||||
node=output.node,
|
||||
@@ -1061,6 +1079,7 @@ class ExecutionProcessor:
|
||||
graph_id=graph_exec.graph_id,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
user_context=graph_exec.user_context,
|
||||
):
|
||||
execution_queue.add(next_execution)
|
||||
|
||||
@@ -1101,25 +1120,25 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_low_balance_notif(
|
||||
def _handle_insufficient_funds_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
exec_stats: GraphExecutionStats,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
shortfall = e.balance - e.amount
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=exec_stats.cost,
|
||||
type=NotificationType.ZERO_BALANCE,
|
||||
data=ZeroBalanceData(
|
||||
current_balance=e.balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
shortfall=shortfall,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
@@ -1127,6 +1146,73 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||||
f"Current balance: ${e.balance/100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount)/100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall)/100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as alert_error:
|
||||
logger.error(
|
||||
f"Failed to send insufficient funds Discord alert: {alert_error}"
|
||||
)
|
||||
|
||||
def _handle_low_balance(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
):
|
||||
"""Check and handle low balance scenarios after a transaction"""
|
||||
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
|
||||
|
||||
balance_before = current_balance + transaction_cost
|
||||
|
||||
if (
|
||||
current_balance < LOW_BALANCE_THRESHOLD
|
||||
and balance_before >= LOW_BALANCE_THRESHOLD
|
||||
):
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=current_balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD/100:.2f}\n"
|
||||
f"Current balance: ${current_balance/100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost/100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send low balance Discord alert: {e}")
|
||||
|
||||
|
||||
class ExecutionManager(AppProcess):
|
||||
def __init__(self):
|
||||
@@ -1308,14 +1394,14 @@ class ExecutionManager(AppProcess):
|
||||
delivery_tag = method.delivery_tag
|
||||
|
||||
@func_retry
|
||||
def _ack_message(reject: bool = False):
|
||||
def _ack_message(reject: bool, requeue: bool):
|
||||
"""Acknowledge or reject the message based on execution status."""
|
||||
|
||||
# Connection can be lost, so always get a fresh channel
|
||||
channel = self.run_client.get_channel()
|
||||
if reject:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=True)
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=requeue)
|
||||
)
|
||||
else:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
@@ -1327,13 +1413,13 @@ class ExecutionManager(AppProcess):
|
||||
logger.info(
|
||||
f"[{self.service_name}] Rejecting new execution during shutdown"
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Check if we can accept more runs
|
||||
self._cleanup_completed_runs()
|
||||
if len(self.active_graph_runs) >= self.pool_size:
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -1342,7 +1428,7 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Could not parse run message: {e}, body={body}"
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
@@ -1354,7 +1440,7 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
cancel_event = threading.Event()
|
||||
@@ -1370,9 +1456,9 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}"
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
else:
|
||||
_ack_message(reject=False)
|
||||
_ack_message(reject=False, requeue=False)
|
||||
except BaseException as e:
|
||||
logger.exception(
|
||||
f"[{self.service_name}] Error in run completion callback: {e}"
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import LowBalanceData
|
||||
from backend.executor.manager import ExecutionProcessor
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
"""Test that _handle_low_balance triggers notification when crossing threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 400 # $4 - below $5 threshold
|
||||
transaction_cost = 600 # $6 transaction
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
)
|
||||
|
||||
# Verify notification was queued
|
||||
mock_queue_notif.assert_called_once()
|
||||
notification_call = mock_queue_notif.call_args[0][0]
|
||||
|
||||
# Verify notification details
|
||||
assert notification_call.type == NotificationType.LOW_BALANCE
|
||||
assert notification_call.user_id == user_id
|
||||
assert isinstance(notification_call.data, LowBalanceData)
|
||||
assert notification_call.data.current_balance == current_balance
|
||||
|
||||
# Verify Discord alert was sent
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
discord_message = mock_client.discord_system_alert.call_args[0][0]
|
||||
assert "Low Balance Alert" in discord_message
|
||||
assert "test@example.com" in discord_message
|
||||
assert "$4.00" in discord_message
|
||||
assert "$6.00" in discord_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that no notification is sent when not crossing the threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 600 # $6 - above $5 threshold
|
||||
transaction_cost = (
|
||||
100 # $1 transaction (balance before was $7, still above threshold)
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
)
|
||||
|
||||
# Verify no notification was sent
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that no notification is sent when already below threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 300 # $3 - below $5 threshold
|
||||
transaction_cost = (
|
||||
100 # $1 transaction (balance before was $4, also below threshold)
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
)
|
||||
|
||||
# Verify no notification was sent (user was already below threshold)
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
|
||||
@@ -36,21 +35,20 @@ async def execute_graph(
|
||||
logger.info(f"Input data: {input_data}")
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
response = await agent_server.test_execute_graph(
|
||||
graph_exec = await agent_server.test_execute_graph(
|
||||
user_id=test_user.id,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
node_input=input_data,
|
||||
)
|
||||
graph_exec_id = response.graph_exec_id
|
||||
logger.info(f"Created execution with ID: {graph_exec_id}")
|
||||
logger.info(f"Created execution with ID: {graph_exec.id}")
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, graph_exec_id, 30)
|
||||
result = await wait_execution(test_user.id, graph_exec.id, 30)
|
||||
logger.info(f"Execution completed with {len(result)} results")
|
||||
assert len(result) == num_execs
|
||||
return graph_exec_id
|
||||
return graph_exec.id
|
||||
|
||||
|
||||
async def assert_sample_graph_executions(
|
||||
@@ -380,7 +378,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
|
||||
# Verify execution
|
||||
assert result is not None
|
||||
graph_exec_id = result["id"]
|
||||
graph_exec_id = result.id
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, graph_exec_id)
|
||||
@@ -469,7 +467,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
|
||||
# Verify execution
|
||||
assert result is not None, "Result must not be None"
|
||||
graph_exec_id = result["id"]
|
||||
graph_exec_id = result.id
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, graph_exec_id)
|
||||
@@ -521,12 +519,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
),
|
||||
autogpt_libs.auth.models.User(
|
||||
user_id=admin_user.id,
|
||||
role="admin",
|
||||
email=admin_user.email,
|
||||
phone_number="1234567890",
|
||||
),
|
||||
user_id=admin_user.id,
|
||||
)
|
||||
alt_test_user = admin_user
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from apscheduler.jobstores.memory import MemoryJobStore
|
||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.util import ZoneInfo
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
@@ -303,6 +304,7 @@ class Scheduler(AppService):
|
||||
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
|
||||
},
|
||||
logger=apscheduler_logger,
|
||||
timezone=ZoneInfo("UTC"),
|
||||
)
|
||||
|
||||
if self.register_system_tasks:
|
||||
@@ -406,6 +408,8 @@ class Scheduler(AppService):
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Scheduling job for user {user_id} in UTC (cron: {cron})")
|
||||
|
||||
job_args = GraphExecutionJobArgs(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
@@ -418,12 +422,12 @@ class Scheduler(AppService):
|
||||
execute_graph,
|
||||
kwargs=job_args.model_dump(),
|
||||
name=name,
|
||||
trigger=CronTrigger.from_crontab(cron),
|
||||
trigger=CronTrigger.from_crontab(cron, timezone="UTC"),
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}"
|
||||
f"Added job {job.id} with cron schedule '{cron}' in UTC, input data: {input_data}"
|
||||
)
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
|
||||
@@ -4,24 +4,33 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Mapping, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCostType,
|
||||
BlockInput,
|
||||
BlockOutputEntry,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.db import prisma
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
@@ -34,6 +43,27 @@ from backend.util.mock import MockObject
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
async def get_user_context(user_id: str) -> UserContext:
|
||||
"""
|
||||
Get UserContext for a user, always returns a valid context with timezone.
|
||||
Defaults to UTC if user has no timezone set.
|
||||
"""
|
||||
user_context = UserContext(timezone="UTC") # Default to UTC
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
if user and user.timezone and user.timezone != "not-set":
|
||||
user_context.timezone = user.timezone
|
||||
logger.debug(f"Retrieved user context: timezone={user.timezone}")
|
||||
else:
|
||||
logger.debug("User has no timezone set, using UTC")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch user timezone: {e}")
|
||||
# Continue with UTC as default
|
||||
|
||||
return user_context
|
||||
|
||||
|
||||
config = Config()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||
|
||||
@@ -216,7 +246,7 @@ def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | None:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
@@ -240,7 +270,7 @@ def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: Any = data
|
||||
cur: JsonValue = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
@@ -405,7 +435,7 @@ def validate_exec(
|
||||
async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
@@ -485,8 +515,8 @@ async def _validate_node_input_credentials(
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
graph_credentials_input: Mapping[str, CredentialsMetaInput],
|
||||
) -> NodesInputMasks:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
@@ -521,8 +551,8 @@ def make_node_credentials_input_map(
|
||||
async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
|
||||
@@ -552,7 +582,7 @@ async def _construct_starting_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
@@ -593,7 +623,7 @@ async def _construct_starting_node_execution_input(
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
input_name = cast(str | None, node.input_default.get("name"))
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
@@ -620,9 +650,9 @@ async def validate_and_construct_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], dict[str, dict[str, JsonValue]]]:
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
@@ -636,7 +666,9 @@ async def validate_and_construct_node_execution_input(
|
||||
nodes_input_masks: Node inputs to use.
|
||||
|
||||
Returns:
|
||||
tuple[GraphModel, list[tuple[str, BlockInput]]]: Graph model and list of tuples for node execution input.
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
@@ -677,11 +709,11 @@ async def validate_and_construct_node_execution_input(
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
overrides_map_1: dict[str, dict[str, JsonValue]],
|
||||
overrides_map_2: dict[str, dict[str, JsonValue]],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
overrides_map_1: NodesInputMasks,
|
||||
overrides_map_2: NodesInputMasks,
|
||||
) -> NodesInputMasks:
|
||||
"""Perform a per-node merge of input overrides"""
|
||||
result = overrides_map_1.copy()
|
||||
result = dict(overrides_map_1).copy()
|
||||
for node_id, overrides2 in overrides_map_2.items():
|
||||
if node_id in result:
|
||||
result[node_id] = {**result[node_id], **overrides2}
|
||||
@@ -831,8 +863,8 @@ async def add_graph_execution(
|
||||
inputs: Optional[BlockInput] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
@@ -856,7 +888,7 @@ async def add_graph_execution(
|
||||
else:
|
||||
edb = get_database_manager_async_client()
|
||||
|
||||
graph, starting_nodes_input, nodes_input_masks = (
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
@@ -869,18 +901,26 @@ async def add_graph_execution(
|
||||
graph_exec = None
|
||||
|
||||
try:
|
||||
# Sanity check: running add_graph_execution with the properties of
|
||||
# the graph_exec created here should create the same execution again.
|
||||
graph_exec = await edb.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
inputs=inputs or {},
|
||||
credential_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
# Fetch user context for the graph execution
|
||||
user_context = await get_user_context(user_id)
|
||||
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if nodes_input_masks:
|
||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
user_context, compiled_nodes_input_masks
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created graph execution #{graph_exec.id} for graph "
|
||||
@@ -926,7 +966,7 @@ async def add_graph_execution(
|
||||
class ExecutionOutputEntry(BaseModel):
|
||||
node: Node
|
||||
node_exec_id: str
|
||||
data: BlockData
|
||||
data: BlockOutputEntry
|
||||
|
||||
|
||||
class NodeExecutionProgress:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.executor.utils import merge_execution_input, parse_execution_output
|
||||
from backend.util.mock import MockObject
|
||||
@@ -276,3 +277,142 @@ def test_merge_execution_input():
|
||||
result = merge_execution_input(data)
|
||||
assert "mixed" in result
|
||||
assert result["mixed"].attr[0]["key"] == "value3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
"""
|
||||
Verify that calling the function with its own output creates the same execution again.
|
||||
"""
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Mock data
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
inputs = {"test_input": "test_value"}
|
||||
preset_id = "test-preset-id"
|
||||
graph_version = 1
|
||||
graph_credentials_inputs = {
|
||||
"cred_key": CredentialsMetaInput(
|
||||
id="cred-id", provider=ProviderName("test_provider"), type="oauth2"
|
||||
)
|
||||
}
|
||||
nodes_input_masks = {"node1": {"input1": "masked_value"}}
|
||||
|
||||
# Mock the graph object returned by validate_and_construct_node_execution_input
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = graph_version
|
||||
|
||||
# Mock the starting nodes input and compiled nodes input masks
|
||||
starting_nodes_input = [
|
||||
("node1", {"input1": "value1"}),
|
||||
("node2", {"input1": "value2"}),
|
||||
]
|
||||
compiled_nodes_input_masks = {"node1": {"input1": "compiled_mask"}}
|
||||
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Mock user context
|
||||
mock_user_context = {"user_id": user_id, "context": "test_context"}
|
||||
|
||||
# Mock the queue and event bus
|
||||
mock_queue = mocker.AsyncMock()
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_get_user_context = mocker.patch("backend.executor.utils.get_user_context")
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_get_user_context.return_value = mock_user_context
|
||||
mock_get_queue.return_value = mock_queue
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
# Call the function - first execution
|
||||
result1 = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
# Store the parameters used in the first call to create_graph_execution
|
||||
first_call_kwargs = mock_edb.create_graph_execution.call_args[1]
|
||||
|
||||
# Verify the create_graph_execution was called with correct parameters
|
||||
mock_edb.create_graph_execution.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=mock_graph.version,
|
||||
inputs=inputs,
|
||||
credential_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
# Set up the graph execution mock to have properties we can extract
|
||||
mock_graph_exec.graph_id = graph_id
|
||||
mock_graph_exec.user_id = user_id
|
||||
mock_graph_exec.graph_version = graph_version
|
||||
mock_graph_exec.inputs = inputs
|
||||
mock_graph_exec.credential_inputs = graph_credentials_inputs
|
||||
mock_graph_exec.nodes_input_masks = nodes_input_masks
|
||||
mock_graph_exec.preset_id = preset_id
|
||||
|
||||
# Create a second mock execution for the sanity check
|
||||
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec_2.id = "execution-id-456"
|
||||
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Reset mocks and set up for second call
|
||||
mock_edb.create_graph_execution.reset_mock()
|
||||
mock_edb.create_graph_execution.return_value = mock_graph_exec_2
|
||||
mock_validate.reset_mock()
|
||||
|
||||
# Sanity check: call add_graph_execution with properties from first result
|
||||
# This should create the same execution parameters
|
||||
result2 = await add_graph_execution(
|
||||
graph_id=mock_graph_exec.graph_id,
|
||||
user_id=mock_graph_exec.user_id,
|
||||
inputs=mock_graph_exec.inputs,
|
||||
preset_id=mock_graph_exec.preset_id,
|
||||
graph_version=mock_graph_exec.graph_version,
|
||||
graph_credentials_inputs=mock_graph_exec.credential_inputs,
|
||||
nodes_input_masks=mock_graph_exec.nodes_input_masks,
|
||||
)
|
||||
|
||||
# Verify that create_graph_execution was called with identical parameters
|
||||
second_call_kwargs = mock_edb.create_graph_execution.call_args[1]
|
||||
|
||||
# The sanity check: both calls should use identical parameters
|
||||
assert first_call_kwargs == second_call_kwargs
|
||||
|
||||
# Both executions should succeed (though they create different objects)
|
||||
assert result1 == mock_graph_exec
|
||||
assert result2 == mock_graph_exec_2
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.integrations.oauth.todoist import TodoistOAuthHandler
|
||||
|
||||
from .discord import DiscordOAuthHandler
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
@@ -15,6 +16,7 @@ if TYPE_CHECKING:
|
||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||
# Build handlers dict with string keys for compatibility with SDK auto-registration
|
||||
_ORIGINAL_HANDLERS = [
|
||||
DiscordOAuthHandler,
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
|
||||
175
autogpt_platform/backend/backend/integrations/oauth/discord.py
Normal file
175
autogpt_platform/backend/backend/integrations/oauth/discord.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
|
||||
class DiscordOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Discord OAuth2 handler implementation.
|
||||
|
||||
Based on the documentation at:
|
||||
- https://discord.com/developers/docs/topics/oauth2
|
||||
|
||||
Discord OAuth2 tokens expire after 7 days by default and include refresh tokens.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.DISCORD
|
||||
DEFAULT_SCOPES = ["identify"] # Basic user information
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://discord.com/oauth2/authorize"
|
||||
self.token_url = "https://discord.com/api/oauth2/token"
|
||||
self.revoke_url = "https://discord.com/api/oauth2/token/revoke"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
# Handle default scopes
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
}
|
||||
|
||||
# Discord supports PKCE
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
params = {
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
|
||||
# Include PKCE verifier if provided
|
||||
if code_verifier:
|
||||
params["code_verifier"] = code_verifier
|
||||
|
||||
return await self._request_tokens(params)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
if not credentials.access_token:
|
||||
raise ValueError("No access token to revoke")
|
||||
|
||||
# Discord requires client authentication for token revocation
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
url=self.revoke_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
auth=(self.client_id, self.client_secret),
|
||||
)
|
||||
|
||||
# Discord returns 200 OK for successful revocation
|
||||
return response.status == 200
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
return credentials
|
||||
|
||||
return await self._request_tokens(
|
||||
{
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
current_credentials=credentials,
|
||||
)
|
||||
|
||||
async def _request_tokens(
|
||||
self,
|
||||
params: dict[str, str],
|
||||
current_credentials: Optional[OAuth2Credentials] = None,
|
||||
) -> OAuth2Credentials:
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
**params,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
self.token_url, data=request_body, headers=headers
|
||||
)
|
||||
token_data: dict = response.json()
|
||||
|
||||
# Get username if this is a new token request
|
||||
username = None
|
||||
if "access_token" in token_data:
|
||||
username = await self._request_username(token_data["access_token"])
|
||||
|
||||
now = int(time.time())
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=current_credentials.title if current_credentials else None,
|
||||
username=username,
|
||||
access_token=token_data["access_token"],
|
||||
scopes=token_data.get("scope", "").split()
|
||||
or (current_credentials.scopes if current_credentials else []),
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
# Discord tokens expire after expires_in seconds (typically 7 days)
|
||||
access_token_expires_at=(
|
||||
now + expires_in
|
||||
if (expires_in := token_data.get("expires_in", None))
|
||||
else None
|
||||
),
|
||||
# Discord doesn't provide separate refresh token expiration
|
||||
refresh_token_expires_at=None,
|
||||
)
|
||||
|
||||
if current_credentials:
|
||||
new_credentials.id = current_credentials.id
|
||||
|
||||
return new_credentials
|
||||
|
||||
async def _request_username(self, access_token: str) -> str | None:
|
||||
"""
|
||||
Fetch the username using the Discord OAuth2 @me endpoint.
|
||||
"""
|
||||
url = "https://discord.com/api/oauth2/@me"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
response = await Requests().get(url, headers=headers)
|
||||
|
||||
if not response.ok:
|
||||
return None
|
||||
|
||||
# Get user info from the response
|
||||
data = response.json()
|
||||
user_info = data.get("user", {})
|
||||
|
||||
# Return username (without discriminator)
|
||||
return user_info.get("username")
|
||||
@@ -29,7 +29,7 @@ from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.metrics import discord_send_alert
|
||||
from backend.util.metrics import DiscordChannel, discord_send_alert
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
@@ -382,8 +382,10 @@ class NotificationManager(AppService):
|
||||
}
|
||||
|
||||
@expose
|
||||
async def discord_system_alert(self, content: str):
|
||||
await discord_send_alert(content)
|
||||
async def discord_system_alert(
|
||||
self, content: str, channel: DiscordChannel = DiscordChannel.PLATFORM
|
||||
):
|
||||
await discord_send_alert(content, channel)
|
||||
|
||||
async def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
|
||||
"""Queue a scheduled notification - exposed method for other services to call"""
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
{# Agent Approved Notification Email Template #}
|
||||
{#
|
||||
Template variables:
|
||||
data.agent_name: the name of the approved agent
|
||||
data.agent_id: the ID of the agent
|
||||
data.agent_version: the version of the agent
|
||||
data.reviewer_name: the name of the reviewer who approved it
|
||||
data.reviewer_email: the email of the reviewer
|
||||
data.comments: comments from the reviewer
|
||||
data.reviewed_at: when the agent was reviewed
|
||||
data.store_url: URL to view the agent in the store
|
||||
|
||||
Subject: 🎉 Your agent '{{ data.agent_name }}' has been approved!
|
||||
#}
|
||||
|
||||
{% block content %}
|
||||
<h1 style="color: #28a745; font-size: 32px; font-weight: 700; margin: 0 0 24px 0; text-align: center;">
|
||||
🎉 Congratulations!
|
||||
</h1>
|
||||
|
||||
<p style="color: #586069; font-size: 18px; text-align: center; margin: 0 0 24px 0;">
|
||||
Your agent <strong>'{{ data.agent_name }}'</strong> has been approved and is now live in the store!
|
||||
</p>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
{% if data.comments %}
|
||||
<div style="background: #d4edda; border: 1px solid #c3e6cb; border-radius: 8px; padding: 20px; margin: 0;">
|
||||
<h3 style="color: #155724; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
||||
💬 Creator feedback area
|
||||
</h3>
|
||||
<p style="color: #155724; margin: 0; font-size: 16px; line-height: 1.5;">
|
||||
{{ data.comments }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="height: 40px; background: transparent;"></div>
|
||||
{% endif %}
|
||||
|
||||
<div style="background: #d1ecf1; border: 1px solid #bee5eb; border-radius: 8px; padding: 20px; margin: 0;">
|
||||
<h3 style="color: #0c5460; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
||||
What's Next?
|
||||
</h3>
|
||||
<ul style="color: #0c5460; margin: 0; padding-left: 18px; font-size: 16px; line-height: 1.6;">
|
||||
<li>Your agent is now live and discoverable in the AutoGPT Store</li>
|
||||
<li>Users can find, install, and run your agent</li>
|
||||
<li>You can update your agent anytime by submitting a new version</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<div style="text-align: center; margin: 24px 0;">
|
||||
<a href="{{ data.store_url }}" style="display: inline-block; background: linear-gradient(135deg, #7c3aed 0%, #5b21b6 100%); color: black; text-decoration: none; padding: 14px 28px; border-radius: 6px; font-weight: 600; font-size: 16px;">
|
||||
View Your Agent in Store
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<div style="background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 6px; padding: 16px; margin: 24px 0; text-align: center;">
|
||||
<p style="margin: 0; color: #856404; font-size: 14px;">
|
||||
<strong>💡 Pro Tip:</strong> Share your agent with the community! Post about it on social media, forums, or your blog to help more users discover and benefit from your creation.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<p style="color: #6a737d; font-size: 14px; text-align: center; margin: 24px 0;">
|
||||
Thank you for contributing to the AutoGPT ecosystem! 🚀
|
||||
</p>
|
||||
|
||||
{% endblock %}
|
||||
@@ -0,0 +1,77 @@
|
||||
{# Agent Rejected Notification Email Template #}
|
||||
{#
|
||||
Template variables:
|
||||
data.agent_name: the name of the rejected agent
|
||||
data.agent_id: the ID of the agent
|
||||
data.agent_version: the version of the agent
|
||||
data.reviewer_name: the name of the reviewer who rejected it
|
||||
data.reviewer_email: the email of the reviewer
|
||||
data.comments: comments from the reviewer explaining the rejection
|
||||
data.reviewed_at: when the agent was reviewed
|
||||
data.resubmit_url: URL to resubmit the agent
|
||||
|
||||
Subject: Your agent '{{ data.agent_name }}' needs some updates
|
||||
#}
|
||||
|
||||
|
||||
{% block content %}
|
||||
<h1 style="color: #d73a49; font-size: 32px; font-weight: 700; margin: 0 0 24px 0; text-align: center;">
|
||||
📝 Review Complete
|
||||
</h1>
|
||||
|
||||
<p style="color: #586069; font-size: 18px; text-align: center; margin: 0 0 24px 0;">
|
||||
Your agent <strong>'{{ data.agent_name }}'</strong> needs some updates before approval.
|
||||
</p>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<div style="background: #f8d7da; border: 1px solid #f5c6cb; border-radius: 8px; padding: 20px; margin: 0 0 24px 0;">
|
||||
<h3 style="color: #721c24; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
||||
💬 Creator feedback area
|
||||
</h3>
|
||||
<p style="color: #721c24; margin: 0; font-size: 16px; line-height: 1.5;">
|
||||
{{ data.comments }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="height: 40px; background: transparent;"></div>
|
||||
|
||||
<div style="background: #d4edda; border: 1px solid #c3e6cb; border-radius: 8px; padding: 20px; margin: 0 0 24px 0;">
|
||||
<h3 style="color: #155724; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
||||
☑ Steps to Resubmit:
|
||||
</h3>
|
||||
<ul style="color: #155724; margin: 0; padding-left: 18px; font-size: 16px; line-height: 1.6;">
|
||||
<li>Review the feedback provided above carefully</li>
|
||||
<li>Make the necessary updates to your agent</li>
|
||||
<li>Test your agent thoroughly to ensure it works as expected</li>
|
||||
<li>Submit your updated agent for review</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<div style="background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 6px; padding: 12px; margin: 0 0 24px 0; text-align: center;">
|
||||
<p style="margin: 0; color: #856404; font-size: 14px;">
|
||||
<strong>💡 Tip:</strong> Address all the points mentioned in the feedback to increase your chances of approval in the next review.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="text-align: center; margin: 32px 0;">
|
||||
<a href="{{ data.resubmit_url }}" style="display: inline-block; background: linear-gradient(135deg, #7c3aed 0%, #5b21b6 100%); color: black; text-decoration: none; padding: 14px 28px; border-radius: 6px; font-weight: 600; font-size: 16px;">
|
||||
Update & Resubmit Agent
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div style="background: #d1ecf1; border: 1px solid #bee5eb; border-radius: 6px; padding: 16px; margin: 24px 0;">
|
||||
<p style="margin: 0; color: #0c5460; font-size: 14px; text-align: center;">
|
||||
<strong>🌟 Don't Give Up!</strong> Many successful agents go through multiple iterations before approval. Our review team is here to help you succeed!
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="height: 32px; background: transparent;"></div>
|
||||
|
||||
<p style="color: #6a737d; font-size: 14px; text-align: center; margin: 32px 0 24px 0;">
|
||||
We're excited to see your improved agent submission! 🚀
|
||||
</p>
|
||||
|
||||
{% endblock %}
|
||||
@@ -1,9 +1,7 @@
|
||||
{# Low Balance Notification Email Template #}
|
||||
{# Template variables:
|
||||
data.agent_name: the name of the agent
|
||||
data.current_balance: the current balance of the user
|
||||
data.billing_page_link: the link to the billing page
|
||||
data.shortfall: the shortfall amount
|
||||
#}
|
||||
|
||||
<p style="
|
||||
@@ -25,7 +23,7 @@ data.shortfall: the shortfall amount
|
||||
margin-top: 0;
|
||||
margin-bottom: 20px;
|
||||
">
|
||||
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
|
||||
Your account balance has dropped below the recommended threshold.
|
||||
</p>
|
||||
|
||||
<div style="
|
||||
@@ -44,15 +42,6 @@ data.shortfall: the shortfall amount
|
||||
">
|
||||
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
|
||||
</p>
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -79,7 +68,7 @@ data.shortfall: the shortfall amount
|
||||
margin-top: 0;
|
||||
margin-bottom: 5px;
|
||||
">
|
||||
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
|
||||
Your account requires additional credits to continue running agents. Please add credits to your account to avoid service interruption.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@@ -110,5 +99,5 @@ data.shortfall: the shortfall amount
|
||||
margin-bottom: 10px;
|
||||
font-style: italic;
|
||||
">
|
||||
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
|
||||
This is an automated low balance notification. Consider adding credits soon to avoid service interruption.
|
||||
</p>
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
{# Low Balance Notification Email Template #}
|
||||
{# Template variables:
|
||||
data.agent_name: the name of the agent
|
||||
data.current_balance: the current balance of the user
|
||||
data.billing_page_link: the link to the billing page
|
||||
data.shortfall: the shortfall amount
|
||||
#}
|
||||
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
line-height: 165%;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Zero Balance Warning</strong>
|
||||
</p>
|
||||
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
line-height: 165%;
|
||||
margin-top: 0;
|
||||
margin-bottom: 20px;
|
||||
">
|
||||
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
|
||||
</p>
|
||||
|
||||
<div style="
|
||||
margin-left: 15px;
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #5D23BB;
|
||||
background-color: #f8f8ff;
|
||||
">
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
|
||||
</p>
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div style="
|
||||
margin-left: 15px;
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #FF6B6B;
|
||||
background-color: #FFF0F0;
|
||||
">
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Low Balance:</strong>
|
||||
</p>
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 5px;
|
||||
">
|
||||
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="
|
||||
text-align: center;
|
||||
margin: 30px 0;
|
||||
">
|
||||
<a href="{{ data.billing_page_link }}" style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
background-color: #5D23BB;
|
||||
color: white;
|
||||
padding: 12px 24px;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
display: inline-block;
|
||||
">
|
||||
Manage Billing
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
line-height: 150%;
|
||||
margin-top: 30px;
|
||||
margin-bottom: 10px;
|
||||
font-style: italic;
|
||||
">
|
||||
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
|
||||
</p>
|
||||
@@ -63,7 +63,7 @@ except ImportError:
|
||||
|
||||
# Cost System
|
||||
try:
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.block import BlockCost, BlockCostType
|
||||
except ImportError:
|
||||
from backend.data.block_cost_config import BlockCost, BlockCostType
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Callable, List, Optional, Type
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.block import BlockCost, BlockCostType
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
|
||||
@@ -8,9 +8,8 @@ BLOCK_COSTS configuration used by the execution system.
|
||||
import logging
|
||||
from typing import List, Type
|
||||
|
||||
from backend.data.block import Block
|
||||
from backend.data.block import Block, BlockCost
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -7,13 +7,15 @@ from typing import Any, Callable, List, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.block import BlockCost
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
CredentialsType,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
@@ -104,14 +106,39 @@ class Provider:
|
||||
)
|
||||
|
||||
def get_test_credentials(self) -> Credentials:
|
||||
"""Get test credentials for the provider."""
|
||||
return APIKeyCredentials(
|
||||
id=str(self.test_credentials_uuid),
|
||||
provider=self.name,
|
||||
api_key=SecretStr("mock-api-key"),
|
||||
title=f"Mock {self.name.title()} API key",
|
||||
expires_at=None,
|
||||
)
|
||||
"""Get test credentials for the provider based on supported auth types."""
|
||||
test_id = str(self.test_credentials_uuid)
|
||||
|
||||
# Return credentials based on the first supported auth type
|
||||
if "user_password" in self.supported_auth_types:
|
||||
return UserPasswordCredentials(
|
||||
id=test_id,
|
||||
provider=self.name,
|
||||
username=SecretStr(f"mock-{self.name}-username"),
|
||||
password=SecretStr(f"mock-{self.name}-password"),
|
||||
title=f"Mock {self.name.title()} credentials",
|
||||
)
|
||||
elif "oauth2" in self.supported_auth_types:
|
||||
return OAuth2Credentials(
|
||||
id=test_id,
|
||||
provider=self.name,
|
||||
username=f"mock-{self.name}-username",
|
||||
access_token=SecretStr(f"mock-{self.name}-access-token"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token=SecretStr(f"mock-{self.name}-refresh-token"),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[f"mock-{self.name}-scope"],
|
||||
title=f"Mock {self.name.title()} OAuth credentials",
|
||||
)
|
||||
else:
|
||||
# Default to API key credentials
|
||||
return APIKeyCredentials(
|
||||
id=test_id,
|
||||
provider=self.name,
|
||||
api_key=SecretStr(f"mock-{self.name}-api-key"),
|
||||
title=f"Mock {self.name.title()} API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
def get_api(self, credentials: Credentials) -> Any:
|
||||
"""Get API client instance for the given credentials."""
|
||||
|
||||
@@ -11,7 +11,45 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
return snapshot
|
||||
|
||||
|
||||
# Test ID constants
|
||||
TEST_USER_ID = "test-user-id"
|
||||
ADMIN_USER_ID = "admin-user-id"
|
||||
TARGET_USER_ID = "target-user-id"
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "test-user-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "admin-user-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "target-user-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
import fastapi
|
||||
|
||||
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
|
||||
return {"sub": test_user_id, "role": "user", "email": "test@example.com"}
|
||||
|
||||
return {"get_jwt_payload": override_get_jwt_payload, "user_id": test_user_id}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_admin(admin_user_id):
|
||||
"""Provide mock JWT payload for admin user testing."""
|
||||
import fastapi
|
||||
|
||||
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
|
||||
return {
|
||||
"sub": admin_user_id,
|
||||
"role": "admin",
|
||||
"email": "test-admin@example.com",
|
||||
}
|
||||
|
||||
return {"get_jwt_payload": override_get_jwt_payload, "user_id": admin_user_id}
|
||||
|
||||
@@ -88,6 +88,7 @@ async def test_send_graph_execution_result(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=datetime.now(tz=timezone.utc),
|
||||
@@ -101,6 +102,8 @@ async def test_send_graph_execution_result(
|
||||
"input_1": "some input value :)",
|
||||
"input_2": "some *other* input value",
|
||||
},
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
outputs={
|
||||
"the_output": ["some output value"],
|
||||
"other_output": ["sike there was another output"],
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.api_key import has_permission, validate_api_key
|
||||
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key")
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(request: Request):
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Base middleware for API key authentication"""
|
||||
api_key = await api_key_header(request)
|
||||
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
@@ -19,18 +17,19 @@ async def require_api_key(request: Request):
|
||||
if not api_key_obj:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
request.state.api_key = api_key_obj
|
||||
return api_key_obj
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""Dependency function for checking specific permissions"""
|
||||
|
||||
async def check_permission(api_key=Depends(require_api_key)):
|
||||
async def check_permission(
|
||||
api_key: APIKeyInfo = Security(require_api_key),
|
||||
) -> APIKeyInfo:
|
||||
if not has_permission(api_key, permission):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"API key missing required permission: {permission}",
|
||||
detail=f"API key lacks the required permission '{permission}'",
|
||||
)
|
||||
return api_key
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user