mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 08:38:09 -05:00
Compare commits
139 Commits
update-ins
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a61b45644 | ||
|
|
1d207a9b52 | ||
|
|
7f01df5bee | ||
|
|
634f826d82 | ||
|
|
6d6bf308fc | ||
|
|
dd84fb5c66 | ||
|
|
33679f3ffe | ||
|
|
fc8c5ccbb6 | ||
|
|
7d2ab61546 | ||
|
|
c2f11dbcfa | ||
|
|
f82adeb959 | ||
|
|
6f08a1cca7 | ||
|
|
1ddf92eed4 | ||
|
|
4c0dd27157 | ||
|
|
17fcf68f2e | ||
|
|
381558342a | ||
|
|
1fdc02467b | ||
|
|
f262bb9307 | ||
|
|
5a6978b07d | ||
|
|
339ec733cb | ||
|
|
6575b655f0 | ||
|
|
7c2df24d7c | ||
|
|
23eafa178c | ||
|
|
27fccdbf31 | ||
|
|
fb8fbc9d1f | ||
|
|
6a86e70fd6 | ||
|
|
6a2d7e0fb0 | ||
|
|
3d6ea3088e | ||
|
|
64b4480b1e | ||
|
|
f490b01abb | ||
|
|
e56a4a135d | ||
|
|
e70c970ab6 | ||
|
|
3bbce71678 | ||
|
|
34fbf4377f | ||
|
|
f682ef885a | ||
|
|
2ffd249aac | ||
|
|
986245ec43 | ||
|
|
f89717153f | ||
|
|
5da41e0753 | ||
|
|
cddeb185a8 | ||
|
|
08a3fd6d26 | ||
|
|
39b30bc82c | ||
|
|
2df0e2b750 | ||
|
|
925f249ce1 | ||
|
|
e8cf3edbf4 | ||
|
|
dc03ea718c | ||
|
|
dbee580d80 | ||
|
|
0325ec0a2c | ||
|
|
3952a1a226 | ||
|
|
cfc975d39b | ||
|
|
46e0f6cc45 | ||
|
|
c03af5c196 | ||
|
|
00cbfb8f80 | ||
|
|
3beafae955 | ||
|
|
9cd186a2f3 | ||
|
|
dcf26bd3d4 | ||
|
|
b97f097c9d | ||
|
|
ce24975a9d | ||
|
|
75c90e49ce | ||
|
|
2e38f132e7 | ||
|
|
5be6987d58 | ||
|
|
b59592be9b | ||
|
|
7ea17df9ed | ||
|
|
22e692bdda | ||
|
|
901e9eba5d | ||
|
|
bbd6709bd6 | ||
|
|
3ec1721d6d | ||
|
|
8c8a2ab0c2 | ||
|
|
4041e1f39c | ||
|
|
7cbdc1ad1a | ||
|
|
2a19aa0ed3 | ||
|
|
6d39dfe382 | ||
|
|
57ecc10535 | ||
|
|
4928ce3f90 | ||
|
|
e16e69ca55 | ||
|
|
4bcc73f784 | ||
|
|
c8240a4d6b | ||
|
|
f669db4a10 | ||
|
|
0e755a5c85 | ||
|
|
dfdc71f97f | ||
|
|
def008408c | ||
|
|
916d0adabb | ||
|
|
417ee7f0e1 | ||
|
|
ae4c9897b4 | ||
|
|
7544028b94 | ||
|
|
ad49946890 | ||
|
|
b333d56492 | ||
|
|
b23cb14e49 | ||
|
|
e71a44521a | ||
|
|
15aa091f65 | ||
|
|
3e4ca19036 | ||
|
|
a595da02f7 | ||
|
|
12cdd45551 | ||
|
|
df3c81a7a6 | ||
|
|
0f477e2392 | ||
|
|
82618aede0 | ||
|
|
b713093276 | ||
|
|
3718b948ea | ||
|
|
44d739386b | ||
|
|
533d2d0277 | ||
|
|
c6821484c7 | ||
|
|
fecbd3042d | ||
|
|
c0172c93aa | ||
|
|
8a68e03eb1 | ||
|
|
da585a34e1 | ||
|
|
bc43d05cac | ||
|
|
469b1fccbb | ||
|
|
1aa7e10cbd | ||
|
|
890bb3b8b4 | ||
|
|
2bb8e91040 | ||
|
|
76090f0ba2 | ||
|
|
5b12e02c4e | ||
|
|
e0520f5e0a | ||
|
|
a9530b7304 | ||
|
|
8a9c165faf | ||
|
|
476bfc6c84 | ||
|
|
61f17e5b97 | ||
|
|
a54bed6d68 | ||
|
|
5502256bea | ||
|
|
bd97727763 | ||
|
|
aa256f21cd | ||
|
|
2848e62f8a | ||
|
|
fa5ff9ca3c | ||
|
|
7c908c10b8 | ||
|
|
68cd1cb398 | ||
|
|
f4538d6f5a | ||
|
|
4589b15450 | ||
|
|
ccc4d0dc6c | ||
|
|
c4483fa6c7 | ||
|
|
c2af8c1a6a | ||
|
|
2610c4579f | ||
|
|
0c09b0c459 | ||
|
|
1105e6c0d2 | ||
|
|
483c399812 | ||
|
|
260dd526c9 | ||
|
|
75a159db01 | ||
|
|
62032e6584 | ||
|
|
105d5dc7e9 | ||
|
|
92515b3683 |
244
.github/copilot-instructions.md
vendored
Normal file
244
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,244 @@
|
||||
# GitHub Copilot Instructions for AutoGPT
|
||||
|
||||
This file provides comprehensive onboarding information for GitHub Copilot coding agent to work efficiently with the AutoGPT repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
**AutoGPT** is a powerful platform for creating, deploying, and managing continuous AI agents that automate complex workflows. This is a large monorepo (~150MB) containing multiple components:
|
||||
|
||||
- **AutoGPT Platform** (`autogpt_platform/`) - Main focus: Modern AI agent platform (Polyform Shield License)
|
||||
- **Classic AutoGPT** (`classic/`) - Legacy agent system (MIT License)
|
||||
- **Documentation** (`docs/`) - MkDocs-based documentation site
|
||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
||||
|
||||
**Primary Languages & Frameworks:**
|
||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
||||
|
||||
## Build and Validation Instructions
|
||||
|
||||
### Essential Setup Commands
|
||||
|
||||
**Always run these commands in the correct directory and in this order:**
|
||||
|
||||
1. **Initial Setup** (required once):
|
||||
```bash
|
||||
# Clone and enter repository
|
||||
git clone <repo> && cd AutoGPT
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
2. **Backend Setup** (always run before backend development):
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry install # Install dependencies
|
||||
poetry run prisma migrate dev # Run database migrations
|
||||
poetry run prisma generate # Generate Prisma client
|
||||
```
|
||||
|
||||
3. **Frontend Setup** (always run before frontend development):
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm install # Install dependencies
|
||||
```
|
||||
|
||||
### Runtime Requirements
|
||||
|
||||
**Critical:** Always ensure Docker services are running before starting development:
|
||||
```bash
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
**Python Version:** Use Python 3.11 (required; managed by Poetry via pyproject.toml)
|
||||
**Node.js Version:** Use Node.js 21+ with pnpm package manager
|
||||
|
||||
### Development Commands
|
||||
|
||||
**Backend Development:**
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run serve # Start development server (port 8000)
|
||||
poetry run test # Run all tests (requires ~5 minutes)
|
||||
poetry run pytest path/to/test.py # Run specific test
|
||||
poetry run format # Format code (Black + isort) - always run first
|
||||
poetry run lint # Lint code (ruff) - run after format
|
||||
```
|
||||
|
||||
**Frontend Development:**
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm dev # Start development server (port 3000) - use for active development
|
||||
pnpm build # Build for production (only needed for E2E tests or deployment)
|
||||
pnpm test # Run Playwright E2E tests (requires build first)
|
||||
pnpm test-ui # Run tests with UI
|
||||
pnpm format # Format and lint code
|
||||
pnpm storybook # Start component development server
|
||||
```
|
||||
|
||||
### Testing Strategy
|
||||
|
||||
**Backend Tests:**
|
||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
||||
|
||||
**Frontend Tests:**
|
||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
||||
- **Component Tests**: Use Storybook for isolated component development
|
||||
|
||||
### Critical Validation Steps
|
||||
|
||||
**Before committing changes:**
|
||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
||||
2. Ensure all tests pass in modified areas
|
||||
3. Verify Docker services are still running
|
||||
4. Check that database migrations apply cleanly
|
||||
|
||||
**Common Issues & Workarounds:**
|
||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
||||
- **Permission errors**: Ensure Docker has proper permissions
|
||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
||||
- **Test timeouts**: Backend tests can take 5+ minutes, use `-x` flag to stop on first failure
|
||||
|
||||
## Project Layout & Architecture
|
||||
|
||||
### Core Architecture
|
||||
|
||||
**AutoGPT Platform** (`autogpt_platform/`):
|
||||
- `backend/` - FastAPI server with async support
|
||||
- `backend/backend/` - Core API logic
|
||||
- `backend/blocks/` - Agent execution blocks
|
||||
- `backend/data/` - Database models and schemas
|
||||
- `schema.prisma` - Database schema definition
|
||||
- `frontend/` - Next.js application
|
||||
- `src/app/` - App Router pages and layouts
|
||||
- `src/components/` - Reusable React components
|
||||
- `src/lib/` - Utilities and configurations
|
||||
- `autogpt_libs/` - Shared Python utilities
|
||||
- `docker-compose.yml` - Development stack orchestration
|
||||
|
||||
**Key Configuration Files:**
|
||||
- `pyproject.toml` - Python dependencies and tooling
|
||||
- `package.json` - Node.js dependencies and scripts
|
||||
- `schema.prisma` - Database schema and migrations
|
||||
- `next.config.mjs` - Next.js configuration
|
||||
- `tailwind.config.ts` - Styling configuration
|
||||
|
||||
### Security & Middleware
|
||||
|
||||
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
|
||||
**Authentication**: JWT-based with Supabase integration
|
||||
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
|
||||
|
||||
### Development Workflow
|
||||
|
||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
||||
- `platform-backend-ci.yml` - Backend testing and validation
|
||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
||||
|
||||
**Pre-commit Hooks**: Run linting and formatting checks
|
||||
**Conventional Commits**: Use format `type(scope): description` (e.g., `feat(backend): add API`)
|
||||
|
||||
### Key Source Files
|
||||
|
||||
**Backend Entry Points:**
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
**Frontend Entry Points:**
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
|
||||
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
|
||||
|
||||
### Agent Block System
|
||||
|
||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
||||
- Block definition with input/output schemas
|
||||
- Execution logic with proper error handling
|
||||
- Tests validating functionality
|
||||
|
||||
### Database & ORM
|
||||
|
||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
||||
- Schema in `schema.prisma`
|
||||
- Migrations in `backend/migrations/`
|
||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
### Configuration Files Priority Order
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
### Docker Environment Setup
|
||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Copy `.env.default` files to `.env` for local development customization
|
||||
|
||||
## Advanced Development Patterns
|
||||
|
||||
### Adding New Blocks
|
||||
1. Create file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class with input/output schemas
|
||||
3. Implement `run` method with proper error handling
|
||||
4. Generate block UUID using `uuid.uuid4()`
|
||||
5. Register in block registry
|
||||
6. Write tests alongside block implementation
|
||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
||||
|
||||
### API Development
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
4. For `data/*.py` changes, validate user ID checks
|
||||
5. Run `poetry run test` to verify changes
|
||||
|
||||
### Frontend Development
|
||||
1. Components in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for component development
|
||||
4. Test user-facing features with Playwright E2E tests
|
||||
5. Update protected routes in middleware when needed
|
||||
|
||||
### Security Guidelines
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
- Prevents sensitive data caching in browsers/proxies
|
||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
||||
|
||||
### CI/CD Alignment
|
||||
The repository has comprehensive CI workflows that test:
|
||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
||||
- **Integration**: Full-stack type checking and E2E testing
|
||||
|
||||
Match these patterns when developing locally - the copilot setup environment mirrors these CI configurations.
|
||||
|
||||
## Collaboration with Other AI Assistants
|
||||
|
||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
||||
- Check for existing CLAUDE.md files that provide additional context
|
||||
- Follow established patterns and conventions already in the codebase
|
||||
- Maintain consistency with existing code style and architecture
|
||||
- Consider that changes may be reviewed and extended by both human developers and AI assistants
|
||||
|
||||
## Trust These Instructions
|
||||
|
||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
||||
1. Information here is incomplete for your specific task
|
||||
2. You encounter errors not covered by the workarounds
|
||||
3. You need to understand implementation details not covered above
|
||||
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
Normal file
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
Normal file
@@ -0,0 +1,97 @@
|
||||
name: Auto Fix CI Failures
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["CI"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
actions: read
|
||||
issues: write
|
||||
id-token: write # Required for OIDC token exchange
|
||||
|
||||
jobs:
|
||||
auto-fix:
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'failure' &&
|
||||
github.event.workflow_run.pull_requests[0] &&
|
||||
!startsWith(github.event.workflow_run.head_branch, 'claude-auto-fix-ci-')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.workflow_run.head_branch }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Setup git identity
|
||||
run: |
|
||||
git config --global user.email "claude[bot]@users.noreply.github.com"
|
||||
git config --global user.name "claude[bot]"
|
||||
|
||||
- name: Create fix branch
|
||||
id: branch
|
||||
run: |
|
||||
BRANCH_NAME="claude-auto-fix-ci-${{ github.event.workflow_run.head_branch }}-${{ github.run_id }}"
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const jobs = await github.rest.actions.listJobsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const failedJobs = jobs.data.jobs.filter(job => job.conclusion === 'failure');
|
||||
|
||||
let errorLogs = [];
|
||||
for (const job of failedJobs) {
|
||||
const logs = await github.rest.actions.downloadJobLogsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
job_id: job.id
|
||||
});
|
||||
errorLogs.push({
|
||||
jobName: job.name,
|
||||
logs: logs.data
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
runUrl: run.data.html_url,
|
||||
failedJobs: failedJobs.map(j => j.name),
|
||||
errorLogs: errorLogs
|
||||
};
|
||||
|
||||
- name: Fix CI failures with Claude
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
prompt: |
|
||||
/fix-ci
|
||||
Failed CI Run: ${{ fromJSON(steps.failure_details.outputs.result).runUrl }}
|
||||
Failed Jobs: ${{ join(fromJSON(steps.failure_details.outputs.result).failedJobs, ', ') }}
|
||||
PR Number: ${{ github.event.workflow_run.pull_requests[0].number }}
|
||||
Branch Name: ${{ steps.branch.outputs.branch_name }}
|
||||
Base Branch: ${{ github.event.workflow_run.head_branch }}
|
||||
Repository: ${{ github.repository }}
|
||||
|
||||
Error logs:
|
||||
${{ toJSON(fromJSON(steps.failure_details.outputs.result).errorLogs) }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: "--allowedTools 'Edit,MultiEdit,Write,Read,Glob,Grep,LS,Bash(git:*),Bash(bun:*),Bash(npm:*),Bash(npx:*),Bash(gh:*)'"
|
||||
379
.github/workflows/claude-dependabot.yml
vendored
Normal file
379
.github/workflows/claude-dependabot.yml
vendored
Normal file
@@ -0,0 +1,379 @@
|
||||
# Claude Dependabot PR Review Workflow
|
||||
#
|
||||
# This workflow automatically runs Claude analysis on Dependabot PRs to:
|
||||
# - Identify dependency changes and their versions
|
||||
# - Look up changelogs for updated packages
|
||||
# - Assess breaking changes and security impacts
|
||||
# - Provide actionable recommendations for the development team
|
||||
#
|
||||
# Triggered on: Dependabot PRs (opened, synchronize)
|
||||
# Requirements: ANTHROPIC_API_KEY secret must be configured
|
||||
|
||||
name: Claude Dependabot PR Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
|
||||
jobs:
|
||||
dependabot-review:
|
||||
# Only run on Dependabot PRs
|
||||
if: github.actor == 'dependabot[bot]'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
|
||||
- name: Run Claude Dependabot Analysis
|
||||
id: claude_review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
You are Claude, an AI assistant specialized in reviewing Dependabot dependency update PRs.
|
||||
|
||||
Your primary tasks are:
|
||||
1. **Analyze the dependency changes** in this Dependabot PR
|
||||
2. **Look up changelogs** for all updated dependencies to understand what changed
|
||||
3. **Identify breaking changes** and assess potential impact on the AutoGPT codebase
|
||||
4. **Provide actionable recommendations** for the development team
|
||||
|
||||
## Analysis Process:
|
||||
|
||||
1. **Identify Changed Dependencies**:
|
||||
- Use git diff to see what dependencies were updated
|
||||
- Parse package.json, poetry.lock, requirements files, etc.
|
||||
- List all package versions: old → new
|
||||
|
||||
2. **Changelog Research**:
|
||||
- For each updated dependency, look up its changelog/release notes
|
||||
- Use WebFetch to access GitHub releases, NPM package pages, PyPI project pages. The pr should also have some details
|
||||
- Focus on versions between the old and new versions
|
||||
- Identify: breaking changes, deprecations, security fixes, new features
|
||||
|
||||
3. **Breaking Change Assessment**:
|
||||
- Categorize changes: BREAKING, MAJOR, MINOR, PATCH, SECURITY
|
||||
- Assess impact on AutoGPT's usage patterns
|
||||
- Check if AutoGPT uses affected APIs/features
|
||||
- Look for migration guides or upgrade instructions
|
||||
|
||||
4. **Codebase Impact Analysis**:
|
||||
- Search the AutoGPT codebase for usage of changed APIs
|
||||
- Identify files that might be affected by breaking changes
|
||||
- Check test files for deprecated usage patterns
|
||||
- Look for configuration changes needed
|
||||
|
||||
## Output Format:
|
||||
|
||||
Provide a comprehensive review comment with:
|
||||
|
||||
### 🔍 Dependency Analysis Summary
|
||||
- List of updated packages with version changes
|
||||
- Overall risk assessment (LOW/MEDIUM/HIGH)
|
||||
|
||||
### 📋 Detailed Changelog Review
|
||||
For each updated dependency:
|
||||
- **Package**: name (old_version → new_version)
|
||||
- **Changes**: Summary of key changes
|
||||
- **Breaking Changes**: List any breaking changes
|
||||
- **Security Fixes**: Note security improvements
|
||||
- **Migration Notes**: Any upgrade steps needed
|
||||
|
||||
### ⚠️ Impact Assessment
|
||||
- **Breaking Changes Found**: Yes/No with details
|
||||
- **Affected Files**: List AutoGPT files that may need updates
|
||||
- **Test Impact**: Any tests that may need updating
|
||||
- **Configuration Changes**: Required config updates
|
||||
|
||||
### 🛠️ Recommendations
|
||||
- **Action Required**: What the team should do
|
||||
- **Testing Focus**: Areas to test thoroughly
|
||||
- **Follow-up Tasks**: Any additional work needed
|
||||
- **Merge Recommendation**: APPROVE/REVIEW_NEEDED/HOLD
|
||||
|
||||
### 📚 Useful Links
|
||||
- Links to relevant changelogs, migration guides, documentation
|
||||
|
||||
Be thorough but concise. Focus on actionable insights that help the development team make informed decisions about the dependency updates.
|
||||
284
.github/workflows/claude.yml
vendored
284
.github/workflows/claude.yml
vendored
@@ -30,18 +30,296 @@ jobs:
|
||||
github.event.issue.author_association == 'COLLABORATOR'
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
contents: write
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@beta
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
|
||||
--model opus
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
302
.github/workflows/copilot-setup-steps.yml
vendored
Normal file
302
.github/workflows/copilot-setup-steps.yml
vendored
Normal file
@@ -0,0 +1,302 @@
|
||||
name: "Copilot Setup Steps"
|
||||
|
||||
# Automatically run the setup steps when they are changed to allow for easy validation, and
|
||||
# allow manual testing through the repository's "Actions" tab
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
paths:
|
||||
- .github/workflows/copilot-setup-steps.yml
|
||||
pull_request:
|
||||
paths:
|
||||
- .github/workflows/copilot-setup-steps.yml
|
||||
|
||||
jobs:
|
||||
# The job MUST be called `copilot-setup-steps` or it will not be picked up by Copilot.
|
||||
copilot-setup-steps:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
# Set the permissions to the lowest permissions possible needed for your steps.
|
||||
# Copilot will be given its own token for its operations.
|
||||
permissions:
|
||||
# If you want to clone the repository as part of your setup steps, for example to install dependencies, you'll need the `contents: read` permission. If you don't clone the repository in your setup steps, Copilot will do this for you automatically after the steps complete.
|
||||
contents: read
|
||||
|
||||
# You can define any steps you want, and they will run before the agent starts.
|
||||
# If you do not check out your code, Copilot will do this for you.
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
@@ -5,6 +5,13 @@ on:
|
||||
branches: [ dev ]
|
||||
paths:
|
||||
- 'autogpt_platform/**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_ref:
|
||||
description: 'Git ref (branch/tag) of AutoGPT to deploy'
|
||||
required: true
|
||||
default: 'master'
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -19,6 +26,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -48,4 +57,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_dev
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: '{"ref": "${{ github.event.inputs.git_ref || github.ref }}", "repository": "${{ github.repository }}"}'
|
||||
|
||||
@@ -3,6 +3,7 @@ name: AutoGPT Platform - Deploy Prod Environment
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -17,6 +18,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name || 'master' }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -36,7 +39,7 @@ jobs:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
runs-on: ubuntu-latest
|
||||
@@ -47,4 +50,5 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_prod
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: |
|
||||
{"ref": "${{ github.ref_name || 'master' }}", "repository": "${{ github.repository }}"}
|
||||
4
.github/workflows/platform-backend-ci.yml
vendored
4
.github/workflows/platform-backend-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
@@ -201,7 +201,7 @@ jobs:
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
|
||||
113
.github/workflows/platform-container-publish.yml
vendored
Normal file
113
.github/workflows/platform-container-publish.yml
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
name: Platform - Container Publishing
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
no_cache:
|
||||
type: boolean
|
||||
description: 'Build from scratch, without using cached layers'
|
||||
default: false
|
||||
registry:
|
||||
type: choice
|
||||
description: 'Container registry to publish to'
|
||||
options:
|
||||
- 'both'
|
||||
- 'ghcr'
|
||||
- 'dockerhub'
|
||||
default: 'both'
|
||||
|
||||
env:
|
||||
GHCR_REGISTRY: ghcr.io
|
||||
GHCR_IMAGE_BASE: ${{ github.repository_owner }}/autogpt-platform
|
||||
DOCKERHUB_IMAGE_BASE: ${{ secrets.DOCKER_USER }}/autogpt-platform
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
build-and-publish:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
component: [backend, frontend]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
if: inputs.registry == 'both' || inputs.registry == 'ghcr' || github.event_name == 'release'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.GHCR_REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
if: (inputs.registry == 'both' || inputs.registry == 'dockerhub' || github.event_name == 'release') && secrets.DOCKER_USER
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_BASE }}-${{ matrix.component }}
|
||||
${{ secrets.DOCKER_USER && format('{0}-{1}', env.DOCKERHUB_IMAGE_BASE, matrix.component) || '' }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
- name: Set build context and dockerfile for backend
|
||||
if: matrix.component == 'backend'
|
||||
run: |
|
||||
echo "BUILD_CONTEXT=." >> $GITHUB_ENV
|
||||
echo "DOCKERFILE=autogpt_platform/backend/Dockerfile" >> $GITHUB_ENV
|
||||
echo "BUILD_TARGET=server" >> $GITHUB_ENV
|
||||
|
||||
- name: Set build context and dockerfile for frontend
|
||||
if: matrix.component == 'frontend'
|
||||
run: |
|
||||
echo "BUILD_CONTEXT=." >> $GITHUB_ENV
|
||||
echo "DOCKERFILE=autogpt_platform/frontend/Dockerfile" >> $GITHUB_ENV
|
||||
echo "BUILD_TARGET=prod" >> $GITHUB_ENV
|
||||
|
||||
- name: Build and push container image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ${{ env.BUILD_CONTEXT }}
|
||||
file: ${{ env.DOCKERFILE }}
|
||||
target: ${{ env.BUILD_TARGET }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: ${{ !inputs.no_cache && 'type=gha' || '' }},scope=platform-${{ matrix.component }}
|
||||
cache-to: type=gha,scope=platform-${{ matrix.component }},mode=max
|
||||
|
||||
- name: Generate build summary
|
||||
run: |
|
||||
echo "## 🐳 Container Build Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Component:** ${{ matrix.component }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Registry:** ${{ inputs.registry || 'both' }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Tags:** ${{ steps.meta.outputs.tags }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "### Images Published:" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
echo "${{ steps.meta.outputs.tags }}" | sed 's/,/\n/g' >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
2
.github/workflows/platform-frontend-ci.yml
vendored
2
.github/workflows/platform-frontend-ci.yml
vendored
@@ -160,7 +160,7 @@ jobs:
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
|
||||
@@ -61,24 +61,27 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && npm install
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Start development server
|
||||
npm run dev
|
||||
pnpm dev
|
||||
|
||||
# Run E2E tests
|
||||
npm run test
|
||||
pnpm test
|
||||
|
||||
# Run Storybook for component development
|
||||
npm run storybook
|
||||
pnpm storybook
|
||||
|
||||
# Build production
|
||||
npm run build
|
||||
pnpm build
|
||||
|
||||
# Type checking
|
||||
npm run types
|
||||
pnpm types
|
||||
```
|
||||
|
||||
We have a components library in autogpt_platform/frontend/src/components/atoms that should be used when adding new pages and components.
|
||||
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
@@ -149,14 +152,23 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class
|
||||
3. Define input/output schemas
|
||||
4. Implement `run` method
|
||||
5. Register in block registry
|
||||
6. Generate the block uuid using `uuid.uuid4()`
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
Quick steps:
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
389
autogpt_platform/CONTAINERS.md
Normal file
389
autogpt_platform/CONTAINERS.md
Normal file
@@ -0,0 +1,389 @@
|
||||
# AutoGPT Platform Container Publishing
|
||||
|
||||
This document describes the container publishing infrastructure and deployment options for the AutoGPT Platform.
|
||||
|
||||
## Published Container Images
|
||||
|
||||
### GitHub Container Registry (GHCR) - Recommended
|
||||
|
||||
- **Backend**: `ghcr.io/significant-gravitas/autogpt-platform-backend`
|
||||
- **Frontend**: `ghcr.io/significant-gravitas/autogpt-platform-frontend`
|
||||
|
||||
### Docker Hub
|
||||
|
||||
- **Backend**: `significantgravitas/autogpt-platform-backend`
|
||||
- **Frontend**: `significantgravitas/autogpt-platform-frontend`
|
||||
|
||||
## Available Tags
|
||||
|
||||
- `latest` - Latest stable release from master branch
|
||||
- `v1.0.0`, `v1.1.0`, etc. - Specific version releases
|
||||
- `main` - Latest development build (use with caution)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Using Docker Compose (Recommended)
|
||||
|
||||
```bash
|
||||
# Clone the repository (or just download the compose file)
|
||||
git clone https://github.com/Significant-Gravitas/AutoGPT.git
|
||||
cd AutoGPT/autogpt_platform
|
||||
|
||||
# Deploy with published images
|
||||
./deploy.sh deploy
|
||||
```
|
||||
|
||||
### Manual Docker Run
|
||||
|
||||
```bash
|
||||
# Start dependencies first
|
||||
docker network create autogpt
|
||||
|
||||
# PostgreSQL
|
||||
docker run -d --name postgres --network autogpt \
|
||||
-e POSTGRES_DB=autogpt \
|
||||
-e POSTGRES_USER=autogpt \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-v postgres_data:/var/lib/postgresql/data \
|
||||
postgres:15
|
||||
|
||||
# Redis
|
||||
docker run -d --name redis --network autogpt \
|
||||
-v redis_data:/data \
|
||||
redis:7-alpine redis-server --requirepass password
|
||||
|
||||
# RabbitMQ
|
||||
docker run -d --name rabbitmq --network autogpt \
|
||||
-e RABBITMQ_DEFAULT_USER=autogpt \
|
||||
-e RABBITMQ_DEFAULT_PASS=password \
|
||||
-p 15672:15672 \
|
||||
rabbitmq:3-management
|
||||
|
||||
# Backend
|
||||
docker run -d --name backend --network autogpt \
|
||||
-p 8000:8000 \
|
||||
-e DATABASE_URL=postgresql://autogpt:password@postgres:5432/autogpt \
|
||||
-e REDIS_HOST=redis \
|
||||
-e RABBITMQ_HOST=rabbitmq \
|
||||
ghcr.io/significant-gravitas/autogpt-platform-backend:latest
|
||||
|
||||
# Frontend
|
||||
docker run -d --name frontend --network autogpt \
|
||||
-p 3000:3000 \
|
||||
-e AGPT_SERVER_URL=http://localhost:8000/api \
|
||||
ghcr.io/significant-gravitas/autogpt-platform-frontend:latest
|
||||
```
|
||||
|
||||
## Deployment Scripts
|
||||
|
||||
### Deploy Script
|
||||
|
||||
The included `deploy.sh` script provides a complete deployment solution:
|
||||
|
||||
```bash
|
||||
# Basic deployment
|
||||
./deploy.sh deploy
|
||||
|
||||
# Deploy specific version
|
||||
./deploy.sh -v v1.0.0 deploy
|
||||
|
||||
# Deploy from Docker Hub
|
||||
./deploy.sh -r docker.io deploy
|
||||
|
||||
# Production deployment
|
||||
./deploy.sh -p production deploy
|
||||
|
||||
# Other operations
|
||||
./deploy.sh start # Start services
|
||||
./deploy.sh stop # Stop services
|
||||
./deploy.sh restart # Restart services
|
||||
./deploy.sh update # Update to latest
|
||||
./deploy.sh backup # Create backup
|
||||
./deploy.sh status # Show status
|
||||
./deploy.sh logs # Show logs
|
||||
./deploy.sh cleanup # Remove everything
|
||||
```
|
||||
|
||||
## Platform-Specific Deployment Guides
|
||||
|
||||
### Unraid
|
||||
|
||||
See [Unraid Deployment Guide](../docs/content/platform/deployment/unraid.md)
|
||||
|
||||
Key features:
|
||||
- Community Applications template
|
||||
- Web UI management
|
||||
- Automatic updates
|
||||
- Built-in backup system
|
||||
|
||||
### Home Assistant Add-on
|
||||
|
||||
See [Home Assistant Add-on Guide](../docs/content/platform/deployment/home-assistant.md)
|
||||
|
||||
Key features:
|
||||
- Native Home Assistant integration
|
||||
- Automation services
|
||||
- Entity monitoring
|
||||
- Backup integration
|
||||
|
||||
### Kubernetes
|
||||
|
||||
See [Kubernetes Deployment Guide](../docs/content/platform/deployment/kubernetes.md)
|
||||
|
||||
Key features:
|
||||
- Helm charts
|
||||
- Horizontal scaling
|
||||
- Health checks
|
||||
- Persistent volumes
|
||||
|
||||
## Container Architecture
|
||||
|
||||
### Backend Container
|
||||
|
||||
- **Base Image**: `debian:13-slim`
|
||||
- **Runtime**: Python 3.13 with Poetry
|
||||
- **Services**: REST API, WebSocket, Executor, Scheduler, Database Manager, Notification
|
||||
- **Ports**: 8000-8007 (depending on service)
|
||||
- **Health Check**: `GET /health`
|
||||
|
||||
### Frontend Container
|
||||
|
||||
- **Base Image**: `node:21-alpine`
|
||||
- **Runtime**: Next.js production build
|
||||
- **Port**: 3000
|
||||
- **Health Check**: HTTP 200 on root path
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
### Required Environment Variables
|
||||
|
||||
#### Backend
|
||||
```env
|
||||
DATABASE_URL=postgresql://user:pass@host:5432/db
|
||||
REDIS_HOST=redis
|
||||
RABBITMQ_HOST=rabbitmq
|
||||
JWT_SECRET=your-secret-key
|
||||
```
|
||||
|
||||
#### Frontend
|
||||
```env
|
||||
AGPT_SERVER_URL=http://backend:8000/api
|
||||
SUPABASE_URL=http://auth:8000
|
||||
```
|
||||
|
||||
### Optional Configuration
|
||||
|
||||
```env
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_DEBUG=false
|
||||
|
||||
# Performance
|
||||
REDIS_PASSWORD=your-redis-password
|
||||
RABBITMQ_PASSWORD=your-rabbitmq-password
|
||||
|
||||
# Security
|
||||
CORS_ORIGINS=http://localhost:3000
|
||||
```
|
||||
|
||||
## CI/CD Pipeline
|
||||
|
||||
### GitHub Actions Workflow
|
||||
|
||||
The publishing workflow (`.github/workflows/platform-container-publish.yml`) automatically:
|
||||
|
||||
1. **Triggers** on releases and manual dispatch
|
||||
2. **Builds** both backend and frontend containers
|
||||
3. **Tests** container functionality
|
||||
4. **Publishes** to both GHCR and Docker Hub
|
||||
5. **Tags** with version and latest
|
||||
|
||||
### Manual Publishing
|
||||
|
||||
```bash
|
||||
# Build and tag locally
|
||||
docker build -t ghcr.io/significant-gravitas/autogpt-platform-backend:latest \
|
||||
-f autogpt_platform/backend/Dockerfile \
|
||||
--target server .
|
||||
|
||||
docker build -t ghcr.io/significant-gravitas/autogpt-platform-frontend:latest \
|
||||
-f autogpt_platform/frontend/Dockerfile \
|
||||
--target prod .
|
||||
|
||||
# Push to registry
|
||||
docker push ghcr.io/significant-gravitas/autogpt-platform-backend:latest
|
||||
docker push ghcr.io/significant-gravitas/autogpt-platform-frontend:latest
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Container Security
|
||||
|
||||
1. **Non-root users** - Containers run as non-root
|
||||
2. **Minimal base images** - Using slim/alpine images
|
||||
3. **No secrets in images** - All secrets via environment variables
|
||||
4. **Read-only filesystem** - Where possible
|
||||
5. **Resource limits** - CPU and memory limits set
|
||||
|
||||
### Deployment Security
|
||||
|
||||
1. **Network isolation** - Use dedicated networks
|
||||
2. **TLS encryption** - Enable HTTPS in production
|
||||
3. **Secret management** - Use Docker secrets or external secret stores
|
||||
4. **Regular updates** - Keep images updated
|
||||
5. **Vulnerability scanning** - Regular security scans
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Health Checks
|
||||
|
||||
All containers include health checks:
|
||||
|
||||
```bash
|
||||
# Check container health
|
||||
docker inspect --format='{{.State.Health.Status}}' container_name
|
||||
|
||||
# Manual health check
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
### Metrics
|
||||
|
||||
The backend exposes Prometheus metrics at `/metrics`:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/metrics
|
||||
```
|
||||
|
||||
### Logging
|
||||
|
||||
Containers log to stdout/stderr for easy aggregation:
|
||||
|
||||
```bash
|
||||
# View logs
|
||||
docker logs container_name
|
||||
|
||||
# Follow logs
|
||||
docker logs -f container_name
|
||||
|
||||
# Aggregate logs
|
||||
docker compose logs -f
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Container won't start**
|
||||
```bash
|
||||
# Check logs
|
||||
docker logs container_name
|
||||
|
||||
# Check environment
|
||||
docker exec container_name env
|
||||
```
|
||||
|
||||
2. **Database connection failed**
|
||||
```bash
|
||||
# Test connectivity
|
||||
docker exec backend ping postgres
|
||||
|
||||
# Check database status
|
||||
docker exec postgres pg_isready
|
||||
```
|
||||
|
||||
3. **Port conflicts**
|
||||
```bash
|
||||
# Check port usage
|
||||
ss -tuln | grep :3000
|
||||
|
||||
# Use different ports
|
||||
docker run -p 3001:3000 ...
|
||||
```
|
||||
|
||||
### Debug Mode
|
||||
|
||||
Enable debug mode for detailed logging:
|
||||
|
||||
```env
|
||||
LOG_LEVEL=DEBUG
|
||||
ENABLE_DEBUG=true
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Resource Limits
|
||||
|
||||
```yaml
|
||||
# Docker Compose
|
||||
services:
|
||||
backend:
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 2G
|
||||
cpus: '1.0'
|
||||
reservations:
|
||||
memory: 1G
|
||||
cpus: '0.5'
|
||||
```
|
||||
|
||||
### Scaling
|
||||
|
||||
```bash
|
||||
# Scale backend services
|
||||
docker compose up -d --scale backend=3
|
||||
|
||||
# Or use Docker Swarm
|
||||
docker service scale backend=3
|
||||
```
|
||||
|
||||
## Backup and Recovery
|
||||
|
||||
### Data Backup
|
||||
|
||||
```bash
|
||||
# Database backup
|
||||
docker exec postgres pg_dump -U autogpt autogpt > backup.sql
|
||||
|
||||
# Volume backup
|
||||
docker run --rm -v postgres_data:/data -v $(pwd):/backup \
|
||||
alpine tar czf /backup/postgres_backup.tar.gz /data
|
||||
```
|
||||
|
||||
### Restore
|
||||
|
||||
```bash
|
||||
# Database restore
|
||||
docker exec -i postgres psql -U autogpt autogpt < backup.sql
|
||||
|
||||
# Volume restore
|
||||
docker run --rm -v postgres_data:/data -v $(pwd):/backup \
|
||||
alpine tar xzf /backup/postgres_backup.tar.gz -C /
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
- **Documentation**: [Platform Docs](../docs/content/platform/)
|
||||
- **Issues**: [GitHub Issues](https://github.com/Significant-Gravitas/AutoGPT/issues)
|
||||
- **Discord**: [AutoGPT Community](https://discord.gg/autogpt)
|
||||
- **Docker Hub**: [Container Registry](https://hub.docker.com/r/significantgravitas/)
|
||||
|
||||
## Contributing
|
||||
|
||||
To contribute to the container infrastructure:
|
||||
|
||||
1. **Test locally** with `docker build` and `docker run`
|
||||
2. **Update documentation** if making changes
|
||||
3. **Test deployment scripts** on your platform
|
||||
4. **Submit PR** with clear description of changes
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [ ] ARM64 support for Apple Silicon
|
||||
- [ ] Helm charts for Kubernetes
|
||||
- [ ] Official Unraid template
|
||||
- [ ] Home Assistant Add-on store submission
|
||||
- [ ] Multi-stage builds optimization
|
||||
- [ ] Security scanning integration
|
||||
- [ ] Performance benchmarking
|
||||
@@ -2,16 +2,38 @@
|
||||
|
||||
Welcome to the AutoGPT Platform - a powerful system for creating and running AI agents to solve business problems. This platform enables you to harness the power of artificial intelligence to automate tasks, analyze data, and generate insights for your organization.
|
||||
|
||||
## Getting Started
|
||||
## Deployment Options
|
||||
|
||||
### Quick Deploy with Published Containers (Recommended)
|
||||
|
||||
The fastest way to get started is using our pre-built containers:
|
||||
|
||||
```bash
|
||||
# Download and run with published images
|
||||
curl -fsSL https://raw.githubusercontent.com/Significant-Gravitas/AutoGPT/master/autogpt_platform/deploy.sh -o deploy.sh
|
||||
chmod +x deploy.sh
|
||||
./deploy.sh deploy
|
||||
```
|
||||
|
||||
Access the platform at http://localhost:3000 after deployment completes.
|
||||
|
||||
### Platform-Specific Deployments
|
||||
|
||||
- **Unraid**: [Deployment Guide](../docs/content/platform/deployment/unraid.md)
|
||||
- **Home Assistant**: [Add-on Guide](../docs/content/platform/deployment/home-assistant.md)
|
||||
- **Kubernetes**: [K8s Deployment](../docs/content/platform/deployment/kubernetes.md)
|
||||
- **General Containers**: [Container Guide](../docs/content/platform/container-deployment.md)
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker
|
||||
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
|
||||
|
||||
### Running the System
|
||||
### Running from Source
|
||||
|
||||
To run the AutoGPT Platform, follow these steps:
|
||||
To run the AutoGPT Platform from source for development:
|
||||
|
||||
1. Clone this repository to your local machine and navigate to the `autogpt_platform` directory within the repository:
|
||||
|
||||
@@ -157,3 +179,28 @@ If you need to update the API client after making changes to the backend API:
|
||||
```
|
||||
|
||||
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.
|
||||
|
||||
## Container Deployment
|
||||
|
||||
For production deployments and specific platforms, see our container deployment guides:
|
||||
|
||||
- **[Container Deployment Overview](CONTAINERS.md)** - Complete guide to using published containers
|
||||
- **[Deployment Script](deploy.sh)** - Automated deployment and management tool
|
||||
- **[Published Images](docker-compose.published.yml)** - Docker Compose for published containers
|
||||
|
||||
### Published Container Images
|
||||
|
||||
- **Backend**: `ghcr.io/significant-gravitas/autogpt-platform-backend:latest`
|
||||
- **Frontend**: `ghcr.io/significant-gravitas/autogpt-platform-frontend:latest`
|
||||
|
||||
### Quick Production Deployment
|
||||
|
||||
```bash
|
||||
# Deploy with published containers
|
||||
./deploy.sh deploy
|
||||
|
||||
# Or use the published compose file directly
|
||||
docker compose -f docker-compose.published.yml up -d
|
||||
```
|
||||
|
||||
For detailed deployment instructions, troubleshooting, and platform-specific guides, see the [Container Documentation](CONTAINERS.md).
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
raw: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
hash: str
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
PREFIX: str = "agpt_"
|
||||
PREFIX_LENGTH: int = 8
|
||||
POSTFIX_LENGTH: int = 8
|
||||
|
||||
def generate_api_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with all its parts."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
return APIKeyContainer(
|
||||
raw=raw_key,
|
||||
prefix=raw_key[: self.PREFIX_LENGTH],
|
||||
postfix=raw_key[-self.POSTFIX_LENGTH :],
|
||||
hash=hashlib.sha256(raw_key.encode()).hexdigest(),
|
||||
)
|
||||
|
||||
def verify_api_key(self, provided_key: str, stored_hash: str) -> bool:
|
||||
"""Verify if a provided API key matches the stored hash."""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(provided_hash, stored_hash)
|
||||
@@ -0,0 +1,78 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
key: str
|
||||
head: str
|
||||
tail: str
|
||||
hash: str
|
||||
salt: str
|
||||
|
||||
|
||||
class APIKeySmith:
|
||||
PREFIX: str = "agpt_"
|
||||
HEAD_LENGTH: int = 8
|
||||
TAIL_LENGTH: int = 8
|
||||
|
||||
def generate_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with secure hashing."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
hash, salt = self.hash_key(raw_key)
|
||||
|
||||
return APIKeyContainer(
|
||||
key=raw_key,
|
||||
head=raw_key[: self.HEAD_LENGTH],
|
||||
tail=raw_key[-self.TAIL_LENGTH :],
|
||||
hash=hash,
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
def verify_key(
|
||||
self, provided_key: str, known_hash: str, known_salt: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verify an API key against a known hash (+ salt).
|
||||
Supports verifying both legacy SHA256 and secure Scrypt hashes.
|
||||
"""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
|
||||
# Handle legacy SHA256 hashes (migration support)
|
||||
if known_salt is None:
|
||||
legacy_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(legacy_hash, known_hash)
|
||||
|
||||
try:
|
||||
salt_bytes = bytes.fromhex(known_salt)
|
||||
provided_hash = self._hash_key_with_salt(provided_key, salt_bytes)
|
||||
return secrets.compare_digest(provided_hash, known_hash)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
def _generate_salt(self) -> bytes:
|
||||
"""Generate a random salt for hashing."""
|
||||
return secrets.token_bytes(32)
|
||||
|
||||
def _hash_key_with_salt(self, raw_key: str, salt: bytes) -> str:
|
||||
"""Hash API key using Scrypt with salt."""
|
||||
kdf = Scrypt(
|
||||
length=32,
|
||||
salt=salt,
|
||||
n=2**14, # CPU/memory cost parameter
|
||||
r=8, # Block size parameter
|
||||
p=1, # Parallelization parameter
|
||||
)
|
||||
key_hash = kdf.derive(raw_key.encode())
|
||||
return key_hash.hex()
|
||||
@@ -0,0 +1,79 @@
|
||||
import hashlib
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
|
||||
|
||||
def test_generate_api_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
assert key.key.startswith(keysmith.PREFIX)
|
||||
assert key.head == key.key[: keysmith.HEAD_LENGTH]
|
||||
assert key.tail == key.key[-keysmith.TAIL_LENGTH :]
|
||||
assert len(key.hash) == 64 # 32 bytes hex encoded
|
||||
assert len(key.salt) == 64 # 32 bytes hex encoded
|
||||
|
||||
|
||||
def test_verify_new_secure_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test correct key validates
|
||||
assert keysmith.verify_key(key.key, key.hash, key.salt) is True
|
||||
|
||||
# Test wrong key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey123"
|
||||
assert keysmith.verify_key(wrong_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_verify_legacy_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}legacykey123"
|
||||
legacy_hash = hashlib.sha256(legacy_key.encode()).hexdigest()
|
||||
|
||||
# Test legacy key validates without salt
|
||||
assert keysmith.verify_key(legacy_key, legacy_hash) is True
|
||||
|
||||
# Test wrong legacy key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wronglegacy"
|
||||
assert keysmith.verify_key(wrong_key, legacy_hash) is False
|
||||
|
||||
|
||||
def test_rehash_existing_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}migratekey123"
|
||||
|
||||
# Migrate the legacy key
|
||||
new_hash, new_salt = keysmith.hash_key(legacy_key)
|
||||
|
||||
# Verify migrated key works
|
||||
assert keysmith.verify_key(legacy_key, new_hash, new_salt) is True
|
||||
|
||||
# Verify different key fails with migrated hash
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey"
|
||||
assert keysmith.verify_key(wrong_key, new_hash, new_salt) is False
|
||||
|
||||
|
||||
def test_invalid_key_prefix():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test key without proper prefix fails
|
||||
invalid_key = "invalid_prefix_key"
|
||||
assert keysmith.verify_key(invalid_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_secure_hash_requires_salt():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Secure hash without salt should fail
|
||||
assert keysmith.verify_key(key.key, key.hash) is False
|
||||
|
||||
|
||||
def test_invalid_salt_format():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Invalid salt format should fail gracefully
|
||||
assert keysmith.verify_key(key.key, key.hash, "invalid_hex") is False
|
||||
@@ -1,13 +1,13 @@
|
||||
from .depends import requires_admin_user, requires_user
|
||||
from .jwt_utils import parse_jwt_token
|
||||
from .middleware import APIKeyValidator, auth_middleware
|
||||
from .config import verify_settings
|
||||
from .dependencies import get_user_id, requires_admin_user, requires_user
|
||||
from .helpers import add_auth_responses_to_openapi
|
||||
from .models import User
|
||||
|
||||
__all__ = [
|
||||
"parse_jwt_token",
|
||||
"requires_user",
|
||||
"verify_settings",
|
||||
"get_user_id",
|
||||
"requires_admin_user",
|
||||
"APIKeyValidator",
|
||||
"auth_middleware",
|
||||
"requires_user",
|
||||
"add_auth_responses_to_openapi",
|
||||
"User",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,90 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from jwt.algorithms import get_default_algorithms, has_crypto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfigError(ValueError):
|
||||
"""Raised when authentication configuration is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
ALGO_RECOMMENDATION = (
|
||||
"We highly recommend using an asymmetric algorithm such as ES256, "
|
||||
"because when leaked, a shared secret would allow anyone to "
|
||||
"forge valid tokens and impersonate users. "
|
||||
"More info: https://supabase.com/docs/guides/auth/signing-keys#choosing-the-right-signing-algorithm" # noqa
|
||||
)
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
self.JWT_VERIFY_KEY: str = os.getenv(
|
||||
"JWT_VERIFY_KEY", os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
).strip()
|
||||
self.JWT_ALGORITHM: str = os.getenv("JWT_SIGN_ALGORITHM", "HS256").strip()
|
||||
|
||||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
if not self.JWT_VERIFY_KEY:
|
||||
raise AuthConfigError(
|
||||
"JWT_VERIFY_KEY must be set. "
|
||||
"An empty JWT secret would allow anyone to forge valid tokens."
|
||||
)
|
||||
|
||||
if len(self.JWT_VERIFY_KEY) < 32:
|
||||
logger.warning(
|
||||
"⚠️ JWT_VERIFY_KEY appears weak (less than 32 characters). "
|
||||
"Consider using a longer, cryptographically secure secret."
|
||||
)
|
||||
|
||||
supported_algorithms = get_default_algorithms().keys()
|
||||
|
||||
if not has_crypto:
|
||||
logger.warning(
|
||||
"⚠️ Asymmetric JWT verification is not available "
|
||||
"because the 'cryptography' package is not installed. "
|
||||
+ ALGO_RECOMMENDATION
|
||||
)
|
||||
|
||||
if (
|
||||
self.JWT_ALGORITHM not in supported_algorithms
|
||||
or self.JWT_ALGORITHM == "none"
|
||||
):
|
||||
raise AuthConfigError(
|
||||
f"Invalid JWT_SIGN_ALGORITHM: '{self.JWT_ALGORITHM}'. "
|
||||
"Supported algorithms are listed on "
|
||||
"https://pyjwt.readthedocs.io/en/stable/algorithms.html"
|
||||
)
|
||||
|
||||
if self.JWT_ALGORITHM.startswith("HS"):
|
||||
logger.warning(
|
||||
f"⚠️ JWT_SIGN_ALGORITHM is set to '{self.JWT_ALGORITHM}', "
|
||||
"a symmetric shared-key signature algorithm. " + ALGO_RECOMMENDATION
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
_settings: Settings = None # type: ignore
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
global _settings
|
||||
|
||||
if not _settings:
|
||||
_settings = Settings()
|
||||
|
||||
return _settings
|
||||
|
||||
|
||||
def verify_settings() -> None:
|
||||
global _settings
|
||||
|
||||
if not _settings:
|
||||
_settings = Settings() # calls validation indirectly
|
||||
return
|
||||
|
||||
_settings.validate()
|
||||
|
||||
306
autogpt_platform/autogpt_libs/autogpt_libs/auth/config_test.py
Normal file
306
autogpt_platform/autogpt_libs/autogpt_libs/auth/config_test.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Comprehensive tests for auth configuration to ensure 100% line and branch coverage.
|
||||
These tests verify critical security checks preventing JWT token forgery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth.config import AuthConfigError, Settings
|
||||
|
||||
|
||||
def test_environment_variable_precedence(mocker: MockerFixture):
|
||||
"""Test that environment variables take precedence over defaults."""
|
||||
secret = "environment-secret-key-with-proper-length-123456"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_environment_variable_backwards_compatible(mocker: MockerFixture):
|
||||
"""Test that SUPABASE_JWT_SECRET is read if JWT_VERIFY_KEY is not set."""
|
||||
secret = "environment-secret-key-with-proper-length-123456"
|
||||
mocker.patch.dict(os.environ, {"SUPABASE_JWT_SECRET": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_auth_config_error_inheritance():
|
||||
"""Test that AuthConfigError is properly defined as an Exception."""
|
||||
assert issubclass(AuthConfigError, Exception)
|
||||
error = AuthConfigError("test message")
|
||||
assert str(error) == "test message"
|
||||
|
||||
|
||||
def test_settings_static_after_creation(mocker: MockerFixture):
|
||||
"""Test that settings maintain their values after creation."""
|
||||
secret = "immutable-secret-key-with-proper-length-12345"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
original_secret = settings.JWT_VERIFY_KEY
|
||||
|
||||
# Changing environment after creation shouldn't affect settings
|
||||
os.environ["JWT_VERIFY_KEY"] = "different-secret"
|
||||
|
||||
assert settings.JWT_VERIFY_KEY == original_secret
|
||||
|
||||
|
||||
def test_settings_load_with_valid_secret(mocker: MockerFixture):
|
||||
"""Test auth enabled with a valid JWT secret."""
|
||||
valid_secret = "a" * 32 # 32 character secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": valid_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == valid_secret
|
||||
|
||||
|
||||
def test_settings_load_with_strong_secret(mocker: MockerFixture):
|
||||
"""Test auth enabled with a cryptographically strong secret."""
|
||||
strong_secret = "super-secret-jwt-token-with-at-least-32-characters-long"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": strong_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == strong_secret
|
||||
assert len(settings.JWT_VERIFY_KEY) >= 32
|
||||
|
||||
|
||||
def test_secret_empty_raises_error(mocker: MockerFixture):
|
||||
"""Test that auth enabled with empty secret raises AuthConfigError."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": ""}, clear=True)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
Settings()
|
||||
assert "JWT_VERIFY_KEY" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_secret_missing_raises_error(mocker: MockerFixture):
|
||||
"""Test that auth enabled without secret env var raises AuthConfigError."""
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
Settings()
|
||||
assert "JWT_VERIFY_KEY" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("secret", [" ", " ", "\t", "\n", " \t\n "])
|
||||
def test_secret_only_whitespace_raises_error(mocker: MockerFixture, secret: str):
|
||||
"""Test that auth enabled with whitespace-only secret raises error."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Settings()
|
||||
|
||||
|
||||
def test_secret_weak_logs_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that weak JWT secret triggers warning log."""
|
||||
weak_secret = "short" # Less than 32 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": weak_secret}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == weak_secret
|
||||
assert "key appears weak" in caplog.text.lower()
|
||||
assert "less than 32 characters" in caplog.text
|
||||
|
||||
|
||||
def test_secret_31_char_logs_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that 31-character secret triggers warning (boundary test)."""
|
||||
secret_31 = "a" * 31 # Exactly 31 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_31}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert len(settings.JWT_VERIFY_KEY) == 31
|
||||
assert "key appears weak" in caplog.text.lower()
|
||||
|
||||
|
||||
def test_secret_32_char_no_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that 32-character secret does not trigger warning (boundary test)."""
|
||||
secret_32 = "a" * 32 # Exactly 32 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_32}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert len(settings.JWT_VERIFY_KEY) == 32
|
||||
assert "JWT secret appears weak" not in caplog.text
|
||||
|
||||
|
||||
def test_secret_whitespace_stripped(mocker: MockerFixture):
|
||||
"""Test that JWT secret whitespace is stripped."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": f" {secret} "}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_secret_with_special_characters(mocker: MockerFixture):
|
||||
"""Test JWT secret with special characters."""
|
||||
special_secret = "!@#$%^&*()_+-=[]{}|;:,.<>?`~" + "a" * 10 # 40 chars total
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": special_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == special_secret
|
||||
|
||||
|
||||
def test_secret_with_unicode(mocker: MockerFixture):
|
||||
"""Test JWT secret with unicode characters."""
|
||||
unicode_secret = "秘密🔐キー" + "a" * 25 # Ensure >32 bytes
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": unicode_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == unicode_secret
|
||||
|
||||
|
||||
def test_secret_very_long(mocker: MockerFixture):
|
||||
"""Test JWT secret with excessive length."""
|
||||
long_secret = "a" * 1000 # 1000 character secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": long_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == long_secret
|
||||
assert len(settings.JWT_VERIFY_KEY) == 1000
|
||||
|
||||
|
||||
def test_secret_with_newline(mocker: MockerFixture):
|
||||
"""Test JWT secret containing newlines."""
|
||||
multiline_secret = "secret\nwith\nnewlines" + "a" * 20
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": multiline_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == multiline_secret
|
||||
|
||||
|
||||
def test_secret_base64_encoded(mocker: MockerFixture):
|
||||
"""Test JWT secret that looks like base64."""
|
||||
base64_secret = "dGhpc19pc19hX3NlY3JldF9rZXlfd2l0aF9wcm9wZXJfbGVuZ3Ro"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": base64_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == base64_secret
|
||||
|
||||
|
||||
def test_secret_numeric_only(mocker: MockerFixture):
|
||||
"""Test JWT secret with only numbers."""
|
||||
numeric_secret = "1234567890" * 4 # 40 character numeric secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": numeric_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == numeric_secret
|
||||
|
||||
|
||||
def test_algorithm_default_hs256(mocker: MockerFixture):
|
||||
"""Test that JWT algorithm defaults to HS256."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": "a" * 32}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_ALGORITHM == "HS256"
|
||||
|
||||
|
||||
def test_algorithm_whitespace_stripped(mocker: MockerFixture):
|
||||
"""Test that JWT algorithm whitespace is stripped."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": " HS256 "},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_ALGORITHM == "HS256"
|
||||
|
||||
|
||||
def test_no_crypto_warning(mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
|
||||
"""Test warning when crypto package is not available."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
# Mock has_crypto to return False
|
||||
mocker.patch("autogpt_libs.auth.config.has_crypto", False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
Settings()
|
||||
assert "Asymmetric JWT verification is not available" in caplog.text
|
||||
assert "cryptography" in caplog.text
|
||||
|
||||
|
||||
def test_algorithm_invalid_raises_error(mocker: MockerFixture):
|
||||
"""Test that invalid JWT algorithm raises AuthConfigError."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "INVALID_ALG"},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthConfigError) as exc_info:
|
||||
Settings()
|
||||
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
|
||||
assert "INVALID_ALG" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_algorithm_none_raises_error(mocker: MockerFixture):
|
||||
"""Test that 'none' algorithm raises AuthConfigError."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "none"},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthConfigError) as exc_info:
|
||||
Settings()
|
||||
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm", ["HS256", "HS384", "HS512"])
|
||||
def test_algorithm_symmetric_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
|
||||
):
|
||||
"""Test warning for symmetric algorithms (HS256, HS384, HS512)."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert algorithm in caplog.text
|
||||
assert "symmetric shared-key signature algorithm" in caplog.text
|
||||
assert settings.JWT_ALGORITHM == algorithm
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"algorithm",
|
||||
["ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512"],
|
||||
)
|
||||
def test_algorithm_asymmetric_no_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
|
||||
):
|
||||
"""Test that asymmetric algorithms do not trigger warning."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
# Should not contain the symmetric algorithm warning
|
||||
assert "symmetric shared-key signature algorithm" not in caplog.text
|
||||
assert settings.JWT_ALGORITHM == algorithm
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
FastAPI dependency functions for JWT-based authentication and authorization.
|
||||
|
||||
These are the high-level dependency functions used in route definitions.
|
||||
"""
|
||||
|
||||
import fastapi
|
||||
|
||||
from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
|
||||
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures
|
||||
"""
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures, 403 for insufficient permissions
|
||||
"""
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures or missing user ID
|
||||
"""
|
||||
user_id = jwt_payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
return user_id
|
||||
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Comprehensive integration tests for authentication dependencies.
|
||||
Tests the full authentication flow from HTTP requests to user validation.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Security
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth.dependencies import (
|
||||
get_user_id,
|
||||
requires_admin_user,
|
||||
requires_user,
|
||||
)
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
|
||||
class TestAuthDependencies:
|
||||
"""Test suite for authentication dependency functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create a test FastAPI application."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/user")
|
||||
def get_user_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
@app.get("/admin")
|
||||
def get_admin_endpoint(user: User = Security(requires_admin_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
@app.get("/user-id")
|
||||
def get_user_id_endpoint(user_id: str = Security(get_user_id)):
|
||||
return {"user_id": user_id}
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app):
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
# Mock get_jwt_payload to return our test payload
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
requires_admin_user(jwt_payload)
|
||||
|
||||
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = get_user_id(jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestAuthDependenciesIntegration:
|
||||
"""Integration tests for auth dependencies with FastAPI."""
|
||||
|
||||
acceptable_jwt_secret = "test-secret-with-proper-length-123456"
|
||||
|
||||
@pytest.fixture
|
||||
def create_token(self, mocker: MockerFixture):
|
||||
"""Helper to create JWT tokens."""
|
||||
import jwt
|
||||
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": self.acceptable_jwt_secret},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
def _create_token(payload, secret=self.acceptable_jwt_secret):
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
return _create_token
|
||||
|
||||
def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Should fail without auth header
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
token = create_token(
|
||||
{"sub": "test-user", "role": "user", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {token}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/admin")
|
||||
def admin_endpoint(user: User = Security(requires_admin_user)):
|
||||
return {"user_id": user.user_id}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Regular user token
|
||||
user_token = create_token(
|
||||
{"sub": "regular-user", "role": "user", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Admin token
|
||||
admin_token = create_token(
|
||||
{"sub": "admin-user", "role": "admin", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin", headers={"Authorization": f"Bearer {admin_token}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "admin-user"
|
||||
|
||||
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
"role": "admin",
|
||||
"email": "test@example.com",
|
||||
"app_metadata": {"provider": "email", "providers": ["email"]},
|
||||
"user_metadata": {
|
||||
"full_name": "Test User",
|
||||
"avatar_url": "https://example.com/avatar.jpg",
|
||||
},
|
||||
"aud": "authenticated",
|
||||
"iat": 1234567890,
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
"role": "user",
|
||||
"email": "测试@example.com",
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
"role": "user",
|
||||
"email": None,
|
||||
"phone": None,
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = requires_user(payload1)
|
||||
user2 = requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
assert user1.role == "user"
|
||||
assert user2.role == "admin"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload,expected_error,admin_only",
|
||||
[
|
||||
(None, "Authorization header is missing", False),
|
||||
({}, "User ID not found", False),
|
||||
({"sub": ""}, "User ID not found", False),
|
||||
({"role": "user"}, "User ID not found", False),
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
# Valid case
|
||||
user = verify_user({"sub": "user", "role": "user"}, admin_only=False)
|
||||
assert user.user_id == "user"
|
||||
@@ -1,46 +0,0 @@
|
||||
import fastapi
|
||||
|
||||
from .config import settings
|
||||
from .middleware import auth_middleware
|
||||
from .models import DEFAULT_USER_ID, User
|
||||
|
||||
|
||||
def requires_user(payload: dict = fastapi.Depends(auth_middleware)) -> User:
|
||||
return verify_user(payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(
|
||||
payload: dict = fastapi.Depends(auth_middleware),
|
||||
) -> User:
|
||||
return verify_user(payload, admin_only=True)
|
||||
|
||||
|
||||
def verify_user(payload: dict | None, admin_only: bool) -> User:
|
||||
if not payload:
|
||||
if settings.ENABLE_AUTH:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="Authorization header is missing"
|
||||
)
|
||||
# This handles the case when authentication is disabled
|
||||
payload = {"sub": DEFAULT_USER_ID, "role": "admin"}
|
||||
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
|
||||
if admin_only and payload["role"] != "admin":
|
||||
raise fastapi.HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return User.from_payload(payload)
|
||||
|
||||
|
||||
def get_user_id(payload: dict = fastapi.Depends(auth_middleware)) -> str:
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
return user_id
|
||||
@@ -1,68 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from .depends import requires_admin_user, requires_user, verify_user
|
||||
|
||||
|
||||
def test_verify_user_no_payload():
|
||||
user = verify_user(None, admin_only=False)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_no_user_id():
|
||||
with pytest.raises(Exception):
|
||||
verify_user({"role": "admin"}, admin_only=False)
|
||||
|
||||
|
||||
def test_verify_user_not_admin():
|
||||
with pytest.raises(Exception):
|
||||
verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
|
||||
admin_only=True,
|
||||
)
|
||||
|
||||
|
||||
def test_verify_user_with_admin_role():
|
||||
user = verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"},
|
||||
admin_only=True,
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_with_user_role():
|
||||
user = verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
|
||||
admin_only=False,
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "user"
|
||||
|
||||
|
||||
def test_requires_user():
|
||||
user = requires_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "user"
|
||||
|
||||
|
||||
def test_requires_user_no_user_id():
|
||||
with pytest.raises(Exception):
|
||||
requires_user({"role": "user"})
|
||||
|
||||
|
||||
def test_requires_admin_user():
|
||||
user = requires_admin_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"}
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_requires_admin_user_not_admin():
|
||||
with pytest.raises(Exception):
|
||||
requires_admin_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
|
||||
)
|
||||
68
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers.py
Normal file
68
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
for method, details in methods.items():
|
||||
security_schemas = [
|
||||
schema
|
||||
for auth_option in details.get("security", [])
|
||||
for schema in auth_option.keys()
|
||||
]
|
||||
if bearer_jwt_auth.scheme_name not in security_schemas:
|
||||
continue
|
||||
|
||||
if "responses" not in details:
|
||||
details["responses"] = {}
|
||||
|
||||
details["responses"]["401"] = {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
|
||||
# Ensure #/components/responses exists
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "responses" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["responses"] = {}
|
||||
|
||||
# Define 401 response
|
||||
openapi_schema["components"]["responses"]["HTTP401NotAuthenticatedError"] = {
|
||||
"description": "Authentication required",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"detail": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
435
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers_test.py
Normal file
435
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers_test.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Comprehensive tests for auth helpers module to achieve 100% coverage.
|
||||
Tests OpenAPI schema generation and authentication response handling.
|
||||
"""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from autogpt_libs.auth.helpers import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth.jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_basic():
|
||||
"""Test adding 401 responses to OpenAPI schema."""
|
||||
app = FastAPI(title="Test App", version="1.0.0")
|
||||
|
||||
# Add some test endpoints with authentication
|
||||
from fastapi import Depends
|
||||
|
||||
from autogpt_libs.auth.dependencies import requires_user
|
||||
|
||||
@app.get("/protected", dependencies=[Depends(requires_user)])
|
||||
def protected_endpoint():
|
||||
return {"message": "Protected"}
|
||||
|
||||
@app.get("/public")
|
||||
def public_endpoint():
|
||||
return {"message": "Public"}
|
||||
|
||||
# Apply the OpenAPI customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get the OpenAPI schema
|
||||
schema = app.openapi()
|
||||
|
||||
# Verify basic schema properties
|
||||
assert schema["info"]["title"] == "Test App"
|
||||
assert schema["info"]["version"] == "1.0.0"
|
||||
|
||||
# Verify 401 response component is added
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# Verify 401 response structure
|
||||
error_response = schema["components"]["responses"]["HTTP401NotAuthenticatedError"]
|
||||
assert error_response["description"] == "Authentication required"
|
||||
assert "application/json" in error_response["content"]
|
||||
assert "schema" in error_response["content"]["application/json"]
|
||||
|
||||
# Verify schema properties
|
||||
response_schema = error_response["content"]["application/json"]["schema"]
|
||||
assert response_schema["type"] == "object"
|
||||
assert "detail" in response_schema["properties"]
|
||||
assert response_schema["properties"]["detail"]["type"] == "string"
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_with_security():
|
||||
"""Test that 401 responses are added only to secured endpoints."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock endpoint with security
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.dependencies import get_user_id
|
||||
|
||||
@app.get("/secured")
|
||||
def secured_endpoint(user_id: str = Security(get_user_id)):
|
||||
return {"user_id": user_id}
|
||||
|
||||
@app.post("/also-secured")
|
||||
def another_secured(user_id: str = Security(get_user_id)):
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/unsecured")
|
||||
def unsecured_endpoint():
|
||||
return {"public": True}
|
||||
|
||||
# Apply OpenAPI customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema
|
||||
schema = app.openapi()
|
||||
|
||||
# Check that secured endpoints have 401 responses
|
||||
if "/secured" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/secured"]:
|
||||
secured_get = schema["paths"]["/secured"]["get"]
|
||||
if "responses" in secured_get:
|
||||
assert "401" in secured_get["responses"]
|
||||
assert (
|
||||
secured_get["responses"]["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
if "/also-secured" in schema["paths"]:
|
||||
if "post" in schema["paths"]["/also-secured"]:
|
||||
secured_post = schema["paths"]["/also-secured"]["post"]
|
||||
if "responses" in secured_post:
|
||||
assert "401" in secured_post["responses"]
|
||||
|
||||
# Check that unsecured endpoint does not have 401 response
|
||||
if "/unsecured" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/unsecured"]:
|
||||
unsecured_get = schema["paths"]["/unsecured"]["get"]
|
||||
if "responses" in unsecured_get:
|
||||
assert "401" not in unsecured_get.get("responses", {})
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_cached_schema():
|
||||
"""Test that OpenAPI schema is cached after first generation."""
|
||||
app = FastAPI()
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema twice
|
||||
schema1 = app.openapi()
|
||||
schema2 = app.openapi()
|
||||
|
||||
# Should return the same cached object
|
||||
assert schema1 is schema2
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_existing_responses():
|
||||
"""Test handling endpoints that already have responses defined."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get(
|
||||
"/with-responses",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
404: {"description": "Not found"},
|
||||
},
|
||||
)
|
||||
def endpoint_with_responses(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"data": "test"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Check that existing responses are preserved and 401 is added
|
||||
if "/with-responses" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/with-responses"]:
|
||||
responses = schema["paths"]["/with-responses"]["get"].get("responses", {})
|
||||
# Original responses should be preserved
|
||||
if "200" in responses:
|
||||
assert responses["200"]["description"] == "Success"
|
||||
if "404" in responses:
|
||||
assert responses["404"]["description"] == "Not found"
|
||||
# 401 should be added
|
||||
if "401" in responses:
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_no_security_endpoints():
|
||||
"""Test with app that has no secured endpoints."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/public1")
|
||||
def public1():
|
||||
return {"message": "public1"}
|
||||
|
||||
@app.post("/public2")
|
||||
def public2():
|
||||
return {"message": "public2"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Component should still be added for consistency
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# But no endpoints should have 401 responses
|
||||
for path in schema["paths"].values():
|
||||
for method in path.values():
|
||||
if isinstance(method, dict) and "responses" in method:
|
||||
assert "401" not in method["responses"]
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_multiple_security_schemes():
|
||||
"""Test endpoints with multiple security requirements."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.dependencies import requires_admin_user, requires_user
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
@app.get("/multi-auth")
|
||||
def multi_auth(
|
||||
user: User = Security(requires_user),
|
||||
admin: User = Security(requires_admin_user),
|
||||
):
|
||||
return {"status": "super secure"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Should have 401 response
|
||||
if "/multi-auth" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/multi-auth"]:
|
||||
responses = schema["paths"]["/multi-auth"]["get"].get("responses", {})
|
||||
if "401" in responses:
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_empty_components():
|
||||
"""Test when OpenAPI schema has no components section initially."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock get_openapi to return schema without components
|
||||
original_get_openapi = get_openapi
|
||||
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Remove components if it exists
|
||||
if "components" in schema:
|
||||
del schema["components"]
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Components should be created
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_all_http_methods():
|
||||
"""Test that all HTTP methods are handled correctly."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get("/resource")
|
||||
def get_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "GET"}
|
||||
|
||||
@app.post("/resource")
|
||||
def post_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "POST"}
|
||||
|
||||
@app.put("/resource")
|
||||
def put_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "PUT"}
|
||||
|
||||
@app.patch("/resource")
|
||||
def patch_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "PATCH"}
|
||||
|
||||
@app.delete("/resource")
|
||||
def delete_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "DELETE"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# All methods should have 401 response
|
||||
if "/resource" in schema["paths"]:
|
||||
for method in ["get", "post", "put", "patch", "delete"]:
|
||||
if method in schema["paths"]["/resource"]:
|
||||
method_spec = schema["paths"]["/resource"][method]
|
||||
if "responses" in method_spec:
|
||||
assert "401" in method_spec["responses"]
|
||||
|
||||
|
||||
def test_bearer_jwt_auth_scheme_config():
|
||||
"""Test that bearer_jwt_auth is configured correctly."""
|
||||
assert bearer_jwt_auth.scheme_name == "HTTPBearerJWT"
|
||||
assert bearer_jwt_auth.auto_error is False
|
||||
|
||||
|
||||
def test_add_auth_responses_with_no_routes():
|
||||
"""Test OpenAPI generation with app that has no routes."""
|
||||
app = FastAPI(title="Empty App")
|
||||
|
||||
# Apply customization to empty app
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Should still have basic structure
|
||||
assert schema["info"]["title"] == "Empty App"
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
|
||||
def test_custom_openapi_function_replacement():
|
||||
"""Test that the custom openapi function properly replaces the default."""
|
||||
app = FastAPI()
|
||||
|
||||
# Store original function
|
||||
original_openapi = app.openapi
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Function should be replaced
|
||||
assert app.openapi != original_openapi
|
||||
assert callable(app.openapi)
|
||||
|
||||
|
||||
def test_endpoint_without_responses_section():
|
||||
"""Test endpoint that has security but no responses section initially."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
from fastapi.openapi.utils import get_openapi as original_get_openapi
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
# Create endpoint
|
||||
@app.get("/no-responses")
|
||||
def endpoint_without_responses(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"data": "test"}
|
||||
|
||||
# Mock get_openapi to remove responses from the endpoint
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Remove responses from our endpoint to trigger line 40
|
||||
if "/no-responses" in schema.get("paths", {}):
|
||||
if "get" in schema["paths"]["/no-responses"]:
|
||||
# Delete responses to force the code to create it
|
||||
if "responses" in schema["paths"]["/no-responses"]["get"]:
|
||||
del schema["paths"]["/no-responses"]["get"]["responses"]
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema and verify 401 was added
|
||||
schema = app.openapi()
|
||||
|
||||
# The endpoint should now have 401 response
|
||||
if "/no-responses" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/no-responses"]:
|
||||
responses = schema["paths"]["/no-responses"]["get"].get("responses", {})
|
||||
assert "401" in responses
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_components_with_existing_responses():
|
||||
"""Test when components already has a responses section."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock get_openapi to return schema with existing components/responses
|
||||
from fastapi.openapi.utils import get_openapi as original_get_openapi
|
||||
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Add existing components/responses
|
||||
if "components" not in schema:
|
||||
schema["components"] = {}
|
||||
schema["components"]["responses"] = {
|
||||
"ExistingResponse": {"description": "An existing response"}
|
||||
}
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Both responses should exist
|
||||
assert "ExistingResponse" in schema["components"]["responses"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# Verify our 401 response structure
|
||||
error_response = schema["components"]["responses"][
|
||||
"HTTP401NotAuthenticatedError"
|
||||
]
|
||||
assert error_response["description"] == "Authentication required"
|
||||
|
||||
|
||||
def test_openapi_schema_persistence():
|
||||
"""Test that modifications to OpenAPI schema persist correctly."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"test": True}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema multiple times
|
||||
schema1 = app.openapi()
|
||||
|
||||
# Modify the cached schema (shouldn't affect future calls)
|
||||
schema1["info"]["title"] = "Modified Title"
|
||||
|
||||
# Clear cache and get again
|
||||
app.openapi_schema = None
|
||||
schema2 = app.openapi()
|
||||
|
||||
# Should regenerate with original title
|
||||
assert schema2["info"]["title"] == app.title
|
||||
assert schema2["info"]["title"] != "Modified Title"
|
||||
@@ -1,11 +1,48 @@
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from .config import settings
|
||||
from .config import get_settings
|
||||
from .models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Bearer token authentication scheme
|
||||
bearer_jwt_auth = HTTPBearer(
|
||||
bearerFormat="jwt", scheme_name="HTTPBearerJWT", auto_error=False
|
||||
)
|
||||
|
||||
|
||||
def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract and validate JWT payload from HTTP Authorization header.
|
||||
|
||||
This is the core authentication function that handles:
|
||||
- Reading the `Authorization` header to obtain the JWT token
|
||||
- Verifying the JWT token's signature
|
||||
- Decoding the JWT token's payload
|
||||
|
||||
:param credentials: HTTP Authorization credentials from bearer token
|
||||
:return: JWT payload dictionary
|
||||
:raises HTTPException: 401 if authentication fails
|
||||
"""
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
logger.debug("Token decoded successfully")
|
||||
return payload
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
|
||||
def parse_jwt_token(token: str) -> dict[str, Any]:
|
||||
"""
|
||||
Parse and validate a JWT token.
|
||||
|
||||
@@ -13,10 +50,11 @@ def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
:return: The decoded payload
|
||||
:raises ValueError: If the token is invalid or expired
|
||||
"""
|
||||
settings = get_settings()
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
settings.JWT_VERIFY_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience="authenticated",
|
||||
)
|
||||
@@ -25,3 +63,18 @@ def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise ValueError(f"Invalid token: {str(e)}")
|
||||
|
||||
|
||||
def verify_user(jwt_payload: dict | None, admin_only: bool) -> User:
|
||||
if jwt_payload is None:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
user_id = jwt_payload.get("sub")
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||
|
||||
if admin_only and jwt_payload["role"] != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return User.from_payload(jwt_payload)
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Comprehensive tests for JWT token parsing and validation.
|
||||
Ensures 100% line and branch coverage for JWT security functions.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth import config, jwt_utils
|
||||
from autogpt_libs.auth.config import Settings
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
MOCK_JWT_SECRET = "test-secret-key-with-at-least-32-characters"
|
||||
TEST_USER_PAYLOAD = {
|
||||
"sub": "test-user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
TEST_ADMIN_PAYLOAD = {
|
||||
"sub": "admin-user-id",
|
||||
"role": "admin",
|
||||
"aud": "authenticated",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_config(mocker: MockerFixture):
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": MOCK_JWT_SECRET}, clear=True)
|
||||
mocker.patch.object(config, "_settings", Settings())
|
||||
yield
|
||||
|
||||
|
||||
def create_token(payload, secret=None, algorithm="HS256"):
|
||||
"""Helper to create JWT tokens."""
|
||||
if secret is None:
|
||||
secret = MOCK_JWT_SECRET
|
||||
return jwt.encode(payload, secret, algorithm=algorithm)
|
||||
|
||||
|
||||
def test_parse_jwt_token_valid():
|
||||
"""Test parsing a valid JWT token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
result = jwt_utils.parse_jwt_token(token)
|
||||
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
assert result["aud"] == "authenticated"
|
||||
|
||||
|
||||
def test_parse_jwt_token_expired():
|
||||
"""Test parsing an expired JWT token."""
|
||||
expired_payload = {
|
||||
**TEST_USER_PAYLOAD,
|
||||
"exp": datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
}
|
||||
token = create_token(expired_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Token has expired" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_invalid_signature():
|
||||
"""Test parsing a token with invalid signature."""
|
||||
# Create token with different secret
|
||||
token = create_token(TEST_USER_PAYLOAD, secret="wrong-secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_malformed():
|
||||
"""Test parsing a malformed token."""
|
||||
malformed_tokens = [
|
||||
"not.a.token",
|
||||
"invalid",
|
||||
"",
|
||||
# Header only
|
||||
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9",
|
||||
# No signature
|
||||
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0",
|
||||
]
|
||||
|
||||
for token in malformed_tokens:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_wrong_audience():
|
||||
"""Test parsing a token with wrong audience."""
|
||||
wrong_aud_payload = {**TEST_USER_PAYLOAD, "aud": "wrong-audience"}
|
||||
token = create_token(wrong_aud_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_missing_audience():
|
||||
"""Test parsing a token without audience claim."""
|
||||
no_aud_payload = {k: v for k, v in TEST_USER_PAYLOAD.items() if k != "aud"}
|
||||
token = create_token(no_aud_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_with_valid_user():
|
||||
"""Test verifying a valid user."""
|
||||
user = jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=False)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "test-user-id"
|
||||
assert user.role == "user"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
|
||||
def test_verify_user_with_admin():
|
||||
"""Test verifying an admin user."""
|
||||
user = jwt_utils.verify_user(TEST_ADMIN_PAYLOAD, admin_only=True)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "admin-user-id"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_admin_only_with_regular_user():
|
||||
"""Test verifying regular user when admin is required."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=True)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_no_payload():
|
||||
"""Test verifying user with no payload."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(None, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_missing_sub():
|
||||
"""Test verifying user with payload missing 'sub' field."""
|
||||
invalid_payload = {"role": "user", "email": "test@example.com"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_empty_sub():
|
||||
"""Test verifying user with empty 'sub' field."""
|
||||
invalid_payload = {"sub": "", "role": "user"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_none_sub():
|
||||
"""Test verifying user with None 'sub' field."""
|
||||
invalid_payload = {"sub": None, "role": "user"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_missing_role_admin_check():
|
||||
"""Test verifying admin when role field is missing."""
|
||||
no_role_payload = {"sub": "user-id"}
|
||||
with pytest.raises(KeyError):
|
||||
# This will raise KeyError when checking payload["role"]
|
||||
jwt_utils.verify_user(no_role_payload, admin_only=True)
|
||||
|
||||
|
||||
# ======================== EDGE CASES ======================== #
|
||||
|
||||
|
||||
def test_jwt_with_additional_claims():
|
||||
"""Test JWT token with additional custom claims."""
|
||||
extra_claims_payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"custom_claim": "custom_value",
|
||||
"permissions": ["read", "write"],
|
||||
"metadata": {"key": "value"},
|
||||
}
|
||||
token = create_token(extra_claims_payload)
|
||||
|
||||
result = jwt_utils.parse_jwt_token(token)
|
||||
assert result["sub"] == "user-id"
|
||||
assert result["custom_claim"] == "custom_value"
|
||||
assert result["permissions"] == ["read", "write"]
|
||||
|
||||
|
||||
def test_jwt_with_numeric_sub():
|
||||
"""Test JWT token with numeric user ID."""
|
||||
payload = {
|
||||
"sub": 12345, # Numeric ID
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
# Should convert to string internally
|
||||
user = jwt_utils.verify_user(payload, admin_only=False)
|
||||
assert user.user_id == 12345
|
||||
|
||||
|
||||
def test_jwt_with_very_long_sub():
|
||||
"""Test JWT token with very long user ID."""
|
||||
long_id = "a" * 1000
|
||||
payload = {
|
||||
"sub": long_id,
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
user = jwt_utils.verify_user(payload, admin_only=False)
|
||||
assert user.user_id == long_id
|
||||
|
||||
|
||||
def test_jwt_with_special_characters_in_claims():
|
||||
"""Test JWT token with special characters in claims."""
|
||||
payload = {
|
||||
"sub": "user@example.com/special-chars!@#$%",
|
||||
"role": "admin",
|
||||
"aud": "authenticated",
|
||||
"email": "test+special@example.com",
|
||||
}
|
||||
user = jwt_utils.verify_user(payload, admin_only=True)
|
||||
assert "special-chars!@#$%" in user.user_id
|
||||
|
||||
|
||||
def test_jwt_with_future_iat():
|
||||
"""Test JWT token with issued-at time in future."""
|
||||
future_payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"iat": datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
}
|
||||
token = create_token(future_payload)
|
||||
|
||||
# PyJWT validates iat claim and should reject future tokens
|
||||
with pytest.raises(ValueError, match="not yet valid"):
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
|
||||
|
||||
def test_jwt_with_different_algorithms():
|
||||
"""Test that only HS256 algorithm is accepted."""
|
||||
payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
|
||||
# Try different algorithms
|
||||
algorithms = ["HS384", "HS512", "none"]
|
||||
for algo in algorithms:
|
||||
if algo == "none":
|
||||
# Special case for 'none' algorithm (security vulnerability if accepted)
|
||||
token = create_token(payload, "", algorithm="none")
|
||||
else:
|
||||
token = create_token(payload, algorithm=algo)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
@@ -1,140 +0,0 @@
|
||||
import inspect
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Security
|
||||
from fastapi.security import APIKeyHeader, HTTPBearer
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
from .config import settings
|
||||
from .jwt_utils import parse_jwt_token
|
||||
|
||||
security = HTTPBearer()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def auth_middleware(request: Request):
|
||||
if not settings.ENABLE_AUTH:
|
||||
# If authentication is disabled, allow the request to proceed
|
||||
logger.warning("Auth disabled")
|
||||
return {}
|
||||
|
||||
security = HTTPBearer()
|
||||
credentials = await security(request)
|
||||
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
request.state.user = payload
|
||||
logger.debug("Token decoded successfully")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
return payload
|
||||
|
||||
|
||||
class APIKeyValidator:
|
||||
"""
|
||||
Configurable API key validator that supports custom validation functions
|
||||
for FastAPI applications.
|
||||
|
||||
This class provides a flexible way to implement API key authentication with optional
|
||||
custom validation logic. It can be used for simple token matching
|
||||
or more complex validation scenarios like database lookups.
|
||||
|
||||
Examples:
|
||||
Simple token validation:
|
||||
```python
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="your-secret-token"
|
||||
)
|
||||
|
||||
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
|
||||
def protected_endpoint():
|
||||
return {"message": "Access granted"}
|
||||
```
|
||||
|
||||
Custom validation with database lookup:
|
||||
```python
|
||||
async def validate_with_db(api_key: str):
|
||||
api_key_obj = await db.get_api_key(api_key)
|
||||
return api_key_obj if api_key_obj and api_key_obj.is_active else None
|
||||
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
validate_fn=validate_with_db
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
header_name (str): The name of the header containing the API key
|
||||
expected_token (Optional[str]): The expected API key value for simple token matching
|
||||
validate_fn (Optional[Callable]): Custom validation function that takes an API key
|
||||
string and returns a boolean or object. Can be async.
|
||||
error_status (int): HTTP status code to use for validation errors
|
||||
error_message (str): Error message to return when validation fails
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
header_name: str,
|
||||
expected_token: Optional[str] = None,
|
||||
validate_fn: Optional[Callable[[str], bool]] = None,
|
||||
error_status: int = HTTP_401_UNAUTHORIZED,
|
||||
error_message: str = "Invalid API key",
|
||||
):
|
||||
# Create the APIKeyHeader as a class property
|
||||
self.security_scheme = APIKeyHeader(name=header_name)
|
||||
self.expected_token = expected_token
|
||||
self.custom_validate_fn = validate_fn
|
||||
self.error_status = error_status
|
||||
self.error_message = error_message
|
||||
|
||||
async def default_validator(self, api_key: str) -> bool:
|
||||
if not self.expected_token:
|
||||
raise ValueError(
|
||||
"Expected Token Required to be set when uisng API Key Validator default validation"
|
||||
)
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, api_key: str = Security(APIKeyHeader)
|
||||
) -> Any:
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=self.error_status, detail="Missing API key")
|
||||
|
||||
# Use custom validation if provided, otherwise use default equality check
|
||||
validator = self.custom_validate_fn or self.default_validator
|
||||
result = (
|
||||
await validator(api_key)
|
||||
if inspect.iscoroutinefunction(validator)
|
||||
else validator(api_key)
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=self.error_status, detail=self.error_message
|
||||
)
|
||||
|
||||
# Store validation result in request state if it's not just a boolean
|
||||
if result is not True:
|
||||
request.state.api_key = result
|
||||
|
||||
return result
|
||||
|
||||
def get_dependency(self):
|
||||
"""
|
||||
Returns a callable dependency that FastAPI will recognize as a security scheme
|
||||
"""
|
||||
|
||||
async def validate_api_key(
|
||||
request: Request, api_key: str = Security(self.security_scheme)
|
||||
) -> Any:
|
||||
return await self(request, api_key)
|
||||
|
||||
# This helps FastAPI recognize it as a security dependency
|
||||
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
|
||||
return validate_api_key
|
||||
379
autogpt_platform/autogpt_libs/poetry.lock
generated
379
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -54,7 +54,7 @@ version = "1.2.0"
|
||||
description = "Backport of asyncio.Runner, a context manager that controls event loop life cycle."
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
@@ -85,6 +85,87 @@ files = [
|
||||
{file = "certifi-2025.7.14.tar.gz", hash = "sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cffi"
|
||||
version = "1.17.1"
|
||||
description = "Foreign Function Interface for Python calling C code."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "platform_python_implementation != \"PyPy\""
|
||||
files = [
|
||||
{file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c"},
|
||||
{file = "cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655"},
|
||||
{file = "cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65"},
|
||||
{file = "cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d"},
|
||||
{file = "cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-win32.whl", hash = "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8"},
|
||||
{file = "cffi-1.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7"},
|
||||
{file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"},
|
||||
{file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pycparser = "*"
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.2"
|
||||
@@ -208,12 +289,176 @@ version = "0.4.6"
|
||||
description = "Cross-platform colored terminal text."
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.10.5"
|
||||
description = "Code coverage measurement for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "coverage-7.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c6a5c3414bfc7451b879141ce772c546985163cf553f08e0f135f0699a911801"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bc8e4d99ce82f1710cc3c125adc30fd1487d3cf6c2cd4994d78d68a47b16989a"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:02252dc1216e512a9311f596b3169fad54abcb13827a8d76d5630c798a50a754"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:73269df37883e02d460bee0cc16be90509faea1e3bd105d77360b512d5bb9c33"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f8a81b0614642f91c9effd53eec284f965577591f51f547a1cbeb32035b4c2f"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6a29f8e0adb7f8c2b95fa2d4566a1d6e6722e0a637634c6563cb1ab844427dd9"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fcf6ab569436b4a647d4e91accba12509ad9f2554bc93d3aee23cc596e7f99c3"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:90dc3d6fb222b194a5de60af8d190bedeeddcbc7add317e4a3cd333ee6b7c879"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-win32.whl", hash = "sha256:414a568cd545f9dc75f0686a0049393de8098414b58ea071e03395505b73d7a8"},
|
||||
{file = "coverage-7.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:e551f9d03347196271935fd3c0c165f0e8c049220280c1120de0084d65e9c7ff"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c177e6ffe2ebc7c410785307758ee21258aa8e8092b44d09a2da767834f075f2"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:14d6071c51ad0f703d6440827eaa46386169b5fdced42631d5a5ac419616046f"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:61f78c7c3bc272a410c5ae3fde7792b4ffb4acc03d35a7df73ca8978826bb7ab"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f39071caa126f69d63f99b324fb08c7b1da2ec28cbb1fe7b5b1799926492f65c"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343a023193f04d46edc46b2616cdbee68c94dd10208ecd3adc56fcc54ef2baa1"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:585ffe93ae5894d1ebdee69fc0b0d4b7c75d8007983692fb300ac98eed146f78"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b0ef4e66f006ed181df29b59921bd8fc7ed7cd6a9289295cd8b2824b49b570df"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eb7b0bbf7cc1d0453b843eca7b5fa017874735bef9bfdfa4121373d2cc885ed6"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-win32.whl", hash = "sha256:1d043a8a06987cc0c98516e57c4d3fc2c1591364831e9deb59c9e1b4937e8caf"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:fefafcca09c3ac56372ef64a40f5fe17c5592fab906e0fdffd09543f3012ba50"},
|
||||
{file = "coverage-7.10.5-cp311-cp311-win_arm64.whl", hash = "sha256:7e78b767da8b5fc5b2faa69bb001edafcd6f3995b42a331c53ef9572c55ceb82"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c2d05c7e73c60a4cecc7d9b60dbfd603b4ebc0adafaef371445b47d0f805c8a9"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:32ddaa3b2c509778ed5373b177eb2bf5662405493baeff52278a0b4f9415188b"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:dd382410039fe062097aa0292ab6335a3f1e7af7bba2ef8d27dcda484918f20c"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7fa22800f3908df31cea6fb230f20ac49e343515d968cc3a42b30d5c3ebf9b5a"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f366a57ac81f5e12797136552f5b7502fa053c861a009b91b80ed51f2ce651c6"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5f1dc8f1980a272ad4a6c84cba7981792344dad33bf5869361576b7aef42733a"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:2285c04ee8676f7938b02b4936d9b9b672064daab3187c20f73a55f3d70e6b4a"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c2492e4dd9daab63f5f56286f8a04c51323d237631eb98505d87e4c4ff19ec34"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-win32.whl", hash = "sha256:38a9109c4ee8135d5df5505384fc2f20287a47ccbe0b3f04c53c9a1989c2bbaf"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:6b87f1ad60b30bc3c43c66afa7db6b22a3109902e28c5094957626a0143a001f"},
|
||||
{file = "coverage-7.10.5-cp312-cp312-win_arm64.whl", hash = "sha256:672a6c1da5aea6c629819a0e1461e89d244f78d7b60c424ecf4f1f2556c041d8"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ef3b83594d933020f54cf65ea1f4405d1f4e41a009c46df629dd964fcb6e907c"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2b96bfdf7c0ea9faebce088a3ecb2382819da4fbc05c7b80040dbc428df6af44"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:63df1fdaffa42d914d5c4d293e838937638bf75c794cf20bee12978fc8c4e3bc"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8002dc6a049aac0e81ecec97abfb08c01ef0c1fbf962d0c98da3950ace89b869"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63d4bb2966d6f5f705a6b0c6784c8969c468dbc4bcf9d9ded8bff1c7e092451f"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1f672efc0731a6846b157389b6e6d5d5e9e59d1d1a23a5c66a99fd58339914d5"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:3f39cef43d08049e8afc1fde4a5da8510fc6be843f8dea350ee46e2a26b2f54c"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2968647e3ed5a6c019a419264386b013979ff1fb67dd11f5c9886c43d6a31fc2"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-win32.whl", hash = "sha256:0d511dda38595b2b6934c2b730a1fd57a3635c6aa2a04cb74714cdfdd53846f4"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:9a86281794a393513cf117177fd39c796b3f8e3759bb2764259a2abba5cce54b"},
|
||||
{file = "coverage-7.10.5-cp313-cp313-win_arm64.whl", hash = "sha256:cebd8e906eb98bb09c10d1feed16096700b1198d482267f8bf0474e63a7b8d84"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0520dff502da5e09d0d20781df74d8189ab334a1e40d5bafe2efaa4158e2d9e7"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d9cd64aca68f503ed3f1f18c7c9174cbb797baba02ca8ab5112f9d1c0328cd4b"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0913dd1613a33b13c4f84aa6e3f4198c1a21ee28ccb4f674985c1f22109f0aae"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1b7181c0feeb06ed8a02da02792f42f829a7b29990fef52eff257fef0885d760"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36d42b7396b605f774d4372dd9c49bed71cbabce4ae1ccd074d155709dd8f235"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b4fdc777e05c4940b297bf47bf7eedd56a39a61dc23ba798e4b830d585486ca5"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:42144e8e346de44a6f1dbd0a56575dd8ab8dfa7e9007da02ea5b1c30ab33a7db"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:66c644cbd7aed8fe266d5917e2c9f65458a51cfe5eeff9c05f15b335f697066e"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-win32.whl", hash = "sha256:2d1b73023854068c44b0c554578a4e1ef1b050ed07cf8b431549e624a29a66ee"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-win_amd64.whl", hash = "sha256:54a1532c8a642d8cc0bd5a9a51f5a9dcc440294fd06e9dda55e743c5ec1a8f14"},
|
||||
{file = "coverage-7.10.5-cp313-cp313t-win_arm64.whl", hash = "sha256:74d5b63fe3f5f5d372253a4ef92492c11a4305f3550631beaa432fc9df16fcff"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:68c5e0bc5f44f68053369fa0d94459c84548a77660a5f2561c5e5f1e3bed7031"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:cf33134ffae93865e32e1e37df043bef15a5e857d8caebc0099d225c579b0fa3"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ad8fa9d5193bafcf668231294241302b5e683a0518bf1e33a9a0dfb142ec3031"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:146fa1531973d38ab4b689bc764592fe6c2f913e7e80a39e7eeafd11f0ef6db2"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6013a37b8a4854c478d3219ee8bc2392dea51602dd0803a12d6f6182a0061762"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:eb90fe20db9c3d930fa2ad7a308207ab5b86bf6a76f54ab6a40be4012d88fcae"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:384b34482272e960c438703cafe63316dfbea124ac62006a455c8410bf2a2262"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:467dc74bd0a1a7de2bedf8deaf6811f43602cb532bd34d81ffd6038d6d8abe99"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-win32.whl", hash = "sha256:556d23d4e6393ca898b2e63a5bca91e9ac2d5fb13299ec286cd69a09a7187fde"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-win_amd64.whl", hash = "sha256:f4446a9547681533c8fa3e3c6cf62121eeee616e6a92bd9201c6edd91beffe13"},
|
||||
{file = "coverage-7.10.5-cp314-cp314-win_arm64.whl", hash = "sha256:5e78bd9cf65da4c303bf663de0d73bf69f81e878bf72a94e9af67137c69b9fe9"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:5661bf987d91ec756a47c7e5df4fbcb949f39e32f9334ccd3f43233bbb65e508"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a46473129244db42a720439a26984f8c6f834762fc4573616c1f37f13994b357"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1f64b8d3415d60f24b058b58d859e9512624bdfa57a2d1f8aff93c1ec45c429b"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:44d43de99a9d90b20e0163f9770542357f58860a26e24dc1d924643bd6aa7cb4"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a931a87e5ddb6b6404e65443b742cb1c14959622777f2a4efd81fba84f5d91ba"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f9559b906a100029274448f4c8b8b0a127daa4dade5661dfd821b8c188058842"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:b08801e25e3b4526ef9ced1aa29344131a8f5213c60c03c18fe4c6170ffa2874"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ed9749bb8eda35f8b636fb7632f1c62f735a236a5d4edadd8bbcc5ea0542e732"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-win32.whl", hash = "sha256:609b60d123fc2cc63ccee6d17e4676699075db72d14ac3c107cc4976d516f2df"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-win_amd64.whl", hash = "sha256:0666cf3d2c1626b5a3463fd5b05f5e21f99e6aec40a3192eee4d07a15970b07f"},
|
||||
{file = "coverage-7.10.5-cp314-cp314t-win_arm64.whl", hash = "sha256:bc85eb2d35e760120540afddd3044a5bf69118a91a296a8b3940dfc4fdcfe1e2"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:62835c1b00c4a4ace24c1a88561a5a59b612fbb83a525d1c70ff5720c97c0610"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5255b3bbcc1d32a4069d6403820ac8e6dbcc1d68cb28a60a1ebf17e47028e898"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3876385722e335d6e991c430302c24251ef9c2a9701b2b390f5473199b1b8ebf"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8048ce4b149c93447a55d279078c8ae98b08a6951a3c4d2d7e87f4efc7bfe100"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4028e7558e268dd8bcf4d9484aad393cafa654c24b4885f6f9474bf53183a82a"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03f47dc870eec0367fcdd603ca6a01517d2504e83dc18dbfafae37faec66129a"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2d488d7d42b6ded7ea0704884f89dcabd2619505457de8fc9a6011c62106f6e5"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b3dcf2ead47fa8be14224ee817dfc1df98043af568fe120a22f81c0eb3c34ad2"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-win32.whl", hash = "sha256:02650a11324b80057b8c9c29487020073d5e98a498f1857f37e3f9b6ea1b2426"},
|
||||
{file = "coverage-7.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:b45264dd450a10f9e03237b41a9a24e85cbb1e278e5a32adb1a303f58f0017f3"},
|
||||
{file = "coverage-7.10.5-py3-none-any.whl", hash = "sha256:0be24d35e4db1d23d0db5c0f6a74a962e2ec83c426b5cac09f4234aadef38e4a"},
|
||||
{file = "coverage-7.10.5.tar.gz", hash = "sha256:f2e57716a78bc3ae80b2207be0709a3b2b63b9f2dcf9740ee6ac03588a2015b6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
|
||||
|
||||
[package.extras]
|
||||
toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "45.0.6"
|
||||
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||
optional = false
|
||||
python-versions = "!=3.9.0,!=3.9.1,>=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "cryptography-45.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:048e7ad9e08cf4c0ab07ff7f36cc3115924e22e2266e034450a890d9e312dd74"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:44647c5d796f5fc042bbc6d61307d04bf29bccb74d188f18051b635f20a9c75f"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e40b80ecf35ec265c452eea0ba94c9587ca763e739b8e559c128d23bff7ebbbf"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:00e8724bdad672d75e6f069b27970883179bd472cd24a63f6e620ca7e41cc0c5"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7a3085d1b319d35296176af31c90338eeb2ddac8104661df79f80e1d9787b8b2"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1b7fa6a1c1188c7ee32e47590d16a5a0646270921f8020efc9a511648e1b2e08"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:275ba5cc0d9e320cd70f8e7b96d9e59903c815ca579ab96c1e37278d231fc402"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f4028f29a9f38a2025abedb2e409973709c660d44319c61762202206ed577c42"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ee411a1b977f40bd075392c80c10b58025ee5c6b47a822a33c1198598a7a5f05"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e2a21a8eda2d86bb604934b6b37691585bd095c1f788530c1fcefc53a82b3453"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-win32.whl", hash = "sha256:d063341378d7ee9c91f9d23b431a3502fc8bfacd54ef0a27baa72a0843b29159"},
|
||||
{file = "cryptography-45.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:833dc32dfc1e39b7376a87b9a6a4288a10aae234631268486558920029b086ec"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:3436128a60a5e5490603ab2adbabc8763613f638513ffa7d311c900a8349a2a0"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0d9ef57b6768d9fa58e92f4947cea96ade1233c0e236db22ba44748ffedca394"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea3c42f2016a5bbf71825537c2ad753f2870191134933196bee408aac397b3d9"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:20ae4906a13716139d6d762ceb3e0e7e110f7955f3bc3876e3a07f5daadec5f3"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2dac5ec199038b8e131365e2324c03d20e97fe214af051d20c49db129844e8b3"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:18f878a34b90d688982e43f4b700408b478102dd58b3e39de21b5ebf6509c301"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:5bd6020c80c5b2b2242d6c48487d7b85700f5e0038e67b29d706f98440d66eb5"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:eccddbd986e43014263eda489abbddfbc287af5cddfd690477993dbb31e31016"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:550ae02148206beb722cfe4ef0933f9352bab26b087af00e48fdfb9ade35c5b3"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5b64e668fc3528e77efa51ca70fadcd6610e8ab231e3e06ae2bab3b31c2b8ed9"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-win32.whl", hash = "sha256:780c40fb751c7d2b0c6786ceee6b6f871e86e8718a8ff4bc35073ac353c7cd02"},
|
||||
{file = "cryptography-45.0.6-cp37-abi3-win_amd64.whl", hash = "sha256:20d15aed3ee522faac1a39fbfdfee25d17b1284bafd808e1640a74846d7c4d1b"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:705bb7c7ecc3d79a50f236adda12ca331c8e7ecfbea51edd931ce5a7a7c4f012"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:826b46dae41a1155a0c0e66fafba43d0ede1dc16570b95e40c4d83bfcf0a451d"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:cc4d66f5dc4dc37b89cfef1bd5044387f7a1f6f0abb490815628501909332d5d"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:f68f833a9d445cc49f01097d95c83a850795921b3f7cc6488731e69bde3288da"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3b5bf5267e98661b9b888a9250d05b063220dfa917a8203744454573c7eb79db"},
|
||||
{file = "cryptography-45.0.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2384f2ab18d9be88a6e4f8972923405e2dbb8d3e16c6b43f15ca491d7831bd18"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fc022c1fa5acff6def2fc6d7819bbbd31ccddfe67d075331a65d9cfb28a20983"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3de77e4df42ac8d4e4d6cdb342d989803ad37707cf8f3fbf7b088c9cbdd46427"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:599c8d7df950aa68baa7e98f7b73f4f414c9f02d0e8104a30c0182a07732638b"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:31a2b9a10530a1cb04ffd6aa1cd4d3be9ed49f7d77a4dafe198f3b382f41545c"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:e5b3dda1b00fb41da3af4c5ef3f922a200e33ee5ba0f0bc9ecf0b0c173958385"},
|
||||
{file = "cryptography-45.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:629127cfdcdc6806dfe234734d7cb8ac54edaf572148274fa377a7d3405b0043"},
|
||||
{file = "cryptography-45.0.6.tar.gz", hash = "sha256:5c966c732cf6e4a276ce83b6e4c729edda2df6929083a952cc7da973c539c719"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cffi = {version = ">=1.14", markers = "platform_python_implementation != \"PyPy\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3.0)", "sphinx-inline-tabs ; python_full_version >= \"3.8.0\"", "sphinx-rtd-theme (>=3.0.0) ; python_full_version >= \"3.8.0\""]
|
||||
docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"]
|
||||
nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_full_version >= \"3.8.0\""]
|
||||
pep8test = ["check-sdist ; python_full_version >= \"3.8.0\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"]
|
||||
sdist = ["build (>=1.0.0)"]
|
||||
ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==45.0.6)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test-randomorder = ["pytest-randomly"]
|
||||
|
||||
[[package]]
|
||||
name = "deprecation"
|
||||
version = "2.1.0"
|
||||
@@ -235,7 +480,7 @@ version = "1.3.0"
|
||||
description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||
@@ -710,7 +955,7 @@ version = "2.1.0"
|
||||
description = "brain-dead simple config-ini parsing"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"},
|
||||
{file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
|
||||
@@ -757,6 +1002,18 @@ dynamodb = ["boto3 (>=1.9.71)"]
|
||||
redis = ["redis (>=2.10.5)"]
|
||||
test-filesource = ["pyyaml (>=5.3.1)", "watchdog (>=3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
description = "Node.js virtual environment builder"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"},
|
||||
{file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.35.0"
|
||||
@@ -779,7 +1036,7 @@ version = "25.0"
|
||||
description = "Core utilities for Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
|
||||
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
|
||||
@@ -791,7 +1048,7 @@ version = "1.6.0"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"},
|
||||
{file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"},
|
||||
@@ -883,6 +1140,19 @@ files = [
|
||||
[package.dependencies]
|
||||
pyasn1 = ">=0.6.1,<0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.22"
|
||||
description = "C parser in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "platform_python_implementation != \"PyPy\""
|
||||
files = [
|
||||
{file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"},
|
||||
{file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.11.7"
|
||||
@@ -1047,7 +1317,7 @@ version = "2.19.2"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"},
|
||||
{file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"},
|
||||
@@ -1068,6 +1338,9 @@ files = [
|
||||
{file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"crypto\""}
|
||||
|
||||
[package.extras]
|
||||
crypto = ["cryptography (>=3.4.0)"]
|
||||
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"]
|
||||
@@ -1086,13 +1359,34 @@ files = [
|
||||
{file = "pyrfc3339-2.0.1.tar.gz", hash = "sha256:e47843379ea35c1296c3b6c67a948a1a490ae0584edfcbdea0eaffb5dd29960b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.404"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pyright-1.1.404-py3-none-any.whl", hash = "sha256:c7b7ff1fdb7219c643079e4c3e7d4125f0dafcc19d253b47e898d130ea426419"},
|
||||
{file = "pyright-1.1.404.tar.gz", hash = "sha256:455e881a558ca6be9ecca0b30ce08aa78343ecc031d37a198ffa9a7a1abeb63e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nodeenv = ">=1.6.0"
|
||||
typing-extensions = ">=4.1"
|
||||
|
||||
[package.extras]
|
||||
all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
|
||||
dev = ["twine (>=3.4.1)"]
|
||||
nodejs = ["nodejs-wheel-binaries"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.1"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7"},
|
||||
{file = "pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c"},
|
||||
@@ -1116,7 +1410,7 @@ version = "1.1.0"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"},
|
||||
{file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"},
|
||||
@@ -1130,13 +1424,33 @@ pytest = ">=8.2,<9"
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "6.2.1"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5"},
|
||||
{file = "pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
coverage = {version = ">=7.5", extras = ["toml"]}
|
||||
pluggy = ">=1.2"
|
||||
pytest = ">=6.2.5"
|
||||
|
||||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-mock"
|
||||
version = "3.14.1"
|
||||
description = "Thin-wrapper around the mock package for easier use with pytest"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0"},
|
||||
{file = "pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e"},
|
||||
@@ -1253,30 +1567,31 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.3"
|
||||
version = "0.12.11"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2"},
|
||||
{file = "ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041"},
|
||||
{file = "ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f"},
|
||||
{file = "ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d"},
|
||||
{file = "ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7"},
|
||||
{file = "ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1"},
|
||||
{file = "ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77"},
|
||||
{file = "ruff-0.12.11-py3-none-linux_armv6l.whl", hash = "sha256:93fce71e1cac3a8bf9200e63a38ac5c078f3b6baebffb74ba5274fb2ab276065"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8e33ac7b28c772440afa80cebb972ffd823621ded90404f29e5ab6d1e2d4b93"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d69fb9d4937aa19adb2e9f058bc4fbfe986c2040acb1a4a9747734834eaa0bfd"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:411954eca8464595077a93e580e2918d0a01a19317af0a72132283e28ae21bee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a2c0a2e1a450f387bf2c6237c727dd22191ae8c00e448e0672d624b2bbd7fb0"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ca4c3a7f937725fd2413c0e884b5248a19369ab9bdd850b5781348ba283f644"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4d1df0098124006f6a66ecf3581a7f7e754c4df7644b2e6704cd7ca80ff95211"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a8dd5f230efc99a24ace3b77e3555d3fbc0343aeed3fc84c8d89e75ab2ff793"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc75533039d0ed04cd33fb8ca9ac9620b99672fe7ff1533b6402206901c34ee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fc58f9266d62c6eccc75261a665f26b4ef64840887fc6cbc552ce5b29f96cc8"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5a0113bd6eafd545146440225fe60b4e9489f59eb5f5f107acd715ba5f0b3d2f"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0d737b4059d66295c3ea5720e6efc152623bb83fde5444209b69cd33a53e2000"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:916fc5defee32dbc1fc1650b576a8fed68f5e8256e2180d4d9855aea43d6aab2"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c984f07d7adb42d3ded5be894fb4007f30f82c87559438b4879fe7aa08c62b39"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e07fbb89f2e9249f219d88331c833860489b49cdf4b032b8e4432e9b13e8a4b9"},
|
||||
{file = "ruff-0.12.11-py3-none-win32.whl", hash = "sha256:c792e8f597c9c756e9bcd4d87cf407a00b60af77078c96f7b6366ea2ce9ba9d3"},
|
||||
{file = "ruff-0.12.11-py3-none-win_amd64.whl", hash = "sha256:a3283325960307915b6deb3576b96919ee89432ebd9c48771ca12ee8afe4a0fd"},
|
||||
{file = "ruff-0.12.11-py3-none-win_arm64.whl", hash = "sha256:bae4d6e6a2676f8fb0f98b74594a048bae1b944aab17e9f5d504062303c6dbea"},
|
||||
{file = "ruff-0.12.11.tar.gz", hash = "sha256:c6b09ae8426a65bbee5425b9d0b82796dbb07cb1af045743c79bfb163001165d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1410,7 +1725,7 @@ version = "2.2.1"
|
||||
description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
@@ -1453,7 +1768,7 @@ version = "4.14.1"
|
||||
description = "Backported and Experimental Type Hints for Python 3.9+"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
|
||||
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
|
||||
@@ -1614,4 +1929,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "f67db13e6f68b1d67a55eee908c1c560bfa44da8509f98f842889a7570a9830f"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
|
||||
@@ -9,21 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.12.3"
|
||||
pyright = "^1.1.404"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -16,7 +16,6 @@ DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
ENABLE_AUTH=true
|
||||
|
||||
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
|
||||
# Redis Configuration
|
||||
@@ -31,7 +30,7 @@ RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
# Supabase Authentication
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
@@ -106,6 +105,15 @@ TODOIST_CLIENT_SECRET=
|
||||
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
|
||||
# Discord OAuth App credentials
|
||||
# 1. Go to https://discord.com/developers/applications
|
||||
# 2. Create a new application
|
||||
# 3. Go to OAuth2 section and add redirect URI: http://localhost:3000/auth/integrations/oauth_callback
|
||||
# 4. Copy Client ID and Client Secret below
|
||||
DISCORD_CLIENT_ID=
|
||||
DISCORD_CLIENT_SECRET=
|
||||
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
|
||||
@@ -166,4 +174,4 @@ SMARTLEAD_API_KEY=
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
AUTOMOD_API_KEY=
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
FROM python:3.11.10-slim-bookworm AS builder
|
||||
FROM debian:13-slim AS builder
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Update package list and install build dependencies in a single layer
|
||||
# Update package list and install Python and build dependencies
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
python3.13-venv \
|
||||
python3-pip \
|
||||
build-essential \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
@@ -19,13 +24,11 @@ RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV POETRY_VIRTUALENVS_CREATE=true
|
||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
RUN pip3 install poetry
|
||||
RUN pip3 install poetry --break-system-packages
|
||||
|
||||
# Copy and install dependencies
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
@@ -37,27 +40,30 @@ RUN poetry install --no-ansi --no-root
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM python:3.11.10-slim-bookworm AS server_dependencies
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false
|
||||
POETRY_VIRTUALENVS_CREATE=true \
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3.11 /usr/local/lib/python3.11
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||
RUN mkdir -p /app/autogpt_platform/backend
|
||||
|
||||
@@ -132,17 +132,58 @@ def test_endpoint_success(snapshot: Snapshot):
|
||||
|
||||
### Testing with Authentication
|
||||
|
||||
For the main API routes that use JWT authentication, auth is provided by the `autogpt_libs.auth` module. If the test actually uses the `user_id`, the recommended approach for testing is to mock the `get_jwt_payload` function, which underpins all higher-level auth functions used in the API (`requires_user`, `requires_admin_user`, `get_user_id`).
|
||||
|
||||
If the test doesn't need the `user_id` specifically, mocking is not necessary as during tests auth is disabled anyway (see `conftest.py`).
|
||||
|
||||
#### Using Global Auth Fixtures
|
||||
|
||||
Two global auth fixtures are provided by `backend/server/conftest.py`:
|
||||
|
||||
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
||||
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
||||
|
||||
These provide the easiest way to set up authentication mocking in test modules:
|
||||
|
||||
```python
|
||||
def override_auth_middleware():
|
||||
return {"sub": "test-user-id"}
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
def override_get_user_id():
|
||||
return "test-user-id"
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
app.dependency_overrides[auth_middleware] = override_auth_middleware
|
||||
app.dependency_overrides[get_user_id] = override_get_user_id
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user['get_jwt_payload']
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
```
|
||||
|
||||
For admin-only endpoints, use `mock_jwt_admin` instead:
|
||||
|
||||
```python
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_admin):
|
||||
"""Setup auth overrides for admin tests"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin['get_jwt_payload']
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
```
|
||||
|
||||
The IDs are also available separately as fixtures:
|
||||
|
||||
- `test_user_id`
|
||||
- `admin_user_id`
|
||||
- `target_user_id` (for admin <-> user operations)
|
||||
|
||||
### Mocking External Services
|
||||
|
||||
```python
|
||||
@@ -153,10 +194,10 @@ def test_external_api_call(mocker, snapshot):
|
||||
"backend.services.external_api.call",
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
|
||||
response = client.post("/api/process")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(
|
||||
json.dumps(response.json(), indent=2, sort_keys=True),
|
||||
@@ -187,6 +228,17 @@ def test_external_api_call(mocker, snapshot):
|
||||
- Use `async def` with `@pytest.mark.asyncio` for testing async functions directly
|
||||
|
||||
### 5. Fixtures
|
||||
|
||||
#### Global Fixtures (conftest.py)
|
||||
|
||||
Authentication fixtures are available globally from `conftest.py`:
|
||||
|
||||
- `mock_jwt_user` - Standard user authentication
|
||||
- `mock_jwt_admin` - Admin user authentication
|
||||
- `configured_snapshot` - Pre-configured snapshot fixture
|
||||
|
||||
#### Custom Fixtures
|
||||
|
||||
Create reusable fixtures for common test data:
|
||||
|
||||
```python
|
||||
@@ -202,9 +254,18 @@ def test_create_user(sample_user, snapshot):
|
||||
# ... test implementation
|
||||
```
|
||||
|
||||
#### Test Isolation
|
||||
|
||||
All tests must use fixtures that ensure proper isolation:
|
||||
|
||||
- Authentication overrides are automatically cleaned up after each test
|
||||
- Database connections are properly managed with cleanup
|
||||
- Mock objects are reset between tests
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
The GitHub Actions workflow automatically runs tests on:
|
||||
|
||||
- Pull requests
|
||||
- Pushes to main branch
|
||||
|
||||
@@ -216,16 +277,19 @@ Snapshot tests work in CI by:
|
||||
## Troubleshooting
|
||||
|
||||
### Snapshot Mismatches
|
||||
|
||||
- Review the diff carefully
|
||||
- If changes are expected: `poetry run pytest --snapshot-update`
|
||||
- If changes are unexpected: Fix the code causing the difference
|
||||
|
||||
### Async Test Issues
|
||||
|
||||
- Ensure async functions use `@pytest.mark.asyncio`
|
||||
- Use `AsyncMock` for mocking async functions
|
||||
- FastAPI TestClient handles async automatically
|
||||
|
||||
### Import Errors
|
||||
|
||||
- Check that all dependencies are in `pyproject.toml`
|
||||
- Run `poetry install` to ensure dependencies are installed
|
||||
- Verify import paths are correct
|
||||
@@ -234,4 +298,4 @@ Snapshot tests work in CI by:
|
||||
|
||||
Snapshot testing provides a powerful way to ensure API responses remain consistent. Combined with traditional assertions, it creates a robust test suite that catches regressions while remaining maintainable.
|
||||
|
||||
Remember: Good tests are as important as good code!
|
||||
Remember: Good tests are as important as good code!
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -12,7 +10,7 @@ from backend.data.block import (
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.execution import ExecutionStatus, NodesInputMasks
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
@@ -25,12 +23,15 @@ class AgentExecutorBlock(Block):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
agent_name: Optional[str] = SchemaField(
|
||||
default=None, description="Name to display in the Builder UI"
|
||||
)
|
||||
|
||||
inputs: BlockInput = SchemaField(description="Input data for the graph")
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
nodes_input_masks: Optional[NodesInputMasks] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
|
||||
|
||||
154
autogpt_platform/backend/backend/blocks/ai_image_customizer.py
Normal file
154
autogpt_platform/backend/backend/blocks/ai_image_customizer.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
JPG = "jpg"
|
||||
PNG = "png"
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class AIImageCustomizerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Google Gemini image models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="A text description of the image you want to generate",
|
||||
title="Prompt",
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
description="Optional list of input images to reference or modify",
|
||||
default=[],
|
||||
title="Input Images",
|
||||
)
|
||||
output_format: OutputFormat = SchemaField(
|
||||
description="Format of the output image",
|
||||
default=OutputFormat.PNG,
|
||||
title="Output Format",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
image_url: MediaFileType = SchemaField(description="URL of the generated image")
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
input_schema=AIImageCustomizerBlock.Input,
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"images": [],
|
||||
"output_format": OutputFormat.JPG,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||
"https://replicate.delivery/generated-image.jpg"
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.value,
|
||||
prompt=input_data.prompt,
|
||||
images=input_data.images,
|
||||
output_format=input_data.output_format.value,
|
||||
)
|
||||
yield "image_url", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
images: list[MediaFileType],
|
||||
output_format: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"output_format": output_format,
|
||||
}
|
||||
|
||||
# Add images to input if provided (API expects "image_input" parameter)
|
||||
if images:
|
||||
input_params["image_input"] = [str(img) for img in images]
|
||||
|
||||
output: FileOutput | str = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
if isinstance(output, FileOutput):
|
||||
return MediaFileType(output.url)
|
||||
if isinstance(output, str):
|
||||
return MediaFileType(output)
|
||||
|
||||
raise ValueError("No output received from the model")
|
||||
@@ -166,7 +166,7 @@ class AIMusicGeneratorBlock(Block):
|
||||
output_format=input_data.output_format,
|
||||
normalization_strategy=input_data.normalization_strategy,
|
||||
)
|
||||
if result and result != "No output received":
|
||||
if result and isinstance(result, str) and result.startswith("http"):
|
||||
yield "result", result
|
||||
return
|
||||
else:
|
||||
|
||||
@@ -661,6 +661,167 @@ async def update_field(
|
||||
#################################################################
|
||||
|
||||
|
||||
async def get_table_schema(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
table_id_or_name: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the schema for a specific table, including all field definitions.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The base ID
|
||||
table_id_or_name: The table ID or name
|
||||
|
||||
Returns:
|
||||
Dict containing table schema with fields information
|
||||
"""
|
||||
# First get all tables to find the right one
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
tables = data.get("tables", [])
|
||||
|
||||
# Find the matching table
|
||||
for table in tables:
|
||||
if table.get("id") == table_id_or_name or table.get("name") == table_id_or_name:
|
||||
return table
|
||||
|
||||
raise ValueError(f"Table '{table_id_or_name}' not found in base '{base_id}'")
|
||||
|
||||
|
||||
def get_empty_value_for_field(field_type: str) -> Any:
|
||||
"""
|
||||
Return the appropriate empty value for a given Airtable field type.
|
||||
|
||||
Args:
|
||||
field_type: The Airtable field type
|
||||
|
||||
Returns:
|
||||
The appropriate empty value for that field type
|
||||
"""
|
||||
# Fields that should be false when empty
|
||||
if field_type == "checkbox":
|
||||
return False
|
||||
|
||||
# Fields that should be empty arrays
|
||||
if field_type in [
|
||||
"multipleSelects",
|
||||
"multipleRecordLinks",
|
||||
"multipleAttachments",
|
||||
"multipleLookupValues",
|
||||
"multipleCollaborators",
|
||||
]:
|
||||
return []
|
||||
|
||||
# Fields that should be 0 when empty (numeric types)
|
||||
if field_type in [
|
||||
"number",
|
||||
"percent",
|
||||
"currency",
|
||||
"rating",
|
||||
"duration",
|
||||
"count",
|
||||
"autoNumber",
|
||||
]:
|
||||
return 0
|
||||
|
||||
# Fields that should be empty strings
|
||||
if field_type in [
|
||||
"singleLineText",
|
||||
"multilineText",
|
||||
"email",
|
||||
"url",
|
||||
"phoneNumber",
|
||||
"richText",
|
||||
"barcode",
|
||||
]:
|
||||
return ""
|
||||
|
||||
# Everything else gets null (dates, single selects, formulas, etc.)
|
||||
return None
|
||||
|
||||
|
||||
async def normalize_records(
|
||||
records: list[dict],
|
||||
table_schema: dict,
|
||||
include_field_metadata: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Normalize Airtable records to include all fields with proper empty values.
|
||||
|
||||
Args:
|
||||
records: List of record objects from Airtable API
|
||||
table_schema: Table schema containing field definitions
|
||||
include_field_metadata: Whether to include field metadata in response
|
||||
|
||||
Returns:
|
||||
Dict with normalized records and optionally field metadata
|
||||
"""
|
||||
fields = table_schema.get("fields", [])
|
||||
|
||||
# Normalize each record
|
||||
normalized_records = []
|
||||
for record in records:
|
||||
normalized = {
|
||||
"id": record.get("id"),
|
||||
"createdTime": record.get("createdTime"),
|
||||
"fields": {},
|
||||
}
|
||||
|
||||
# Add existing fields
|
||||
existing_fields = record.get("fields", {})
|
||||
|
||||
# Add all fields from schema, using empty values for missing ones
|
||||
for field in fields:
|
||||
field_name = field["name"]
|
||||
field_type = field["type"]
|
||||
|
||||
if field_name in existing_fields:
|
||||
# Field exists, use its value
|
||||
normalized["fields"][field_name] = existing_fields[field_name]
|
||||
else:
|
||||
# Field is missing, add appropriate empty value
|
||||
normalized["fields"][field_name] = get_empty_value_for_field(field_type)
|
||||
|
||||
normalized_records.append(normalized)
|
||||
|
||||
# Build result dictionary
|
||||
if include_field_metadata:
|
||||
field_metadata = {}
|
||||
for field in fields:
|
||||
metadata = {"type": field["type"], "id": field["id"]}
|
||||
|
||||
# Add type-specific metadata
|
||||
options = field.get("options", {})
|
||||
if field["type"] == "currency" and "symbol" in options:
|
||||
metadata["symbol"] = options["symbol"]
|
||||
metadata["precision"] = options.get("precision", 2)
|
||||
elif field["type"] == "duration" and "durationFormat" in options:
|
||||
metadata["format"] = options["durationFormat"]
|
||||
elif field["type"] == "percent" and "precision" in options:
|
||||
metadata["precision"] = options["precision"]
|
||||
elif (
|
||||
field["type"] in ["singleSelect", "multipleSelects"]
|
||||
and "choices" in options
|
||||
):
|
||||
metadata["choices"] = [choice["name"] for choice in options["choices"]]
|
||||
elif field["type"] == "rating" and "max" in options:
|
||||
metadata["max"] = options["max"]
|
||||
metadata["icon"] = options.get("icon", "star")
|
||||
metadata["color"] = options.get("color", "yellowBright")
|
||||
|
||||
field_metadata[field["name"]] = metadata
|
||||
|
||||
return {"records": normalized_records, "field_metadata": field_metadata}
|
||||
else:
|
||||
return {"records": normalized_records}
|
||||
|
||||
|
||||
async def list_records(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
@@ -1249,3 +1410,26 @@ async def list_bases(
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
async def get_base_tables(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get all tables for a specific base.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The ID of the base
|
||||
|
||||
Returns:
|
||||
list[dict]: List of table objects with their schemas
|
||||
"""
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
return data.get("tables", [])
|
||||
|
||||
@@ -14,13 +14,13 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import create_base, list_bases
|
||||
from ._api import create_base, get_base_tables, list_bases
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableCreateBaseBlock(Block):
|
||||
"""
|
||||
Creates a new base in an Airtable workspace.
|
||||
Creates a new base in an Airtable workspace, or returns existing base if one with the same name exists.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
@@ -31,6 +31,10 @@ class AirtableCreateBaseBlock(Block):
|
||||
description="The workspace ID where the base will be created"
|
||||
)
|
||||
name: str = SchemaField(description="The name of the new base")
|
||||
find_existing: bool = SchemaField(
|
||||
description="If true, return existing base with same name instead of creating duplicate",
|
||||
default=True,
|
||||
)
|
||||
tables: list[dict] = SchemaField(
|
||||
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
|
||||
default=[
|
||||
@@ -50,14 +54,18 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
base_id: str = SchemaField(description="The ID of the created base")
|
||||
base_id: str = SchemaField(description="The ID of the created or found base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if a new base was created, False if existing was found",
|
||||
default=True,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
|
||||
description="Create a new base in Airtable",
|
||||
description="Create or find a base in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
@@ -66,6 +74,31 @@ class AirtableCreateBaseBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# If find_existing is true, check if a base with this name already exists
|
||||
if input_data.find_existing:
|
||||
# List all bases to check for existing one with same name
|
||||
# Note: Airtable API doesn't have a direct search, so we need to list and filter
|
||||
existing_bases = await list_bases(credentials)
|
||||
|
||||
for base in existing_bases.get("bases", []):
|
||||
if base.get("name") == input_data.name:
|
||||
# Base already exists, return it
|
||||
base_id = base.get("id")
|
||||
yield "base_id", base_id
|
||||
yield "was_created", False
|
||||
|
||||
# Get the tables for this base
|
||||
try:
|
||||
tables = await get_base_tables(credentials, base_id)
|
||||
yield "tables", tables
|
||||
for table in tables:
|
||||
yield "table", table
|
||||
except Exception:
|
||||
# If we can't get tables, return empty list
|
||||
yield "tables", []
|
||||
return
|
||||
|
||||
# No existing base found or find_existing is false, create new one
|
||||
data = await create_base(
|
||||
credentials,
|
||||
input_data.workspace_id,
|
||||
@@ -74,6 +107,7 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
yield "base_id", data.get("id", None)
|
||||
yield "was_created", True
|
||||
yield "tables", data.get("tables", [])
|
||||
for table in data.get("tables", []):
|
||||
yield "table", table
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Airtable record operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
@@ -18,7 +18,9 @@ from ._api import (
|
||||
create_record,
|
||||
delete_multiple_records,
|
||||
get_record,
|
||||
get_table_schema,
|
||||
list_records,
|
||||
normalize_records,
|
||||
update_multiple_records,
|
||||
)
|
||||
from ._config import airtable
|
||||
@@ -54,12 +56,24 @@ class AirtableListRecordsBlock(Block):
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return (comma-separated)", default=[]
|
||||
)
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
)
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -73,6 +87,7 @@ class AirtableListRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
data = await list_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -88,8 +103,33 @@ class AirtableListRecordsBlock(Block):
|
||||
fields=input_data.return_fields if input_data.return_fields else None,
|
||||
)
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
records = data.get("records", [])
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
records,
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
yield "records", normalized_data["records"]
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "records", records
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
|
||||
class AirtableGetRecordBlock(Block):
|
||||
@@ -104,11 +144,23 @@ class AirtableGetRecordBlock(Block):
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -122,6 +174,7 @@ class AirtableGetRecordBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
record = await get_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -129,9 +182,34 @@ class AirtableGetRecordBlock(Block):
|
||||
input_data.record_id,
|
||||
)
|
||||
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the single record (wrap in list and unwrap result)
|
||||
normalized_data = await normalize_records(
|
||||
[record],
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
normalized_record = normalized_data["records"][0]
|
||||
yield "id", normalized_record.get("id", None)
|
||||
yield "fields", normalized_record.get("fields", None)
|
||||
yield "created_time", normalized_record.get("createdTime", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
@@ -148,6 +226,10 @@ class AirtableCreateRecordsBlock(Block):
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
skip_normalization: bool = SchemaField(
|
||||
description="Skip output normalization to get raw Airtable response (faster but may have missing fields)",
|
||||
default=False,
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
@@ -173,7 +255,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# The create_record API expects records in a specific format
|
||||
|
||||
data = await create_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -182,8 +264,22 @@ class AirtableCreateRecordsBlock(Block):
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=input_data.return_fields_by_field_id,
|
||||
)
|
||||
result_records = cast(list[dict], data.get("records", []))
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
# Normalize output unless explicitly disabled
|
||||
if not input_data.skip_normalization and result_records:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
result_records, table_schema, include_field_metadata=False
|
||||
)
|
||||
result_records = normalized_data["records"]
|
||||
|
||||
yield "records", result_records
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
|
||||
205
autogpt_platform/backend/backend/blocks/baas/_api.py
Normal file
205
autogpt_platform/backend/backend/blocks/baas/_api.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Meeting BaaS API client module.
|
||||
All API calls centralized for consistency and maintainability.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import Requests
|
||||
|
||||
|
||||
class MeetingBaasAPI:
|
||||
"""Client for Meeting BaaS API endpoints."""
|
||||
|
||||
BASE_URL = "https://api.meetingbaas.com"
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""Initialize API client with authentication key."""
|
||||
self.api_key = api_key
|
||||
self.headers = {"x-meeting-baas-api-key": api_key}
|
||||
self.requests = Requests()
|
||||
|
||||
# Bot Management Endpoints
|
||||
|
||||
async def join_meeting(
|
||||
self,
|
||||
bot_name: str,
|
||||
meeting_url: str,
|
||||
reserved: bool = False,
|
||||
bot_image: Optional[str] = None,
|
||||
entry_message: Optional[str] = None,
|
||||
start_time: Optional[int] = None,
|
||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
automatic_leave: Optional[Dict[str, Any]] = None,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
recording_mode: str = "speaker_view",
|
||||
streaming: Optional[Dict[str, Any]] = None,
|
||||
deduplication_key: Optional[str] = None,
|
||||
zoom_sdk_id: Optional[str] = None,
|
||||
zoom_sdk_pwd: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Deploy a bot to join and record a meeting.
|
||||
|
||||
POST /bots
|
||||
"""
|
||||
body = {
|
||||
"bot_name": bot_name,
|
||||
"meeting_url": meeting_url,
|
||||
"reserved": reserved,
|
||||
"recording_mode": recording_mode,
|
||||
}
|
||||
|
||||
# Add optional fields if provided
|
||||
if bot_image is not None:
|
||||
body["bot_image"] = bot_image
|
||||
if entry_message is not None:
|
||||
body["entry_message"] = entry_message
|
||||
if start_time is not None:
|
||||
body["start_time"] = start_time
|
||||
if speech_to_text is not None:
|
||||
body["speech_to_text"] = speech_to_text
|
||||
if webhook_url is not None:
|
||||
body["webhook_url"] = webhook_url
|
||||
if automatic_leave is not None:
|
||||
body["automatic_leave"] = automatic_leave
|
||||
if extra is not None:
|
||||
body["extra"] = extra
|
||||
if streaming is not None:
|
||||
body["streaming"] = streaming
|
||||
if deduplication_key is not None:
|
||||
body["deduplication_key"] = deduplication_key
|
||||
if zoom_sdk_id is not None:
|
||||
body["zoom_sdk_id"] = zoom_sdk_id
|
||||
if zoom_sdk_pwd is not None:
|
||||
body["zoom_sdk_pwd"] = zoom_sdk_pwd
|
||||
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots",
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def leave_meeting(self, bot_id: str) -> bool:
|
||||
"""
|
||||
Remove a bot from an ongoing meeting.
|
||||
|
||||
DELETE /bots/{uuid}
|
||||
"""
|
||||
response = await self.requests.delete(
|
||||
f"{self.BASE_URL}/bots/{bot_id}",
|
||||
headers=self.headers,
|
||||
)
|
||||
return response.status in [200, 204]
|
||||
|
||||
async def retranscribe(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Re-run transcription on a bot's audio.
|
||||
|
||||
POST /bots/retranscribe
|
||||
"""
|
||||
body: Dict[str, Any] = {"bot_uuid": bot_uuid}
|
||||
|
||||
if speech_to_text is not None:
|
||||
body["speech_to_text"] = speech_to_text
|
||||
if webhook_url is not None:
|
||||
body["webhook_url"] = webhook_url
|
||||
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots/retranscribe",
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
if response.status == 202:
|
||||
return {"accepted": True}
|
||||
return response.json()
|
||||
|
||||
# Data Retrieval Endpoints
|
||||
|
||||
async def get_meeting_data(
|
||||
self, bot_id: str, include_transcripts: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve meeting data including recording and transcripts.
|
||||
|
||||
GET /bots/meeting_data
|
||||
"""
|
||||
params = {
|
||||
"bot_id": bot_id,
|
||||
"include_transcripts": str(include_transcripts).lower(),
|
||||
}
|
||||
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/meeting_data",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def get_screenshots(self, bot_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve screenshots captured during a meeting.
|
||||
|
||||
GET /bots/{uuid}/screenshots
|
||||
"""
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/{bot_id}/screenshots",
|
||||
headers=self.headers,
|
||||
)
|
||||
result = response.json()
|
||||
# Ensure we return a list
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
return []
|
||||
|
||||
async def delete_data(self, bot_id: str) -> bool:
|
||||
"""
|
||||
Delete a bot's recorded data.
|
||||
|
||||
POST /bots/{uuid}/delete_data
|
||||
"""
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots/{bot_id}/delete_data",
|
||||
headers=self.headers,
|
||||
)
|
||||
return response.status == 200
|
||||
|
||||
async def list_bots_with_metadata(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: Optional[str] = None,
|
||||
filter_by: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
List bots with metadata including IDs, names, and meeting details.
|
||||
|
||||
GET /bots/bots_with_metadata
|
||||
"""
|
||||
params = {}
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if sort_by is not None:
|
||||
params["sort_by"] = sort_by
|
||||
if sort_order is not None:
|
||||
params["sort_order"] = sort_order
|
||||
if filter_by is not None:
|
||||
params.update(filter_by)
|
||||
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/bots_with_metadata",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
13
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
13
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Shared configuration for all Meeting BaaS blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Configure the Meeting BaaS provider with API key authentication
|
||||
baas = (
|
||||
ProviderBuilder("baas")
|
||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
||||
.build()
|
||||
)
|
||||
217
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
217
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Meeting BaaS bot (recording) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import MeetingBaasAPI
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasBotJoinMeetingBlock(Block):
|
||||
"""
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
meeting_url: str = SchemaField(
|
||||
description="The URL of the meeting the bot should join"
|
||||
)
|
||||
bot_name: str = SchemaField(
|
||||
description="Display name for the bot in the meeting"
|
||||
)
|
||||
bot_image: str = SchemaField(
|
||||
description="URL to an image for the bot's avatar (16:9 ratio recommended)",
|
||||
default="",
|
||||
)
|
||||
entry_message: str = SchemaField(
|
||||
description="Chat message the bot will post upon entry", default=""
|
||||
)
|
||||
reserved: bool = SchemaField(
|
||||
description="Use a reserved bot slot (joins 4 min before meeting)",
|
||||
default=False,
|
||||
)
|
||||
start_time: Optional[int] = SchemaField(
|
||||
description="Unix timestamp (ms) when bot should join", default=None
|
||||
)
|
||||
webhook_url: str | None = SchemaField(
|
||||
description="URL to receive webhook events for this bot", default=None
|
||||
)
|
||||
timeouts: dict = SchemaField(
|
||||
description="Automatic leave timeouts configuration", default={}
|
||||
)
|
||||
extra: dict = SchemaField(
|
||||
description="Custom metadata to attach to the bot", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
||||
join_response: dict = SchemaField(
|
||||
description="Full response from join operation"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="377d1a6a-a99b-46cf-9af3-1d1b12758e04",
|
||||
description="Deploy a bot to join and record a meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Call API with all parameters
|
||||
data = await api.join_meeting(
|
||||
bot_name=input_data.bot_name,
|
||||
meeting_url=input_data.meeting_url,
|
||||
reserved=input_data.reserved,
|
||||
bot_image=input_data.bot_image if input_data.bot_image else None,
|
||||
entry_message=(
|
||||
input_data.entry_message if input_data.entry_message else None
|
||||
),
|
||||
start_time=input_data.start_time,
|
||||
speech_to_text={"provider": "Default"},
|
||||
webhook_url=input_data.webhook_url if input_data.webhook_url else None,
|
||||
automatic_leave=input_data.timeouts if input_data.timeouts else None,
|
||||
extra=input_data.extra if input_data.extra else None,
|
||||
)
|
||||
|
||||
yield "bot_id", data.get("bot_id", "")
|
||||
yield "join_response", data
|
||||
|
||||
|
||||
class BaasBotLeaveMeetingBlock(Block):
|
||||
"""
|
||||
Force the bot to exit the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
||||
|
||||
class Output(BlockSchema):
|
||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf77d128-8b25-4280-b5c7-2d553ba7e482",
|
||||
description="Remove a bot from an ongoing meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Leave meeting
|
||||
left = await api.leave_meeting(input_data.bot_id)
|
||||
|
||||
yield "left", left
|
||||
|
||||
|
||||
class BaasBotFetchMeetingDataBlock(Block):
|
||||
"""
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to fetch")
|
||||
include_transcripts: bool = SchemaField(
|
||||
description="Include transcript data in response", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
mp4_url: str = SchemaField(
|
||||
description="URL to download the meeting recording (time-limited)"
|
||||
)
|
||||
transcript: list = SchemaField(description="Meeting transcript data")
|
||||
metadata: dict = SchemaField(description="Meeting metadata and bot information")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ea7c1309-303c-4da1-893f-89c0e9d64e78",
|
||||
description="Retrieve recorded meeting data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Fetch meeting data
|
||||
data = await api.get_meeting_data(
|
||||
bot_id=input_data.bot_id,
|
||||
include_transcripts=input_data.include_transcripts,
|
||||
)
|
||||
|
||||
yield "mp4_url", data.get("mp4", "")
|
||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
||||
|
||||
|
||||
class BaasBotDeleteRecordingBlock(Block):
|
||||
"""
|
||||
Purge MP4 + transcript data for privacy or storage management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the data was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf8d1aa6-42d8-4944-b6bd-6bac554c0d3b",
|
||||
description="Permanently delete a meeting's recorded data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Delete recording data
|
||||
deleted = await api.delete_data(input_data.bot_id)
|
||||
|
||||
yield "deleted", deleted
|
||||
@@ -0,0 +1,3 @@
|
||||
from .text_overlay import BannerbearTextOverlayBlock
|
||||
|
||||
__all__ = ["BannerbearTextOverlayBlock"]
|
||||
@@ -0,0 +1,8 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
bannerbear = (
|
||||
ProviderBuilder("bannerbear")
|
||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -0,0 +1,239 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import bannerbear
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="bannerbear",
|
||||
api_key=SecretStr("mock-bannerbear-api-key"),
|
||||
title="Mock Bannerbear API Key",
|
||||
)
|
||||
|
||||
|
||||
class TextModification(BlockSchema):
|
||||
name: str = SchemaField(
|
||||
description="The name of the layer to modify in the template"
|
||||
)
|
||||
text: str = SchemaField(description="The text content to add to this layer")
|
||||
color: str = SchemaField(
|
||||
description="Hex color code for the text (e.g., '#FF0000')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_family: str = SchemaField(
|
||||
description="Font family to use for the text",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_size: int = SchemaField(
|
||||
description="Font size in pixels",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
font_weight: str = SchemaField(
|
||||
description="Font weight (e.g., 'bold', 'normal')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_align: str = SchemaField(
|
||||
description="Text alignment (left, center, right)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class BannerbearTextOverlayBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = bannerbear.credentials_field(
|
||||
description="API credentials for Bannerbear"
|
||||
)
|
||||
template_id: str = SchemaField(
|
||||
description="The unique ID of your Bannerbear template"
|
||||
)
|
||||
project_id: str = SchemaField(
|
||||
description="Optional: Project ID (required when using Master API Key)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_modifications: List[TextModification] = SchemaField(
|
||||
description="List of text layers to modify in the template"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="Optional: URL of an image to use in the template",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
image_layer_name: str = SchemaField(
|
||||
description="Optional: Name of the image layer in the template",
|
||||
default="photo",
|
||||
advanced=True,
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="Optional: URL to receive webhook notification when image is ready",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: str = SchemaField(
|
||||
description="Optional: Custom metadata to attach to the image",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the image generation was successfully initiated"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="URL of the generated image (if synchronous) or placeholder"
|
||||
)
|
||||
uid: str = SchemaField(description="Unique identifier for the generated image")
|
||||
status: str = SchemaField(description="Status of the image generation")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c7d3a5c2-05fc-450e-8dce-3b0e04626009",
|
||||
description="Add text overlay to images using Bannerbear templates. Perfect for creating social media graphics, marketing materials, and dynamic image content.",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"template_id": "jJWBKNELpQPvbX5R93Gk",
|
||||
"text_modifications": [
|
||||
{
|
||||
"name": "headline",
|
||||
"text": "Amazing Product Launch!",
|
||||
"color": "#FF0000",
|
||||
},
|
||||
{
|
||||
"name": "subtitle",
|
||||
"text": "50% OFF Today Only",
|
||||
},
|
||||
],
|
||||
"credentials": {
|
||||
"provider": "bannerbear",
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "api_key",
|
||||
},
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
||||
("uid", "test-uid-123"),
|
||||
("status", "completed"),
|
||||
],
|
||||
test_mock={
|
||||
"_make_api_request": lambda *args, **kwargs: {
|
||||
"uid": "test-uid-123",
|
||||
"status": "completed",
|
||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
||||
}
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def _make_api_request(self, payload: dict, api_key: str) -> dict:
|
||||
"""Make the actual API request to Bannerbear. This is separated for easy mocking in tests."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
"https://sync.api.bannerbear.com/v2/images",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status in [200, 201, 202]:
|
||||
return response.json()
|
||||
else:
|
||||
error_msg = f"API request failed with status {response.status}"
|
||||
if response.text:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = (
|
||||
f"{error_msg}: {error_data.get('message', response.text)}"
|
||||
)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg}: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build the modifications array
|
||||
modifications = []
|
||||
|
||||
# Add text modifications
|
||||
for text_mod in input_data.text_modifications:
|
||||
mod_data: Dict[str, Any] = {
|
||||
"name": text_mod.name,
|
||||
"text": text_mod.text,
|
||||
}
|
||||
|
||||
# Add optional text styling parameters only if they have values
|
||||
if text_mod.color and text_mod.color.strip():
|
||||
mod_data["color"] = text_mod.color
|
||||
if text_mod.font_family and text_mod.font_family.strip():
|
||||
mod_data["font_family"] = text_mod.font_family
|
||||
if text_mod.font_size and text_mod.font_size > 0:
|
||||
mod_data["font_size"] = text_mod.font_size
|
||||
if text_mod.font_weight and text_mod.font_weight.strip():
|
||||
mod_data["font_weight"] = text_mod.font_weight
|
||||
if text_mod.text_align and text_mod.text_align.strip():
|
||||
mod_data["text_align"] = text_mod.text_align
|
||||
|
||||
modifications.append(mod_data)
|
||||
|
||||
# Add image modification if provided and not empty
|
||||
if input_data.image_url and input_data.image_url.strip():
|
||||
modifications.append(
|
||||
{
|
||||
"name": input_data.image_layer_name,
|
||||
"image_url": input_data.image_url,
|
||||
}
|
||||
)
|
||||
|
||||
# Build the request payload - only include non-empty optional fields
|
||||
payload = {
|
||||
"template": input_data.template_id,
|
||||
"modifications": modifications,
|
||||
}
|
||||
|
||||
# Add project_id if provided (required for Master API keys)
|
||||
if input_data.project_id and input_data.project_id.strip():
|
||||
payload["project_id"] = input_data.project_id
|
||||
|
||||
if input_data.webhook_url and input_data.webhook_url.strip():
|
||||
payload["webhook_url"] = input_data.webhook_url
|
||||
if input_data.metadata and input_data.metadata.strip():
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
# Make the API request using the private method
|
||||
data = await self._make_api_request(
|
||||
payload, credentials.api_key.get_secret_value()
|
||||
)
|
||||
|
||||
# Synchronous request - image should be ready
|
||||
yield "success", True
|
||||
yield "image_url", data.get("image_url", "")
|
||||
yield "uid", data.get("uid", "")
|
||||
yield "status", data.get("status", "completed")
|
||||
178
autogpt_platform/backend/backend/blocks/dataforseo/_api.py
Normal file
178
autogpt_platform/backend/backend/blocks/dataforseo/_api.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
DataForSEO API client with async support using the SDK patterns.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import Requests, UserPasswordCredentials
|
||||
|
||||
|
||||
class DataForSeoClient:
|
||||
"""Client for the DataForSEO API using async requests."""
|
||||
|
||||
API_URL = "https://api.dataforseo.com"
|
||||
|
||||
def __init__(self, credentials: UserPasswordCredentials):
|
||||
self.credentials = credentials
|
||||
self.requests = Requests(
|
||||
trusted_origins=["https://api.dataforseo.com"],
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Generate the authorization header using Basic Auth."""
|
||||
username = self.credentials.username.get_secret_value()
|
||||
password = self.credentials.password.get_secret_value()
|
||||
credentials_str = f"{username}:{password}"
|
||||
encoded = base64.b64encode(credentials_str.encode("ascii")).decode("ascii")
|
||||
return {
|
||||
"Authorization": f"Basic {encoded}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def keyword_suggestions(
|
||||
self,
|
||||
keyword: str,
|
||||
location_code: Optional[int] = None,
|
||||
language_code: Optional[str] = None,
|
||||
include_seed_keyword: bool = True,
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get keyword suggestions from DataForSEO Labs.
|
||||
|
||||
Args:
|
||||
keyword: Seed keyword
|
||||
location_code: Location code for targeting
|
||||
language_code: Language code (e.g., "en")
|
||||
include_seed_keyword: Include seed keyword in results
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
|
||||
Returns:
|
||||
API response with keyword suggestions
|
||||
"""
|
||||
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/keyword_suggestions/live"
|
||||
|
||||
# Build payload only with non-None values to avoid sending null fields
|
||||
task_data: dict[str, Any] = {
|
||||
"keyword": keyword,
|
||||
}
|
||||
|
||||
if location_code is not None:
|
||||
task_data["location_code"] = location_code
|
||||
if language_code is not None:
|
||||
task_data["language_code"] = language_code
|
||||
if include_seed_keyword is not None:
|
||||
task_data["include_seed_keyword"] = include_seed_keyword
|
||||
if include_serp_info is not None:
|
||||
task_data["include_serp_info"] = include_serp_info
|
||||
if include_clickstream_data is not None:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
response = await self.requests.post(
|
||||
endpoint,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check for API errors
|
||||
if response.status != 200:
|
||||
error_message = data.get("status_message", "Unknown error")
|
||||
raise Exception(
|
||||
f"DataForSEO API error ({response.status}): {error_message}"
|
||||
)
|
||||
|
||||
# Extract the results from the response
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
raise Exception(f"DataForSEO task error: {error_msg}")
|
||||
|
||||
return []
|
||||
|
||||
async def related_keywords(
|
||||
self,
|
||||
keyword: str,
|
||||
location_code: Optional[int] = None,
|
||||
language_code: Optional[str] = None,
|
||||
include_seed_keyword: bool = True,
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get related keywords from DataForSEO Labs.
|
||||
|
||||
Args:
|
||||
keyword: Seed keyword
|
||||
location_code: Location code for targeting
|
||||
language_code: Language code (e.g., "en")
|
||||
include_seed_keyword: Include seed keyword in results
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
|
||||
Returns:
|
||||
API response with related keywords
|
||||
"""
|
||||
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/related_keywords/live"
|
||||
|
||||
# Build payload only with non-None values to avoid sending null fields
|
||||
task_data: dict[str, Any] = {
|
||||
"keyword": keyword,
|
||||
}
|
||||
|
||||
if location_code is not None:
|
||||
task_data["location_code"] = location_code
|
||||
if language_code is not None:
|
||||
task_data["language_code"] = language_code
|
||||
if include_seed_keyword is not None:
|
||||
task_data["include_seed_keyword"] = include_seed_keyword
|
||||
if include_serp_info is not None:
|
||||
task_data["include_serp_info"] = include_serp_info
|
||||
if include_clickstream_data is not None:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
response = await self.requests.post(
|
||||
endpoint,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check for API errors
|
||||
if response.status != 200:
|
||||
error_message = data.get("status_message", "Unknown error")
|
||||
raise Exception(
|
||||
f"DataForSEO API error ({response.status}): {error_message}"
|
||||
)
|
||||
|
||||
# Extract the results from the response
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
raise Exception(f"DataForSEO task error: {error_msg}")
|
||||
|
||||
return []
|
||||
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Configuration for all DataForSEO blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Build the DataForSEO provider with username/password authentication
|
||||
dataforseo = (
|
||||
ProviderBuilder("dataforseo")
|
||||
.with_user_password(
|
||||
username_env_var="DATAFORSEO_USERNAME",
|
||||
password_env_var="DATAFORSEO_PASSWORD",
|
||||
title="DataForSEO Credentials",
|
||||
)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
DataForSEO Google Keyword Suggestions block.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class KeywordSuggestion(BlockSchema):
|
||||
"""Schema for a keyword suggestion result."""
|
||||
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="data from SERP for each keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
|
||||
class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
"""Block for getting keyword suggestions from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
keyword: str = SchemaField(description="Seed keyword to get suggestions for")
|
||||
location_code: Optional[int] = SchemaField(
|
||||
description="Location code for targeting (e.g., 2840 for USA)",
|
||||
default=2840, # USA
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (e.g., 'en' for English)",
|
||||
default="en",
|
||||
)
|
||||
include_seed_keyword: bool = SchemaField(
|
||||
description="Include the seed keyword in results",
|
||||
default=True,
|
||||
)
|
||||
include_serp_info: bool = SchemaField(
|
||||
description="Include SERP information",
|
||||
default=False,
|
||||
)
|
||||
include_clickstream_data: bool = SchemaField(
|
||||
description="Include clickstream metrics",
|
||||
default=False,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results (up to 3000)",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
suggestions: List[KeywordSuggestion] = SchemaField(
|
||||
description="List of keyword suggestions with metrics"
|
||||
)
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="A single keyword suggestion with metrics"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of suggestions returned"
|
||||
)
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="73c3e7c4-2b3f-4e9f-9e3e-8f7a5c3e2d45",
|
||||
description="Get keyword suggestions from DataForSEO Labs Google API",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": dataforseo.get_test_credentials().model_dump(),
|
||||
"keyword": "digital marketing",
|
||||
"location_code": 2840,
|
||||
"language_code": "en",
|
||||
"limit": 1,
|
||||
},
|
||||
test_credentials=dataforseo.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"suggestion",
|
||||
lambda x: hasattr(x, "keyword")
|
||||
and x.keyword == "digital marketing strategy",
|
||||
),
|
||||
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
|
||||
("total_count", 1),
|
||||
("seed_keyword", "digital marketing"),
|
||||
],
|
||||
test_mock={
|
||||
"_fetch_keyword_suggestions": lambda *args, **kwargs: [
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"keyword": "digital marketing strategy",
|
||||
"keyword_info": {
|
||||
"search_volume": 10000,
|
||||
"competition": 0.5,
|
||||
"cpc": 2.5,
|
||||
},
|
||||
"keyword_properties": {
|
||||
"keyword_difficulty": 50,
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def _fetch_keyword_suggestions(
|
||||
self,
|
||||
client: DataForSeoClient,
|
||||
input_data: Input,
|
||||
) -> Any:
|
||||
"""Private method to fetch keyword suggestions - can be mocked for testing."""
|
||||
return await client.keyword_suggestions(
|
||||
keyword=input_data.keyword,
|
||||
location_code=input_data.location_code,
|
||||
language_code=input_data.language_code,
|
||||
include_seed_keyword=input_data.include_seed_keyword,
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: UserPasswordCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
"""Extracts individual fields from a KeywordSuggestion object."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="The keyword suggestion object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="data from SERP for each keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4193cb94-677c-48b0-9eec-6ac72fffd0f2",
|
||||
description="Extract individual fields from a KeywordSuggestion object",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"suggestion": KeywordSuggestion(
|
||||
keyword="test keyword",
|
||||
search_volume=1000,
|
||||
competition=0.5,
|
||||
cpc=2.5,
|
||||
keyword_difficulty=60,
|
||||
).model_dump()
|
||||
},
|
||||
test_output=[
|
||||
("keyword", "test keyword"),
|
||||
("search_volume", 1000),
|
||||
("competition", 0.5),
|
||||
("cpc", 2.5),
|
||||
("keyword_difficulty", 60),
|
||||
("serp_info", None),
|
||||
("clickstream_data", None),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Extract fields from the KeywordSuggestion object."""
|
||||
suggestion = input_data.suggestion
|
||||
|
||||
yield "keyword", suggestion.keyword
|
||||
yield "search_volume", suggestion.search_volume
|
||||
yield "competition", suggestion.competition
|
||||
yield "cpc", suggestion.cpc
|
||||
yield "keyword_difficulty", suggestion.keyword_difficulty
|
||||
yield "serp_info", suggestion.serp_info
|
||||
yield "clickstream_data", suggestion.clickstream_data
|
||||
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
DataForSEO Google Related Keywords block.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class RelatedKeyword(BlockSchema):
|
||||
"""Schema for a related keyword result."""
|
||||
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="SERP data for the keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
|
||||
class DataForSeoRelatedKeywordsBlock(Block):
|
||||
"""Block for getting related keywords from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
keyword: str = SchemaField(
|
||||
description="Seed keyword to find related keywords for"
|
||||
)
|
||||
location_code: Optional[int] = SchemaField(
|
||||
description="Location code for targeting (e.g., 2840 for USA)",
|
||||
default=2840, # USA
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (e.g., 'en' for English)",
|
||||
default="en",
|
||||
)
|
||||
include_seed_keyword: bool = SchemaField(
|
||||
description="Include the seed keyword in results",
|
||||
default=True,
|
||||
)
|
||||
include_serp_info: bool = SchemaField(
|
||||
description="Include SERP information",
|
||||
default=False,
|
||||
)
|
||||
include_clickstream_data: bool = SchemaField(
|
||||
description="Include clickstream metrics",
|
||||
default=False,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results (up to 3000)",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
||||
description="List of related keywords with metrics"
|
||||
)
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="A related keyword with metrics"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of related keywords returned"
|
||||
)
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f2e4d6a-1b3c-4a5e-9d7f-2c8e6a4b3f1d",
|
||||
description="Get related keywords from DataForSEO Labs Google API",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": dataforseo.get_test_credentials().model_dump(),
|
||||
"keyword": "content marketing",
|
||||
"location_code": 2840,
|
||||
"language_code": "en",
|
||||
"limit": 1,
|
||||
},
|
||||
test_credentials=dataforseo.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"related_keyword",
|
||||
lambda x: hasattr(x, "keyword") and x.keyword == "content strategy",
|
||||
),
|
||||
("related_keywords", lambda x: isinstance(x, list) and len(x) == 1),
|
||||
("total_count", 1),
|
||||
("seed_keyword", "content marketing"),
|
||||
],
|
||||
test_mock={
|
||||
"_fetch_related_keywords": lambda *args, **kwargs: [
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"keyword_data": {
|
||||
"keyword": "content strategy",
|
||||
"keyword_info": {
|
||||
"search_volume": 8000,
|
||||
"competition": 0.4,
|
||||
"cpc": 3.0,
|
||||
},
|
||||
"keyword_properties": {
|
||||
"keyword_difficulty": 45,
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def _fetch_related_keywords(
|
||||
self,
|
||||
client: DataForSeoClient,
|
||||
input_data: Input,
|
||||
) -> Any:
|
||||
"""Private method to fetch related keywords - can be mocked for testing."""
|
||||
return await client.related_keywords(
|
||||
keyword=input_data.keyword,
|
||||
location_code=input_data.location_code,
|
||||
language_code=input_data.language_code,
|
||||
include_seed_keyword=input_data.include_seed_keyword,
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: UserPasswordCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
"""Extracts individual fields from a RelatedKeyword object."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="The related keyword object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="SERP data for the keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="98342061-09d2-4952-bf77-0761fc8cc9a8",
|
||||
description="Extract individual fields from a RelatedKeyword object",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"related_keyword": RelatedKeyword(
|
||||
keyword="test related keyword",
|
||||
search_volume=800,
|
||||
competition=0.4,
|
||||
cpc=3.0,
|
||||
keyword_difficulty=55,
|
||||
).model_dump()
|
||||
},
|
||||
test_output=[
|
||||
("keyword", "test related keyword"),
|
||||
("search_volume", 800),
|
||||
("competition", 0.4),
|
||||
("cpc", 3.0),
|
||||
("keyword_difficulty", 55),
|
||||
("serp_info", None),
|
||||
("clickstream_data", None),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Extract fields from the RelatedKeyword object."""
|
||||
related_keyword = input_data.related_keyword
|
||||
|
||||
yield "keyword", related_keyword.keyword
|
||||
yield "search_volume", related_keyword.search_volume
|
||||
yield "competition", related_keyword.competition
|
||||
yield "cpc", related_keyword.cpc
|
||||
yield "keyword_difficulty", related_keyword.keyword_difficulty
|
||||
yield "serp_info", related_keyword.serp_info
|
||||
yield "clickstream_data", related_keyword.clickstream_data
|
||||
117
autogpt_platform/backend/backend/blocks/discord/_api.py
Normal file
117
autogpt_platform/backend/backend/blocks/discord/_api.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Discord API helper functions for making authenticated requests.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscordAPIException(Exception):
|
||||
"""Exception raised for Discord API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class DiscordOAuthUser(BaseModel):
|
||||
"""Model for Discord OAuth user response."""
|
||||
|
||||
user_id: str
|
||||
username: str
|
||||
avatar_url: str
|
||||
banner: Optional[str] = None
|
||||
accent_color: Optional[int] = None
|
||||
|
||||
|
||||
def get_api(credentials: OAuth2Credentials) -> Requests:
|
||||
"""
|
||||
Create a Requests instance configured for Discord API calls with OAuth2 credentials.
|
||||
|
||||
Args:
|
||||
credentials: The OAuth2 credentials containing the access token.
|
||||
|
||||
Returns:
|
||||
A configured Requests instance for Discord API calls.
|
||||
"""
|
||||
return Requests(
|
||||
trusted_origins=[],
|
||||
extra_headers={
|
||||
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
|
||||
"""
|
||||
Fetch the current user's information using Discord OAuth2 API.
|
||||
|
||||
Reference: https://discord.com/developers/docs/resources/user#get-current-user
|
||||
|
||||
Args:
|
||||
credentials: The OAuth2 credentials.
|
||||
|
||||
Returns:
|
||||
A model containing user data with avatar URL.
|
||||
|
||||
Raises:
|
||||
DiscordAPIException: If the API request fails.
|
||||
"""
|
||||
api = get_api(credentials)
|
||||
response = await api.get("https://discord.com/api/oauth2/@me")
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise DiscordAPIException(
|
||||
f"Failed to fetch user info: {response.status} - {error_text}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
logger.info(f"Discord OAuth2 API Response: {data}")
|
||||
|
||||
# The /api/oauth2/@me endpoint returns a user object nested in the response
|
||||
user_info = data.get("user", {})
|
||||
logger.info(f"User info extracted: {user_info}")
|
||||
|
||||
# Build avatar URL
|
||||
user_id = user_info.get("id")
|
||||
avatar_hash = user_info.get("avatar")
|
||||
if avatar_hash:
|
||||
# Custom avatar
|
||||
avatar_ext = "gif" if avatar_hash.startswith("a_") else "png"
|
||||
avatar_url = (
|
||||
f"https://cdn.discordapp.com/avatars/{user_id}/{avatar_hash}.{avatar_ext}"
|
||||
)
|
||||
else:
|
||||
# Default avatar based on discriminator or user ID
|
||||
discriminator = user_info.get("discriminator", "0")
|
||||
if discriminator == "0":
|
||||
# New username system - use user ID for default avatar
|
||||
default_avatar_index = (int(user_id) >> 22) % 6
|
||||
else:
|
||||
# Legacy discriminator system
|
||||
default_avatar_index = int(discriminator) % 5
|
||||
avatar_url = (
|
||||
f"https://cdn.discordapp.com/embed/avatars/{default_avatar_index}.png"
|
||||
)
|
||||
|
||||
result = DiscordOAuthUser(
|
||||
user_id=user_id,
|
||||
username=user_info.get("username", ""),
|
||||
avatar_url=avatar_url,
|
||||
banner=user_info.get("banner"),
|
||||
accent_color=user_info.get("accent_color"),
|
||||
)
|
||||
|
||||
logger.info(f"Returning user data: {result.model_dump()}")
|
||||
return result
|
||||
74
autogpt_platform/backend/backend/blocks/discord/_auth.py
Normal file
74
autogpt_platform/backend/backend/blocks/discord/_auth.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
DISCORD_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.discord_client_id and secrets.discord_client_secret
|
||||
)
|
||||
|
||||
# Bot token credentials (existing)
|
||||
DiscordBotCredentials = APIKeyCredentials
|
||||
DiscordBotCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||
]
|
||||
|
||||
# OAuth2 credentials (new)
|
||||
DiscordOAuthCredentials = OAuth2Credentials
|
||||
DiscordOAuthCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
def DiscordBotCredentialsField() -> DiscordBotCredentialsInput:
|
||||
"""Creates a Discord bot token credentials field."""
|
||||
return CredentialsField(description="Discord bot token")
|
||||
|
||||
|
||||
def DiscordOAuthCredentialsField(scopes: list[str]) -> DiscordOAuthCredentialsInput:
|
||||
"""Creates a Discord OAuth2 credentials field."""
|
||||
return CredentialsField(
|
||||
description="Discord OAuth2 credentials",
|
||||
required_scopes=set(scopes) | {"identify"}, # Basic user info scope
|
||||
)
|
||||
|
||||
|
||||
# Test credentials for bot tokens
|
||||
TEST_BOT_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
api_key=SecretStr("test_api_key"),
|
||||
title="Mock Discord API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_BOT_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_BOT_CREDENTIALS.provider,
|
||||
"id": TEST_BOT_CREDENTIALS.id,
|
||||
"type": TEST_BOT_CREDENTIALS.type,
|
||||
"title": TEST_BOT_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
# Test credentials for OAuth2
|
||||
TEST_OAUTH_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
access_token=SecretStr("test_access_token"),
|
||||
title="Mock Discord OAuth",
|
||||
scopes=["identify"],
|
||||
username="testuser",
|
||||
)
|
||||
TEST_OAUTH_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_OAUTH_CREDENTIALS.provider,
|
||||
"id": TEST_OAUTH_CREDENTIALS.id,
|
||||
"type": TEST_OAUTH_CREDENTIALS.type,
|
||||
"title": TEST_OAUTH_CREDENTIALS.type,
|
||||
}
|
||||
@@ -2,45 +2,29 @@ import base64
|
||||
import io
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
DiscordCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def DiscordCredentialsField() -> DiscordCredentials:
|
||||
return CredentialsField(description="Discord bot token")
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
api_key=SecretStr("test_api_key"),
|
||||
title="Mock Discord API key",
|
||||
expires_at=None,
|
||||
from ._auth import (
|
||||
TEST_BOT_CREDENTIALS,
|
||||
TEST_BOT_CREDENTIALS_INPUT,
|
||||
DiscordBotCredentialsField,
|
||||
DiscordBotCredentialsInput,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
# Keep backward compatibility alias
|
||||
DiscordCredentials = DiscordBotCredentialsInput
|
||||
DiscordCredentialsField = DiscordBotCredentialsField
|
||||
TEST_CREDENTIALS = TEST_BOT_CREDENTIALS
|
||||
TEST_CREDENTIALS_INPUT = TEST_BOT_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
class ReadDiscordMessagesBlock(Block):
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Discord OAuth-based blocks.
|
||||
"""
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import DiscordOAuthUser, get_current_user
|
||||
from ._auth import (
|
||||
DISCORD_OAUTH_IS_CONFIGURED,
|
||||
TEST_OAUTH_CREDENTIALS,
|
||||
TEST_OAUTH_CREDENTIALS_INPUT,
|
||||
DiscordOAuthCredentialsField,
|
||||
DiscordOAuthCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class DiscordGetCurrentUserBlock(Block):
|
||||
"""
|
||||
Gets information about the currently authenticated Discord user using OAuth2.
|
||||
This block requires Discord OAuth2 credentials (not bot tokens).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordOAuthCredentialsInput = DiscordOAuthCredentialsField(
|
||||
["identify"]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
user_id: str = SchemaField(description="The authenticated user's Discord ID")
|
||||
username: str = SchemaField(description="The user's username")
|
||||
avatar_url: str = SchemaField(description="URL to the user's avatar image")
|
||||
banner_url: str = SchemaField(
|
||||
description="URL to the user's banner image (if set)", default=""
|
||||
)
|
||||
accent_color: int = SchemaField(
|
||||
description="The user's accent color as an integer", default=0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8c7e39b8-4e9d-4f3a-b4e1-2a8c9d5f6e3b",
|
||||
input_schema=DiscordGetCurrentUserBlock.Input,
|
||||
output_schema=DiscordGetCurrentUserBlock.Output,
|
||||
description="Gets information about the currently authenticated Discord user using OAuth2 credentials.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=not DISCORD_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"credentials": TEST_OAUTH_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_OAUTH_CREDENTIALS,
|
||||
test_output=[
|
||||
("user_id", "123456789012345678"),
|
||||
("username", "testuser"),
|
||||
(
|
||||
"avatar_url",
|
||||
"https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
|
||||
),
|
||||
("banner_url", ""),
|
||||
("accent_color", 0),
|
||||
],
|
||||
test_mock={
|
||||
"get_user": lambda _: DiscordOAuthUser(
|
||||
user_id="123456789012345678",
|
||||
username="testuser",
|
||||
avatar_url="https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
|
||||
banner=None,
|
||||
accent_color=0,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
|
||||
user_info = await get_current_user(credentials)
|
||||
return user_info
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.get_user(credentials)
|
||||
|
||||
# Yield each output field
|
||||
yield "user_id", result.user_id
|
||||
yield "username", result.username
|
||||
yield "avatar_url", result.avatar_url
|
||||
|
||||
# Handle banner URL if banner hash exists
|
||||
if result.banner:
|
||||
banner_url = f"https://cdn.discordapp.com/banners/{result.user_id}/{result.banner}.png"
|
||||
yield "banner_url", banner_url
|
||||
else:
|
||||
yield "banner_url", ""
|
||||
|
||||
yield "accent_color", result.accent_color or 0
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get Discord user info: {e}")
|
||||
@@ -93,11 +93,11 @@ class Webset(BaseModel):
|
||||
"""
|
||||
Set of key-value pairs you want to associate with this object.
|
||||
"""
|
||||
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
|
||||
created_at: Annotated[datetime | None, Field(alias="createdAt")] = None
|
||||
"""
|
||||
The date and time the webset was created
|
||||
"""
|
||||
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
|
||||
updated_at: Annotated[datetime | None, Field(alias="updatedAt")] = None
|
||||
"""
|
||||
The date and time the webset was last updated
|
||||
"""
|
||||
|
||||
@@ -1094,6 +1094,117 @@ class GmailGetThreadBlock(GmailBase):
|
||||
return thread
|
||||
|
||||
|
||||
async def _build_reply_message(
|
||||
service, input_data, graph_exec_id: str, user_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds a reply MIME message for Gmail threads.
|
||||
|
||||
Returns:
|
||||
tuple: (base64-encoded raw message, threadId)
|
||||
"""
|
||||
# Get parent message for reply context
|
||||
parent = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Build headers dictionary, preserving all values for duplicate headers
|
||||
headers = {}
|
||||
for h in parent.get("payload", {}).get("headers", []):
|
||||
name = h["name"].lower()
|
||||
value = h["value"]
|
||||
if name in headers:
|
||||
# For duplicate headers, keep the first occurrence (most relevant for reply context)
|
||||
continue
|
||||
headers[name] = value
|
||||
|
||||
# Determine recipients if not specified
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [addr for _, addr in getaddresses([headers.get("to", "")])]
|
||||
recipients += [addr for _, addr in getaddresses([headers.get("cc", "")])]
|
||||
# Use dict.fromkeys() for O(n) deduplication while preserving order
|
||||
input_data.to = list(dict.fromkeys(filter(None, recipients)))
|
||||
else:
|
||||
# Check Reply-To header first, fall back to From header
|
||||
reply_to = headers.get("reply-to", "")
|
||||
from_addr = headers.get("from", "")
|
||||
sender = parseaddr(reply_to if reply_to else from_addr)[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
|
||||
# Set subject with Re: prefix if not already present
|
||||
if input_data.subject:
|
||||
subject = input_data.subject
|
||||
else:
|
||||
parent_subject = headers.get("subject", "").strip()
|
||||
# Only add "Re:" if not already present (case-insensitive check)
|
||||
if parent_subject.lower().startswith("re:"):
|
||||
subject = parent_subject
|
||||
else:
|
||||
subject = f"Re: {parent_subject}" if parent_subject else "Re:"
|
||||
|
||||
# Build references header for proper threading
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
|
||||
# Use the helper function for consistent content type handling
|
||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
||||
|
||||
# Handle attachments
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
# Encode message
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
return raw, input_data.threadId
|
||||
|
||||
|
||||
class GmailReplyBlock(GmailBase):
|
||||
"""
|
||||
Replies to Gmail threads with intelligent content type detection.
|
||||
@@ -1230,93 +1341,146 @@ class GmailReplyBlock(GmailBase):
|
||||
async def _reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
parent = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in parent.get("payload", {}).get("headers", [])
|
||||
}
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("to", "")])
|
||||
]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("cc", "")])
|
||||
]
|
||||
dedup: list[str] = []
|
||||
for r in recipients:
|
||||
if r and r not in dedup:
|
||||
dedup.append(r)
|
||||
input_data.to = dedup
|
||||
else:
|
||||
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
subject = input_data.subject or (f"Re: {headers.get('subject', '')}".strip())
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
# Use the new helper function for consistent content type handling
|
||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
||||
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
# Send the message
|
||||
return await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"threadId": input_data.threadId, "raw": raw})
|
||||
.send(userId="me", body={"threadId": thread_id, "raw": raw})
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class GmailDraftReplyBlock(GmailBase):
|
||||
"""
|
||||
Creates draft replies to Gmail threads with intelligent content type detection.
|
||||
|
||||
Features:
|
||||
- Automatic HTML detection: Draft replies containing HTML tags are formatted as text/html
|
||||
- No hard-wrap for plain text: Plain text draft replies preserve natural line flow
|
||||
- Manual content type override: Use content_type parameter to force specific format
|
||||
- Reply-all functionality: Option to reply to all original recipients
|
||||
- Thread preservation: Maintains proper email threading with headers
|
||||
- Full Unicode/emoji support with UTF-8 encoding
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
]
|
||||
)
|
||||
threadId: str = SchemaField(description="Thread ID to reply in")
|
||||
parentMessageId: str = SchemaField(
|
||||
description="ID of the message being replied to"
|
||||
)
|
||||
to: list[str] = SchemaField(description="To recipients", default_factory=list)
|
||||
cc: list[str] = SchemaField(description="CC recipients", default_factory=list)
|
||||
bcc: list[str] = SchemaField(description="BCC recipients", default_factory=list)
|
||||
replyAll: bool = SchemaField(
|
||||
description="Reply to all original recipients", default=False
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject", default="")
|
||||
body: str = SchemaField(description="Email body (plain text or HTML)")
|
||||
content_type: Optional[Literal["auto", "plain", "html"]] = SchemaField(
|
||||
description="Content type: 'auto' (default - detects HTML), 'plain', or 'html'",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
attachments: list[MediaFileType] = SchemaField(
|
||||
description="Files to attach", default_factory=list, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
draftId: str = SchemaField(description="Created draft ID")
|
||||
messageId: str = SchemaField(description="Draft message ID")
|
||||
threadId: str = SchemaField(description="Thread ID")
|
||||
status: str = SchemaField(description="Draft creation status")
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d7a9f3e2-8b4c-4d6f-9e1a-3c5b7f8d2a6e",
|
||||
description="Create draft replies to Gmail threads with automatic HTML detection and proper text formatting. Plain text draft replies maintain natural paragraph flow without 78-character line wrapping. HTML content is automatically detected and formatted correctly.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailDraftReplyBlock.Input,
|
||||
output_schema=GmailDraftReplyBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"threadId": "t1",
|
||||
"parentMessageId": "m1",
|
||||
"body": "Thanks for your message. I'll review and get back to you.",
|
||||
"replyAll": False,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("draftId", "draft1"),
|
||||
("messageId", "m2"),
|
||||
("threadId", "t1"),
|
||||
("status", "draft_created"),
|
||||
],
|
||||
test_mock={
|
||||
"_create_draft_reply": lambda *args, **kwargs: {
|
||||
"id": "draft1",
|
||||
"message": {"id": "m2", "threadId": "t1"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
draft = await self._create_draft_reply(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "draftId", draft["id"]
|
||||
yield "messageId", draft["message"]["id"]
|
||||
yield "threadId", draft["message"].get("threadId", input_data.threadId)
|
||||
yield "status", "draft_created"
|
||||
|
||||
async def _create_draft_reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Create draft with proper thread association
|
||||
draft = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.drafts()
|
||||
.create(
|
||||
userId="me",
|
||||
body={
|
||||
"message": {
|
||||
"threadId": thread_id,
|
||||
"raw": raw,
|
||||
}
|
||||
},
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return draft
|
||||
|
||||
|
||||
class GmailGetProfileBlock(GmailBase):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
|
||||
@@ -30,6 +30,7 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
|
||||
class IdeogramModelName(str, Enum):
|
||||
V3 = "V_3"
|
||||
V2 = "V_2"
|
||||
V1 = "V_1"
|
||||
V1_TURBO = "V_1_TURBO"
|
||||
@@ -95,8 +96,8 @@ class IdeogramModelBlock(Block):
|
||||
title="Prompt",
|
||||
)
|
||||
ideogram_model_name: IdeogramModelName = SchemaField(
|
||||
description="The name of the Image Generation Model, e.g., V_2",
|
||||
default=IdeogramModelName.V2,
|
||||
description="The name of the Image Generation Model, e.g., V_3",
|
||||
default=IdeogramModelName.V3,
|
||||
title="Image Generation Model",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -236,6 +237,111 @@ class IdeogramModelBlock(Block):
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
# Use V3 endpoint for V3 model, legacy endpoint for others
|
||||
if model_name == "V_3":
|
||||
return await self._run_model_v3(
|
||||
api_key,
|
||||
prompt,
|
||||
seed,
|
||||
aspect_ratio,
|
||||
magic_prompt_option,
|
||||
style_type,
|
||||
negative_prompt,
|
||||
color_palette_name,
|
||||
custom_colors,
|
||||
)
|
||||
else:
|
||||
return await self._run_model_legacy(
|
||||
api_key,
|
||||
model_name,
|
||||
prompt,
|
||||
seed,
|
||||
aspect_ratio,
|
||||
magic_prompt_option,
|
||||
style_type,
|
||||
negative_prompt,
|
||||
color_palette_name,
|
||||
custom_colors,
|
||||
)
|
||||
|
||||
async def _run_model_v3(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
prompt: str,
|
||||
seed: Optional[int],
|
||||
aspect_ratio: str,
|
||||
magic_prompt_option: str,
|
||||
style_type: str,
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
url = "https://api.ideogram.ai/v1/ideogram-v3/generate"
|
||||
headers = {
|
||||
"Api-Key": api_key.get_secret_value(),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Map legacy aspect ratio values to V3 format
|
||||
aspect_ratio_map = {
|
||||
"ASPECT_10_16": "10x16",
|
||||
"ASPECT_16_10": "16x10",
|
||||
"ASPECT_9_16": "9x16",
|
||||
"ASPECT_16_9": "16x9",
|
||||
"ASPECT_3_2": "3x2",
|
||||
"ASPECT_2_3": "2x3",
|
||||
"ASPECT_4_3": "4x3",
|
||||
"ASPECT_3_4": "3x4",
|
||||
"ASPECT_1_1": "1x1",
|
||||
"ASPECT_1_3": "1x3",
|
||||
"ASPECT_3_1": "3x1",
|
||||
# Additional V3 supported ratios
|
||||
"ASPECT_1_2": "1x2",
|
||||
"ASPECT_2_1": "2x1",
|
||||
"ASPECT_4_5": "4x5",
|
||||
"ASPECT_5_4": "5x4",
|
||||
}
|
||||
|
||||
v3_aspect_ratio = aspect_ratio_map.get(
|
||||
aspect_ratio, "1x1"
|
||||
) # Default to 1x1 if not found
|
||||
|
||||
# Use JSON for V3 endpoint (simpler than multipart/form-data)
|
||||
data: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": v3_aspect_ratio,
|
||||
"magic_prompt": magic_prompt_option,
|
||||
"style_type": style_type,
|
||||
}
|
||||
|
||||
if seed is not None:
|
||||
data["seed"] = seed
|
||||
|
||||
if negative_prompt:
|
||||
data["negative_prompt"] = negative_prompt
|
||||
|
||||
# Note: V3 endpoint may have different color palette support
|
||||
# For now, we'll omit color palettes for V3 to avoid errors
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
|
||||
|
||||
async def _run_model_legacy(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
seed: Optional[int],
|
||||
aspect_ratio: str,
|
||||
magic_prompt_option: str,
|
||||
style_type: str,
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
url = "https://api.ideogram.ai/generate"
|
||||
headers = {
|
||||
@@ -249,28 +355,33 @@ class IdeogramModelBlock(Block):
|
||||
"model": model_name,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"magic_prompt_option": magic_prompt_option,
|
||||
"style_type": style_type,
|
||||
}
|
||||
}
|
||||
|
||||
# Only add style_type for V2, V2_TURBO, and V3 models (V1 models don't support it)
|
||||
if model_name in ["V_2", "V_2_TURBO", "V_3"]:
|
||||
data["image_request"]["style_type"] = style_type
|
||||
|
||||
if seed is not None:
|
||||
data["image_request"]["seed"] = seed
|
||||
|
||||
if negative_prompt:
|
||||
data["image_request"]["negative_prompt"] = negative_prompt
|
||||
|
||||
if color_palette_name != "NONE":
|
||||
data["color_palette"] = {"name": color_palette_name}
|
||||
elif custom_colors:
|
||||
data["color_palette"] = {
|
||||
"members": [{"color_hex": color} for color in custom_colors]
|
||||
}
|
||||
# Only add color palette for V2 and V2_TURBO models (V1 models don't support it)
|
||||
if model_name in ["V_2", "V_2_TURBO"]:
|
||||
if color_palette_name != "NONE":
|
||||
data["color_palette"] = {"name": color_palette_name}
|
||||
elif custom_colors:
|
||||
data["color_palette"] = {
|
||||
"members": [{"color_hex": color} for color in custom_colors]
|
||||
}
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image: {str(e)}")
|
||||
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
|
||||
|
||||
async def upscale_image(self, api_key: SecretStr, image_url: str):
|
||||
url = "https://api.ideogram.ai/upscale"
|
||||
|
||||
@@ -896,6 +896,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history]
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
"""Removes indentation up to and including `|` from a multi-line prompt."""
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
@@ -909,24 +910,25 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
if input_data.expected_format:
|
||||
expected_format = [
|
||||
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
|
||||
f"{json.dumps(k)}: {json.dumps(v)}"
|
||||
for k, v in input_data.expected_format.items()
|
||||
]
|
||||
if input_data.list_result:
|
||||
format_prompt = (
|
||||
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
|
||||
)
|
||||
else:
|
||||
format_prompt = "\n ".join(expected_format)
|
||||
format_prompt = ",\n| ".join(expected_format)
|
||||
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply strictly only in the following JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
|Reply with pure JSON strictly following this JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. DO NOT include any additional text (e.g. markdown code block fences) outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
@@ -946,7 +948,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
error_feedback_message = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
@@ -970,8 +972,25 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
try:
|
||||
response_obj = json.loads(response_text)
|
||||
except JSONDecodeError as json_error:
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
|
||||
response_obj = json.loads(response_text)
|
||||
indented_json_error = str(json_error).replace("\n", "\n|")
|
||||
error_feedback_message = trim_prompt(
|
||||
f"""
|
||||
|Your previous response could not be parsed as valid JSON:
|
||||
|
|
||||
|{indented_json_error}
|
||||
|
|
||||
|Please provide a valid JSON response that matches the expected format.
|
||||
"""
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
|
||||
if input_data.list_result and isinstance(response_obj, dict):
|
||||
if "results" in response_obj:
|
||||
@@ -979,7 +998,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
elif len(response_obj) == 1:
|
||||
response_obj = list(response_obj.values())
|
||||
|
||||
response_error = "\n".join(
|
||||
validation_errors = "\n".join(
|
||||
[
|
||||
validation_error
|
||||
for response_item in (
|
||||
@@ -991,7 +1010,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
]
|
||||
)
|
||||
|
||||
if not response_error:
|
||||
if not validation_errors:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
@@ -1001,6 +1020,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "response", response_obj
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
error_feedback_message = trim_prompt(
|
||||
f"""
|
||||
|Your response did not match the expected format:
|
||||
|
|
||||
|{validation_errors}
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "user", "content": error_feedback_message})
|
||||
else:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
@@ -1011,21 +1040,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "response", {"response": response_text}
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
|
||||
retry_prompt = trim_prompt(
|
||||
f"""
|
||||
|This is your previous error response:
|
||||
|--
|
||||
|{response_text}
|
||||
|--
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{response_error}
|
||||
|--
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
@@ -1038,9 +1052,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(
|
||||
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
|
||||
)
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
# Don't add retry prompt for token limit errors,
|
||||
# just retry with lower maximum output tokens
|
||||
|
||||
raise RuntimeError(retry_prompt)
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
|
||||
raise RuntimeError(error_feedback_message)
|
||||
|
||||
|
||||
class AITextGeneratorBlock(AIBlockBase):
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
stagehand = (
|
||||
ProviderBuilder("stagehand")
|
||||
.with_api_key("STAGEHAND_API_KEY", "Stagehand API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
393
autogpt_platform/backend/backend/blocks/stagehand/blocks.py
Normal file
393
autogpt_platform/backend/backend/blocks/stagehand/blocks.py
Normal file
@@ -0,0 +1,393 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
# Monkey patch Stagehands to prevent signal handling in worker threads
|
||||
import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
ModelMetadata,
|
||||
)
|
||||
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
def safe_register_signal_handlers(self):
|
||||
"""Only register signal handlers in the main thread"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_register_signal_handlers(self)
|
||||
else:
|
||||
# Skip signal handling in worker threads
|
||||
pass
|
||||
|
||||
|
||||
# Replace the method
|
||||
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_signal_handling():
|
||||
"""Context manager to temporarily disable signal handling"""
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# In worker threads, temporarily replace signal.signal with a no-op
|
||||
original_signal = signal.signal
|
||||
|
||||
def noop_signal(*args, **kwargs):
|
||||
pass
|
||||
|
||||
signal.signal = noop_signal
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal = original_signal
|
||||
else:
|
||||
# In main thread, don't modify anything
|
||||
yield
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StagehandRecommendedLlmModel(str, Enum):
|
||||
"""
|
||||
This is subset of LLModel from autogpt_platform/backend/backend/blocks/llm.py
|
||||
|
||||
It contains only the models recommended by Stagehand
|
||||
"""
|
||||
|
||||
# OpenAI
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
|
||||
# Anthropic
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""
|
||||
Returns the provider name for the model in the required format for Stagehand:
|
||||
provider/model_name
|
||||
"""
|
||||
model_metadata = MODEL_METADATA[LlmModel(self.value)]
|
||||
model_name = self.value
|
||||
|
||||
if len(model_name.split("/")) == 1 and not self.value.startswith(
|
||||
model_metadata.provider
|
||||
):
|
||||
assert (
|
||||
model_metadata.provider != "open_router"
|
||||
), "Logic failed and open_router provider attempted to be prepended to model name! in stagehand/_config.py"
|
||||
model_name = f"{model_metadata.provider}/{model_name}"
|
||||
|
||||
logger.error(f"Model name: {model_name}")
|
||||
return model_name
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return MODEL_METADATA[LlmModel(self.value)].provider
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[LlmModel(self.value)]
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return MODEL_METADATA[LlmModel(self.value)].context_window
|
||||
|
||||
@property
|
||||
def max_output_tokens(self) -> int | None:
|
||||
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
|
||||
|
||||
|
||||
class StagehandObserveBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
# Browserbase credentials (Stagehand provider) or raw API key
|
||||
stagehand_credentials: CredentialsMetaInput = (
|
||||
stagehand_provider.credentials_field(
|
||||
description="Stagehand/Browserbase API key"
|
||||
)
|
||||
)
|
||||
browserbase_project_id: str = SchemaField(
|
||||
description="Browserbase project ID (required if using Browserbase)",
|
||||
)
|
||||
# Model selection and credentials (provider-discriminated like llm.py)
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
url: str = SchemaField(
|
||||
description="URL to navigate to.",
|
||||
)
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
selector: str = SchemaField(description="XPath selector to locate element.")
|
||||
description: str = SchemaField(description="Human-readable description")
|
||||
method: str | None = SchemaField(description="Suggested action method")
|
||||
arguments: list[str] | None = SchemaField(
|
||||
description="Additional action parameters"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d3863944-0eaf-45c4-a0c9-63e0fe1ee8b9",
|
||||
description="Find suggested actions for your workflows",
|
||||
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StagehandObserveBlock.Input,
|
||||
output_schema=StagehandObserveBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
stagehand_credentials: APIKeyCredentials,
|
||||
model_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
|
||||
observe_results = await page.observe(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
for result in observe_results:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
|
||||
|
||||
class StagehandActBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
# Browserbase credentials (Stagehand provider) or raw API key
|
||||
stagehand_credentials: CredentialsMetaInput = (
|
||||
stagehand_provider.credentials_field(
|
||||
description="Stagehand/Browserbase API key"
|
||||
)
|
||||
)
|
||||
browserbase_project_id: str = SchemaField(
|
||||
description="Browserbase project ID (required if using Browserbase)",
|
||||
)
|
||||
# Model selection and credentials (provider-discriminated like llm.py)
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
url: str = SchemaField(
|
||||
description="URL to navigate to.",
|
||||
)
|
||||
action: list[str] = SchemaField(
|
||||
description="Action to perform. Suggested actions are: click, fill, type, press, scroll, select from dropdown. For multi-step actions, add an entry for each step.",
|
||||
)
|
||||
variables: dict[str, str] = SchemaField(
|
||||
description="Variables to use in the action. Variables contains data you want the action to use.",
|
||||
default_factory=dict,
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
timeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
|
||||
default=60000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the action was completed successfully"
|
||||
)
|
||||
message: str = SchemaField(description="Details about the action’s execution.")
|
||||
action: str = SchemaField(description="Action performed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="86eba68b-9549-4c0b-a0db-47d85a56cc27",
|
||||
description="Interact with a web page by performing actions on a web page. Use it to build self-healing and deterministic automations that adapt to website chang.",
|
||||
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StagehandActBlock.Input,
|
||||
output_schema=StagehandActBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
stagehand_credentials: APIKeyCredentials,
|
||||
model_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
for action in input_data.action:
|
||||
action_results = await page.act(
|
||||
action,
|
||||
variables=input_data.variables,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
timeoutMs=input_data.timeoutMs,
|
||||
)
|
||||
yield "success", action_results.success
|
||||
yield "message", action_results.message
|
||||
yield "action", action_results.action
|
||||
|
||||
|
||||
class StagehandExtractBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
# Browserbase credentials (Stagehand provider) or raw API key
|
||||
stagehand_credentials: CredentialsMetaInput = (
|
||||
stagehand_provider.credentials_field(
|
||||
description="Stagehand/Browserbase API key"
|
||||
)
|
||||
)
|
||||
browserbase_project_id: str = SchemaField(
|
||||
description="Browserbase project ID (required if using Browserbase)",
|
||||
)
|
||||
# Model selection and credentials (provider-discriminated like llm.py)
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
url: str = SchemaField(
|
||||
description="URL to navigate to.",
|
||||
)
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
extraction: str = SchemaField(description="Extracted data from the page.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fd3c0b18-2ba6-46ae-9339-fcb40537ad98",
|
||||
description="Extract structured data from a webpage.",
|
||||
categories={BlockCategory.AI, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StagehandExtractBlock.Input,
|
||||
output_schema=StagehandExtractBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
stagehand_credentials: APIKeyCredentials,
|
||||
model_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
extraction = await page.extract(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
yield "extraction", str(extraction.model_dump()["extraction"])
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,283 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
|
||||
class LibraryAgent(BaseModel):
|
||||
"""Model representing an agent in the user's library."""
|
||||
|
||||
library_agent_id: str = ""
|
||||
agent_id: str = ""
|
||||
agent_version: int = 0
|
||||
agent_name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
is_archived: bool = False
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class AddToLibraryFromStoreBlock(Block):
|
||||
"""
|
||||
Block that adds an agent from the store to the user's library.
|
||||
This enables users to easily import agents from the marketplace into their personal collection.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
store_listing_version_id: str = SchemaField(
|
||||
description="The ID of the store listing version to add to library"
|
||||
)
|
||||
agent_name: str | None = SchemaField(
|
||||
description="Optional custom name for the agent in your library",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the agent was successfully added to library"
|
||||
)
|
||||
library_agent_id: str = SchemaField(
|
||||
description="The ID of the library agent entry"
|
||||
)
|
||||
agent_id: str = SchemaField(description="The ID of the agent graph")
|
||||
agent_version: int = SchemaField(
|
||||
description="The version number of the agent graph"
|
||||
)
|
||||
agent_name: str = SchemaField(description="The name of the agent")
|
||||
message: str = SchemaField(description="Success or error message")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2602a7b1-3f4d-4e5f-9c8b-1a2b3c4d5e6f",
|
||||
description="Add an agent from the store to your personal library",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToLibraryFromStoreBlock.Input,
|
||||
output_schema=AddToLibraryFromStoreBlock.Output,
|
||||
test_input={
|
||||
"store_listing_version_id": "test-listing-id",
|
||||
"agent_name": "My Custom Agent",
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("library_agent_id", "test-library-id"),
|
||||
("agent_id", "test-agent-id"),
|
||||
("agent_version", 1),
|
||||
("agent_name", "Test Agent"),
|
||||
("message", "Agent successfully added to library"),
|
||||
],
|
||||
test_mock={
|
||||
"_add_to_library": lambda *_, **__: LibraryAgent(
|
||||
library_agent_id="test-library-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Agent",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
library_agent = await self._add_to_library(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=input_data.store_listing_version_id,
|
||||
custom_name=input_data.agent_name,
|
||||
)
|
||||
|
||||
yield "success", True
|
||||
yield "library_agent_id", library_agent.library_agent_id
|
||||
yield "agent_id", library_agent.agent_id
|
||||
yield "agent_version", library_agent.agent_version
|
||||
yield "agent_name", library_agent.agent_name
|
||||
yield "message", "Agent successfully added to library"
|
||||
|
||||
async def _add_to_library(
|
||||
self,
|
||||
user_id: str,
|
||||
store_listing_version_id: str,
|
||||
custom_name: str | None = None,
|
||||
) -> LibraryAgent:
|
||||
"""
|
||||
Add a store agent to the user's library using the existing library database function.
|
||||
"""
|
||||
library_agent = (
|
||||
await get_database_manager_async_client().add_store_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id, user_id=user_id
|
||||
)
|
||||
)
|
||||
|
||||
# If custom name is provided, we could update the library agent name here
|
||||
# For now, we'll just return the agent info
|
||||
agent_name = custom_name if custom_name else library_agent.name
|
||||
|
||||
return LibraryAgent(
|
||||
library_agent_id=library_agent.id,
|
||||
agent_id=library_agent.graph_id,
|
||||
agent_version=library_agent.graph_version,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
|
||||
class ListLibraryAgentsBlock(Block):
|
||||
"""
|
||||
Block that lists all agents in the user's library.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
search_query: str | None = SchemaField(
|
||||
description="Optional search query to filter agents", default=None
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of agents to return", default=50, ge=1, le=100
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination", default=1, ge=1
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
agents: list[LibraryAgent] = SchemaField(
|
||||
description="List of agents in the library",
|
||||
default_factory=list,
|
||||
)
|
||||
agent: LibraryAgent = SchemaField(
|
||||
description="Individual library agent (yielded for each agent)"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of agents in library", default=0
|
||||
)
|
||||
page: int = SchemaField(description="Current page number", default=1)
|
||||
total_pages: int = SchemaField(description="Total number of pages", default=1)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="082602d3-a74d-4600-9e9c-15b3af7eae98",
|
||||
description="List all agents in your personal library",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=ListLibraryAgentsBlock.Input,
|
||||
output_schema=ListLibraryAgentsBlock.Output,
|
||||
test_input={
|
||||
"search_query": None,
|
||||
"limit": 10,
|
||||
"page": 1,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"agents",
|
||||
[
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
),
|
||||
],
|
||||
),
|
||||
("total_count", 1),
|
||||
("page", 1),
|
||||
("total_pages", 1),
|
||||
(
|
||||
"agent",
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_list_library_agents": lambda *_, **__: {
|
||||
"agents": [
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
)
|
||||
],
|
||||
"total": 1,
|
||||
"page": 1,
|
||||
"total_pages": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self._list_library_agents(
|
||||
user_id=user_id,
|
||||
search_query=input_data.search_query,
|
||||
limit=input_data.limit,
|
||||
page=input_data.page,
|
||||
)
|
||||
|
||||
agents = result["agents"]
|
||||
|
||||
yield "agents", agents
|
||||
yield "total_count", result["total"]
|
||||
yield "page", result["page"]
|
||||
yield "total_pages", result["total_pages"]
|
||||
|
||||
# Yield each agent individually for better graph connectivity
|
||||
for agent in agents:
|
||||
yield "agent", agent
|
||||
|
||||
async def _list_library_agents(
|
||||
self,
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
limit: int = 50,
|
||||
page: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
List agents in the user's library using the database client.
|
||||
"""
|
||||
result = await get_database_manager_async_client().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=page,
|
||||
page_size=limit,
|
||||
)
|
||||
|
||||
agents = [
|
||||
LibraryAgent(
|
||||
library_agent_id=agent.id,
|
||||
agent_id=agent.graph_id,
|
||||
agent_version=agent.graph_version,
|
||||
agent_name=agent.name,
|
||||
description=getattr(agent, "description", ""),
|
||||
creator=getattr(agent, "creator", ""),
|
||||
is_archived=getattr(agent, "is_archived", False),
|
||||
categories=getattr(agent, "categories", []),
|
||||
)
|
||||
for agent in result.agents
|
||||
]
|
||||
|
||||
return {
|
||||
"agents": agents,
|
||||
"total": result.pagination.total_items,
|
||||
"page": result.pagination.current_page,
|
||||
"total_pages": result.pagination.total_pages,
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
|
||||
class StoreAgent(BaseModel):
|
||||
"""Model representing a store agent."""
|
||||
|
||||
slug: str = ""
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
rating: float = 0.0
|
||||
runs: int = 0
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class StoreAgentDict(BaseModel):
|
||||
"""Dictionary representation of a store agent."""
|
||||
|
||||
slug: str
|
||||
name: str
|
||||
description: str
|
||||
creator: str
|
||||
rating: float
|
||||
runs: int
|
||||
|
||||
|
||||
class SearchAgentsResponse(BaseModel):
|
||||
"""Response from searching store agents."""
|
||||
|
||||
agents: list[StoreAgentDict]
|
||||
total_count: int
|
||||
|
||||
|
||||
class StoreAgentDetails(BaseModel):
|
||||
"""Detailed information about a store agent."""
|
||||
|
||||
found: bool
|
||||
store_listing_version_id: str = ""
|
||||
agent_name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
categories: list[str] = []
|
||||
runs: int = 0
|
||||
rating: float = 0.0
|
||||
|
||||
|
||||
class GetStoreAgentDetailsBlock(Block):
|
||||
"""
|
||||
Block that retrieves detailed information about an agent from the store.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
creator: str = SchemaField(description="The username of the agent creator")
|
||||
slug: str = SchemaField(description="The name of the agent")
|
||||
|
||||
class Output(BlockSchema):
|
||||
found: bool = SchemaField(
|
||||
description="Whether the agent was found in the store"
|
||||
)
|
||||
store_listing_version_id: str = SchemaField(
|
||||
description="The store listing version ID"
|
||||
)
|
||||
agent_name: str = SchemaField(description="Name of the agent")
|
||||
description: str = SchemaField(description="Description of the agent")
|
||||
creator: str = SchemaField(description="Creator of the agent")
|
||||
categories: list[str] = SchemaField(
|
||||
description="Categories the agent belongs to", default_factory=list
|
||||
)
|
||||
runs: int = SchemaField(
|
||||
description="Number of times the agent has been run", default=0
|
||||
)
|
||||
rating: float = SchemaField(
|
||||
description="Average rating of the agent", default=0.0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b604f0ec-6e0d-40a7-bf55-9fd09997cced",
|
||||
description="Get detailed information about an agent from the store",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=GetStoreAgentDetailsBlock.Input,
|
||||
output_schema=GetStoreAgentDetailsBlock.Output,
|
||||
test_input={"creator": "test-creator", "slug": "test-agent-slug"},
|
||||
test_output=[
|
||||
("found", True),
|
||||
("store_listing_version_id", "test-listing-id"),
|
||||
("agent_name", "Test Agent"),
|
||||
("description", "A test agent"),
|
||||
("creator", "Test Creator"),
|
||||
("categories", ["productivity", "automation"]),
|
||||
("runs", 100),
|
||||
("rating", 4.5),
|
||||
],
|
||||
test_mock={
|
||||
"_get_agent_details": lambda *_, **__: StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id="test-listing-id",
|
||||
agent_name="Test Agent",
|
||||
description="A test agent",
|
||||
creator="Test Creator",
|
||||
categories=["productivity", "automation"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
)
|
||||
},
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
details = await self._get_agent_details(
|
||||
creator=input_data.creator, slug=input_data.slug
|
||||
)
|
||||
yield "found", details.found
|
||||
yield "store_listing_version_id", details.store_listing_version_id
|
||||
yield "agent_name", details.agent_name
|
||||
yield "description", details.description
|
||||
yield "creator", details.creator
|
||||
yield "categories", details.categories
|
||||
yield "runs", details.runs
|
||||
yield "rating", details.rating
|
||||
|
||||
async def _get_agent_details(self, creator: str, slug: str) -> StoreAgentDetails:
|
||||
"""
|
||||
Retrieve detailed information about a store agent.
|
||||
"""
|
||||
# Get by specific version ID
|
||||
agent_details = (
|
||||
await get_database_manager_async_client().get_store_agent_details(
|
||||
username=creator, agent_name=slug
|
||||
)
|
||||
)
|
||||
|
||||
return StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id=agent_details.store_listing_version_id,
|
||||
agent_name=agent_details.agent_name,
|
||||
description=agent_details.description,
|
||||
creator=agent_details.creator,
|
||||
categories=(
|
||||
agent_details.categories if hasattr(agent_details, "categories") else []
|
||||
),
|
||||
runs=agent_details.runs,
|
||||
rating=agent_details.rating,
|
||||
)
|
||||
|
||||
|
||||
class SearchStoreAgentsBlock(Block):
|
||||
"""
|
||||
Block that searches for agents in the store based on various criteria.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
query: str | None = SchemaField(
|
||||
description="Search query to find agents", default=None
|
||||
)
|
||||
category: str | None = SchemaField(
|
||||
description="Filter by category", default=None
|
||||
)
|
||||
sort_by: Literal["rating", "runs", "name", "recent"] = SchemaField(
|
||||
description="How to sort the results", default="rating"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
agents: list[StoreAgent] = SchemaField(
|
||||
description="List of agents matching the search criteria",
|
||||
default_factory=list,
|
||||
)
|
||||
agent: StoreAgent = SchemaField(description="Basic information of the agent")
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of agents found", default=0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="39524701-026c-4328-87cc-1b88c8e2cb4c",
|
||||
description="Search for agents in the store",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=SearchStoreAgentsBlock.Input,
|
||||
output_schema=SearchStoreAgentsBlock.Output,
|
||||
test_input={
|
||||
"query": "productivity",
|
||||
"category": None,
|
||||
"sort_by": "rating",
|
||||
"limit": 10,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"agents",
|
||||
[
|
||||
{
|
||||
"slug": "test-agent",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"creator": "Test Creator",
|
||||
"rating": 4.5,
|
||||
"runs": 100,
|
||||
}
|
||||
],
|
||||
),
|
||||
("total_count", 1),
|
||||
(
|
||||
"agent",
|
||||
{
|
||||
"slug": "test-agent",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"creator": "Test Creator",
|
||||
"rating": 4.5,
|
||||
"runs": 100,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_search_agents": lambda *_, **__: SearchAgentsResponse(
|
||||
agents=[
|
||||
StoreAgentDict(
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="A test agent",
|
||||
creator="Test Creator",
|
||||
rating=4.5,
|
||||
runs=100,
|
||||
)
|
||||
],
|
||||
total_count=1,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self._search_agents(
|
||||
query=input_data.query,
|
||||
category=input_data.category,
|
||||
sort_by=input_data.sort_by,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
agents = result.agents
|
||||
total_count = result.total_count
|
||||
|
||||
# Convert to dict for output
|
||||
agents_as_dicts = [agent.model_dump() for agent in agents]
|
||||
|
||||
yield "agents", agents_as_dicts
|
||||
yield "total_count", total_count
|
||||
|
||||
for agent_dict in agents_as_dicts:
|
||||
yield "agent", agent_dict
|
||||
|
||||
async def _search_agents(
|
||||
self,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
sort_by: str = "rating",
|
||||
limit: int = 10,
|
||||
) -> SearchAgentsResponse:
|
||||
"""
|
||||
Search for agents in the store using the existing store database function.
|
||||
"""
|
||||
# Map our sort_by to the store's sorted_by parameter
|
||||
sorted_by_map = {
|
||||
"rating": "most_popular",
|
||||
"runs": "most_runs",
|
||||
"name": "alphabetical",
|
||||
"recent": "recently_updated",
|
||||
}
|
||||
|
||||
result = await get_database_manager_async_client().get_store_agents(
|
||||
featured=False,
|
||||
creators=None,
|
||||
sorted_by=sorted_by_map.get(sort_by, "most_popular"),
|
||||
search_query=query,
|
||||
category=category,
|
||||
page=1,
|
||||
page_size=limit,
|
||||
)
|
||||
|
||||
agents = [
|
||||
StoreAgentDict(
|
||||
slug=agent.slug,
|
||||
name=agent.agent_name,
|
||||
description=agent.description,
|
||||
creator=agent.creator,
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
)
|
||||
for agent in result.agents
|
||||
]
|
||||
|
||||
return SearchAgentsResponse(agents=agents, total_count=len(agents))
|
||||
@@ -35,20 +35,19 @@ async def execute_graph(
|
||||
logger.info("Input data: %s", input_data)
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
response = await agent_server.test_execute_graph(
|
||||
graph_exec = await agent_server.test_execute_graph(
|
||||
user_id=test_user.id,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
node_input=input_data,
|
||||
)
|
||||
graph_exec_id = response.graph_exec_id
|
||||
logger.info("Created execution with ID: %s", graph_exec_id)
|
||||
logger.info("Created execution with ID: %s", graph_exec.id)
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, graph_exec_id, 30)
|
||||
result = await wait_execution(test_user.id, graph_exec.id, 30)
|
||||
logger.info("Execution completed with %d results", len(result))
|
||||
return graph_exec_id
|
||||
return graph_exec.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.system.library_operations import (
|
||||
AddToLibraryFromStoreBlock,
|
||||
LibraryAgent,
|
||||
)
|
||||
from backend.blocks.system.store_operations import (
|
||||
GetStoreAgentDetailsBlock,
|
||||
SearchAgentsResponse,
|
||||
SearchStoreAgentsBlock,
|
||||
StoreAgentDetails,
|
||||
StoreAgentDict,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_to_library_from_store_block_success(mocker):
|
||||
"""Test successful addition of agent from store to library."""
|
||||
block = AddToLibraryFromStoreBlock()
|
||||
|
||||
# Mock the library agent response
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.id = "lib-agent-123"
|
||||
mock_library_agent.graph_id = "graph-456"
|
||||
mock_library_agent.graph_version = 1
|
||||
mock_library_agent.name = "Test Agent"
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_add_to_library",
|
||||
return_value=LibraryAgent(
|
||||
library_agent_id="lib-agent-123",
|
||||
agent_id="graph-456",
|
||||
agent_version=1,
|
||||
agent_name="Test Agent",
|
||||
),
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
store_listing_version_id="store-listing-v1", agent_name="Custom Agent Name"
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, user_id="test-user"):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["success"] is True
|
||||
assert outputs["library_agent_id"] == "lib-agent-123"
|
||||
assert outputs["agent_id"] == "graph-456"
|
||||
assert outputs["agent_version"] == 1
|
||||
assert outputs["agent_name"] == "Test Agent"
|
||||
assert outputs["message"] == "Agent successfully added to library"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_agent_details_block_success(mocker):
|
||||
"""Test successful retrieval of store agent details."""
|
||||
block = GetStoreAgentDetailsBlock()
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_get_agent_details",
|
||||
return_value=StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id="version-123",
|
||||
agent_name="Test Agent",
|
||||
description="A test agent for testing",
|
||||
creator="Test Creator",
|
||||
categories=["productivity", "automation"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
),
|
||||
)
|
||||
|
||||
input_data = block.Input(creator="Test Creator", slug="test-slug")
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["found"] is True
|
||||
assert outputs["store_listing_version_id"] == "version-123"
|
||||
assert outputs["agent_name"] == "Test Agent"
|
||||
assert outputs["description"] == "A test agent for testing"
|
||||
assert outputs["creator"] == "Test Creator"
|
||||
assert outputs["categories"] == ["productivity", "automation"]
|
||||
assert outputs["runs"] == 100
|
||||
assert outputs["rating"] == 4.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_store_agents_block(mocker):
|
||||
"""Test searching for store agents."""
|
||||
block = SearchStoreAgentsBlock()
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_search_agents",
|
||||
return_value=SearchAgentsResponse(
|
||||
agents=[
|
||||
StoreAgentDict(
|
||||
slug="creator1/agent1",
|
||||
name="Agent One",
|
||||
description="First test agent",
|
||||
creator="Creator 1",
|
||||
rating=4.8,
|
||||
runs=500,
|
||||
),
|
||||
StoreAgentDict(
|
||||
slug="creator2/agent2",
|
||||
name="Agent Two",
|
||||
description="Second test agent",
|
||||
creator="Creator 2",
|
||||
rating=4.2,
|
||||
runs=200,
|
||||
),
|
||||
],
|
||||
total_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
query="test", category="productivity", sort_by="rating", limit=10
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data):
|
||||
outputs[name] = value
|
||||
|
||||
assert len(outputs["agents"]) == 2
|
||||
assert outputs["total_count"] == 2
|
||||
assert outputs["agents"][0]["name"] == "Agent One"
|
||||
assert outputs["agents"][0]["rating"] == 4.8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_store_agents_block_empty_results(mocker):
|
||||
"""Test searching with no results."""
|
||||
block = SearchStoreAgentsBlock()
|
||||
|
||||
mocker.patch.object(
|
||||
block,
|
||||
"_search_agents",
|
||||
return_value=SearchAgentsResponse(agents=[], total_count=0),
|
||||
)
|
||||
|
||||
input_data = block.Input(query="nonexistent", limit=10)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["agents"] == []
|
||||
assert outputs["total_count"] == 0
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Literal, Union
|
||||
@@ -7,6 +8,7 @@ from zoneinfo import ZoneInfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.execution import UserContext
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
# Shared timezone literal type for all time/date blocks
|
||||
@@ -51,16 +53,80 @@ TimezoneLiteral = Literal[
|
||||
"Etc/GMT+12", # UTC-12:00
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_timezone(
|
||||
format_type: Any, # Any format type with timezone and use_user_timezone attributes
|
||||
user_timezone: str | None,
|
||||
) -> ZoneInfo:
|
||||
"""
|
||||
Determine which timezone to use based on format settings and user context.
|
||||
|
||||
Args:
|
||||
format_type: The format configuration containing timezone settings
|
||||
user_timezone: The user's timezone from context
|
||||
|
||||
Returns:
|
||||
ZoneInfo object for the determined timezone
|
||||
"""
|
||||
if format_type.use_user_timezone and user_timezone:
|
||||
tz = ZoneInfo(user_timezone)
|
||||
logger.debug(f"Using user timezone: {user_timezone}")
|
||||
else:
|
||||
tz = ZoneInfo(format_type.timezone)
|
||||
logger.debug(f"Using specified timezone: {format_type.timezone}")
|
||||
return tz
|
||||
|
||||
|
||||
def _format_datetime_iso8601(dt: datetime, include_microseconds: bool = False) -> str:
|
||||
"""
|
||||
Format a datetime object to ISO8601 string.
|
||||
|
||||
Args:
|
||||
dt: The datetime object to format
|
||||
include_microseconds: Whether to include microseconds in the output
|
||||
|
||||
Returns:
|
||||
ISO8601 formatted string
|
||||
"""
|
||||
if include_microseconds:
|
||||
return dt.isoformat()
|
||||
else:
|
||||
return dt.isoformat(timespec="seconds")
|
||||
|
||||
|
||||
# BACKWARDS COMPATIBILITY NOTE:
|
||||
# The timezone field is kept at the format level (not block level) for backwards compatibility.
|
||||
# Existing graphs have timezone saved within format_type, moving it would break them.
|
||||
#
|
||||
# The use_user_timezone flag was added to allow using the user's profile timezone.
|
||||
# Default is False to maintain backwards compatibility - existing graphs will continue
|
||||
# using their specified timezone.
|
||||
#
|
||||
# KNOWN ISSUE: If a user switches between format types (strftime <-> iso8601),
|
||||
# the timezone setting doesn't carry over. This is a UX issue but fixing it would
|
||||
# require either:
|
||||
# 1. Moving timezone to block level (breaking change, needs migration)
|
||||
# 2. Complex state management to sync timezone across format types
|
||||
#
|
||||
# Future migration path: When we do a major version bump, consider moving timezone
|
||||
# to the block Input level for better UX.
|
||||
|
||||
|
||||
class TimeStrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%H:%M:%S"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class TimeISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
include_microseconds: bool = False
|
||||
|
||||
|
||||
@@ -115,25 +181,27 @@ class GetCurrentTimeBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
async def run(
|
||||
self, input_data: Input, *, user_context: UserContext, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Extract timezone from user_context (always present)
|
||||
effective_timezone = user_context.timezone
|
||||
|
||||
# Get the appropriate timezone
|
||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
if isinstance(input_data.format_type, TimeISO8601Format):
|
||||
# ISO 8601 format for time only (extract time portion from full ISO datetime)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
# Get the full ISO format and extract just the time portion with timezone
|
||||
if input_data.format_type.include_microseconds:
|
||||
full_iso = dt.isoformat()
|
||||
else:
|
||||
full_iso = dt.isoformat(timespec="seconds")
|
||||
|
||||
full_iso = _format_datetime_iso8601(
|
||||
dt, input_data.format_type.include_microseconds
|
||||
)
|
||||
# Extract time portion (everything after 'T')
|
||||
current_time = full_iso.split("T")[1] if "T" in full_iso else full_iso
|
||||
current_time = f"T{current_time}" # Add T prefix for ISO 8601 time format
|
||||
else: # TimeStrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
current_time = dt.strftime(input_data.format_type.format)
|
||||
|
||||
yield "time", current_time
|
||||
|
||||
|
||||
@@ -141,11 +209,15 @@ class DateStrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%Y-%m-%d"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class DateISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class GetCurrentDateBlock(Block):
|
||||
@@ -217,20 +289,23 @@ class GetCurrentDateBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Extract timezone from user_context (required keyword argument)
|
||||
user_context: UserContext = kwargs["user_context"]
|
||||
effective_timezone = user_context.timezone
|
||||
|
||||
try:
|
||||
offset = int(input_data.offset)
|
||||
except ValueError:
|
||||
offset = 0
|
||||
|
||||
# Get the appropriate timezone
|
||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
|
||||
if isinstance(input_data.format_type, DateISO8601Format):
|
||||
# ISO 8601 format for date only (YYYY-MM-DD)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
# ISO 8601 date format is YYYY-MM-DD
|
||||
date_str = current_date.date().isoformat()
|
||||
else: # DateStrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
date_str = current_date.strftime(input_data.format_type.format)
|
||||
|
||||
yield "date", date_str
|
||||
@@ -240,11 +315,15 @@ class StrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%Y-%m-%d %H:%M:%S"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
|
||||
|
||||
class ISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
# When True, overrides timezone with user's profile timezone
|
||||
use_user_timezone: bool = False
|
||||
include_microseconds: bool = False
|
||||
|
||||
|
||||
@@ -316,20 +395,22 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Extract timezone from user_context (required keyword argument)
|
||||
user_context: UserContext = kwargs["user_context"]
|
||||
effective_timezone = user_context.timezone
|
||||
|
||||
# Get the appropriate timezone
|
||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
if isinstance(input_data.format_type, ISO8601Format):
|
||||
# ISO 8601 format with specified timezone (also RFC3339-compliant)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
# Format with or without microseconds
|
||||
if input_data.format_type.include_microseconds:
|
||||
current_date_time = dt.isoformat()
|
||||
else:
|
||||
current_date_time = dt.isoformat(timespec="seconds")
|
||||
current_date_time = _format_datetime_iso8601(
|
||||
dt, input_data.format_type.include_microseconds
|
||||
)
|
||||
else: # StrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
current_date_time = dt.strftime(input_data.format_type.format)
|
||||
|
||||
yield "date_time", current_date_time
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from logging import getLogger
|
||||
from typing import Any, Dict, List, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from pydantic import field_serializer
|
||||
|
||||
from backend.sdk import BaseModel, Credentials, Requests
|
||||
|
||||
logger = getLogger(__name__)
|
||||
@@ -382,8 +384,9 @@ class CreatePostRequest(BaseModel):
|
||||
# Advanced
|
||||
metadata: List[Dict[str, Any]] | None = None
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
@field_serializer("date")
|
||||
def serialize_date(self, value: datetime | None) -> str | None:
|
||||
return value.isoformat() if value else None
|
||||
|
||||
|
||||
class PostAuthor(BaseModel):
|
||||
|
||||
@@ -6,8 +6,6 @@ from dotenv import load_dotenv
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
|
||||
os.environ["ENABLE_AUTH"] = "false"
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
# Resolve Webhook <- NodeModel forward reference
|
||||
# Resolve Webhook forward references
|
||||
NodeModel.model_rebuild()
|
||||
LibraryAgentPreset.model_rebuild()
|
||||
|
||||
@@ -1,57 +1,31 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.api_key.key_manager import APIKeyManager
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.errors import PrismaError
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import (
|
||||
APIKeyCreateInput,
|
||||
APIKeyUpdateInput,
|
||||
APIKeyWhereInput,
|
||||
APIKeyWhereUniqueInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.db import BaseDbModel
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
|
||||
# Some basic exceptions
|
||||
class APIKeyError(Exception):
|
||||
"""Base exception for API key operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyNotFoundError(APIKeyError):
|
||||
"""Raised when an API key is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyPermissionError(APIKeyError):
|
||||
"""Raised when there are permission issues with API key operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyValidationError(APIKeyError):
|
||||
"""Raised when API key validation fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKey(BaseDbModel):
|
||||
class APIKeyInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
prefix: str
|
||||
key: str
|
||||
status: APIKeyStatus = APIKeyStatus.ACTIVE
|
||||
permissions: List[APIKeyPermission]
|
||||
postfix: str
|
||||
head: str = Field(
|
||||
description=f"The first {APIKeySmith.HEAD_LENGTH} characters of the key"
|
||||
)
|
||||
tail: str = Field(
|
||||
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
|
||||
)
|
||||
status: APIKeyStatus
|
||||
permissions: list[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
@@ -60,266 +34,211 @@ class APIKey(BaseDbModel):
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
try:
|
||||
return APIKey(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
prefix=api_key.prefix,
|
||||
postfix=api_key.postfix,
|
||||
key=api_key.key,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating APIKey from db: {str(e)}")
|
||||
raise APIKeyError(f"Failed to create API key object: {str(e)}")
|
||||
return APIKeyInfo(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
head=api_key.head,
|
||||
tail=api_key.tail,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
|
||||
|
||||
class APIKeyWithoutHash(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
status: APIKeyStatus
|
||||
permissions: List[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime]
|
||||
revoked_at: Optional[datetime]
|
||||
description: Optional[str]
|
||||
user_id: str
|
||||
class APIKeyInfoWithHash(APIKeyInfo):
|
||||
hash: str
|
||||
salt: str | None = None # None for legacy keys
|
||||
|
||||
def match(self, plaintext_key: str) -> bool:
|
||||
"""Returns whether the given key matches this API key object."""
|
||||
return keysmith.verify_key(plaintext_key, self.hash, self.salt)
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
try:
|
||||
return APIKeyWithoutHash(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
prefix=api_key.prefix,
|
||||
postfix=api_key.postfix,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating APIKeyWithoutHash from db: {str(e)}")
|
||||
raise APIKeyError(f"Failed to create API key object: {str(e)}")
|
||||
return APIKeyInfoWithHash(
|
||||
**APIKeyInfo.from_db(api_key).model_dump(),
|
||||
hash=api_key.hash,
|
||||
salt=api_key.salt,
|
||||
)
|
||||
|
||||
def without_hash(self) -> APIKeyInfo:
|
||||
return APIKeyInfo(**self.model_dump(exclude={"hash", "salt"}))
|
||||
|
||||
|
||||
async def generate_api_key(
|
||||
async def create_api_key(
|
||||
name: str,
|
||||
user_id: str,
|
||||
permissions: List[APIKeyPermission],
|
||||
permissions: list[APIKeyPermission],
|
||||
description: Optional[str] = None,
|
||||
) -> tuple[APIKeyWithoutHash, str]:
|
||||
) -> tuple[APIKeyInfo, str]:
|
||||
"""
|
||||
Generate a new API key and store it in the database.
|
||||
Returns the API key object (without hash) and the plain text key.
|
||||
"""
|
||||
try:
|
||||
api_manager = APIKeyManager()
|
||||
key = api_manager.generate_api_key()
|
||||
generated_key = keysmith.generate_key()
|
||||
|
||||
api_key = await PrismaAPIKey.prisma().create(
|
||||
data=APIKeyCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
prefix=key.prefix,
|
||||
postfix=key.postfix,
|
||||
key=key.hash,
|
||||
permissions=[p for p in permissions],
|
||||
description=description,
|
||||
userId=user_id,
|
||||
)
|
||||
)
|
||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
|
||||
api_key_without_hash = APIKeyWithoutHash.from_db(api_key)
|
||||
return api_key_without_hash, key.raw
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while generating API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to generate API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while generating API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to generate API key: {str(e)}")
|
||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||
|
||||
|
||||
async def validate_api_key(plain_text_key: str) -> Optional[APIKey]:
|
||||
async def get_active_api_keys_by_head(head: str) -> list[APIKeyInfoWithHash]:
|
||||
results = await PrismaAPIKey.prisma().find_many(
|
||||
where={"head": head, "status": APIKeyStatus.ACTIVE}
|
||||
)
|
||||
return [APIKeyInfoWithHash.from_db(key) for key in results]
|
||||
|
||||
|
||||
async def validate_api_key(plaintext_key: str) -> Optional[APIKeyInfo]:
|
||||
"""
|
||||
Validate an API key and return the API key object if valid.
|
||||
Validate an API key and return the API key object if valid and active.
|
||||
"""
|
||||
try:
|
||||
if not plain_text_key.startswith(APIKeyManager.PREFIX):
|
||||
if not plaintext_key.startswith(APIKeySmith.PREFIX):
|
||||
logger.warning("Invalid API key format")
|
||||
return None
|
||||
|
||||
prefix = plain_text_key[: APIKeyManager.PREFIX_LENGTH]
|
||||
api_manager = APIKeyManager()
|
||||
head = plaintext_key[: APIKeySmith.HEAD_LENGTH]
|
||||
potential_matches = await get_active_api_keys_by_head(head)
|
||||
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where=APIKeyWhereInput(prefix=prefix, status=(APIKeyStatus.ACTIVE))
|
||||
matched_api_key = next(
|
||||
(pm for pm in potential_matches if pm.match(plaintext_key)),
|
||||
None,
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
logger.warning(f"No active API key found with prefix {prefix}")
|
||||
if not matched_api_key:
|
||||
# API key not found or invalid
|
||||
return None
|
||||
|
||||
is_valid = api_manager.verify_api_key(plain_text_key, api_key.key)
|
||||
if not is_valid:
|
||||
logger.warning("API key verification failed")
|
||||
return None
|
||||
|
||||
return APIKey.from_db(api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating API key: {str(e)}")
|
||||
raise APIKeyValidationError(f"Failed to validate API key: {str(e)}")
|
||||
|
||||
|
||||
async def revoke_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to revoke this API key."
|
||||
# Migrate legacy keys to secure format on successful validation
|
||||
if matched_api_key.salt is None:
|
||||
matched_api_key = await _migrate_key_to_secure_hash(
|
||||
plaintext_key, matched_api_key
|
||||
)
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(
|
||||
status=APIKeyStatus.REVOKED, revokedAt=datetime.now(timezone.utc)
|
||||
),
|
||||
)
|
||||
return matched_api_key.without_hash()
|
||||
except Exception as e:
|
||||
logger.error(f"Error while validating API key: {e}")
|
||||
raise RuntimeError("Failed to validate API key") from e
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
|
||||
async def _migrate_key_to_secure_hash(
|
||||
plaintext_key: str, key_obj: APIKeyInfoWithHash
|
||||
) -> APIKeyInfoWithHash:
|
||||
"""Replace the SHA256 hash of a legacy API key with a salted Scrypt hash."""
|
||||
try:
|
||||
new_hash, new_salt = keysmith.hash_key(plaintext_key)
|
||||
await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_obj.id}, data={"hash": new_hash, "salt": new_salt}
|
||||
)
|
||||
logger.info(f"Migrated legacy API key #{key_obj.id} to secure format")
|
||||
# Update the API key object with new values for return
|
||||
key_obj.hash = new_hash
|
||||
key_obj.salt = new_salt
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate legacy API key #{key_obj.id}: {e}")
|
||||
|
||||
return key_obj
|
||||
|
||||
|
||||
async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to revoke this API key.")
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_id},
|
||||
data={
|
||||
"status": APIKeyStatus.REVOKED,
|
||||
"revokedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to revoke.")
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyInfo.from_db(key) for key in api_keys]
|
||||
|
||||
|
||||
async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
selector: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where=selector)
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to suspend this API key.")
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=selector, data={"status": APIKeyStatus.SUSPENDED}
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to suspend.")
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
|
||||
return required_permission in api_key.permissions
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where={"id": key_id, "userId": user_id}
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while revoking API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while revoking API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> List[APIKeyWithoutHash]:
|
||||
try:
|
||||
where_clause: APIKeyWhereInput = {"userId": user_id}
|
||||
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where=where_clause, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyWithoutHash.from_db(key) for key in api_keys]
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while listing API keys: {str(e)}")
|
||||
raise APIKeyError(f"Failed to list API keys: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing API keys: {str(e)}")
|
||||
raise APIKeyError(f"Failed to list API keys: {str(e)}")
|
||||
|
||||
|
||||
async def suspend_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to suspend this API key."
|
||||
)
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(status=APIKeyStatus.SUSPENDED),
|
||||
)
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while suspending API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while suspending API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
|
||||
|
||||
|
||||
def has_permission(api_key: APIKey, required_permission: APIKeyPermission) -> bool:
|
||||
try:
|
||||
return required_permission in api_key.permissions
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking API key permissions: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where=APIKeyWhereInput(id=key_id, userId=user_id)
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
return APIKeyWithoutHash.from_db(api_key)
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while getting API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to get API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to get API key: {str(e)}")
|
||||
return APIKeyInfo.from_db(api_key)
|
||||
|
||||
|
||||
async def update_api_key_permissions(
|
||||
key_id: str, user_id: str, permissions: List[APIKeyPermission]
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
key_id: str, user_id: str, permissions: list[APIKeyPermission]
|
||||
) -> APIKeyInfo:
|
||||
"""
|
||||
Update the permissions of an API key.
|
||||
"""
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if api_key is None:
|
||||
raise APIKeyNotFoundError("No such API key found.")
|
||||
if api_key is None:
|
||||
raise NotFoundError("No such API key found.")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to update this API key."
|
||||
)
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to update this API key.")
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(permissions=permissions),
|
||||
)
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_id},
|
||||
data={"permissions": permissions},
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to update.")
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while updating API key permissions: {str(e)}")
|
||||
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while updating API key permissions: {str(e)}")
|
||||
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
@@ -8,6 +8,7 @@ from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generic,
|
||||
Optional,
|
||||
@@ -44,9 +45,10 @@ if TYPE_CHECKING:
|
||||
|
||||
app_config = Config()
|
||||
|
||||
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
|
||||
BlockInput = dict[str, Any] # Input: 1 input pin consumes 1 data.
|
||||
BlockOutput = AsyncGen[BlockData, None] # Output: 1 output pin produces n data.
|
||||
BlockOutputEntry = tuple[str, Any] # Output data should be a tuple of (name, value).
|
||||
BlockOutput = AsyncGen[BlockOutputEntry, None] # Output: 1 output pin produces n data.
|
||||
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
|
||||
CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a dict.
|
||||
|
||||
|
||||
@@ -89,6 +91,45 @@ class BlockCategory(Enum):
|
||||
return {"category": self.name, "description": self.value}
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
|
||||
|
||||
class BlockInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
inputSchema: dict[str, Any]
|
||||
outputSchema: dict[str, Any]
|
||||
costs: list[BlockCost]
|
||||
description: str
|
||||
categories: list[dict[str, str]]
|
||||
contributors: list[dict[str, Any]]
|
||||
staticOutput: bool
|
||||
uiType: str
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
|
||||
@@ -306,7 +347,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
input_schema: Type[BlockSchemaInputType] = EmptySchema,
|
||||
output_schema: Type[BlockSchemaOutputType] = EmptySchema,
|
||||
test_input: BlockInput | list[BlockInput] | None = None,
|
||||
test_output: BlockData | list[BlockData] | None = None,
|
||||
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
|
||||
test_mock: dict[str, Any] | None = None,
|
||||
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
|
||||
disabled: bool = False,
|
||||
@@ -452,6 +493,24 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
"uiType": self.block_type.value,
|
||||
}
|
||||
|
||||
def get_info(self) -> BlockInfo:
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
return BlockInfo(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
inputSchema=self.input_schema.jsonschema(),
|
||||
outputSchema=self.output_schema.jsonschema(),
|
||||
costs=get_block_cost(self),
|
||||
description=self.description,
|
||||
categories=[category.dict() for category in self.categories],
|
||||
contributors=[
|
||||
contributor.model_dump() for contributor in self.contributors
|
||||
],
|
||||
staticOutput=self.static_output,
|
||||
uiType=self.block_type.value,
|
||||
)
|
||||
|
||||
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise ValueError(
|
||||
|
||||
@@ -29,8 +29,7 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data.block import Block
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
@@ -307,7 +306,18 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
"type": ideogram_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=18,
|
||||
cost_filter={
|
||||
"ideogram_model_name": "V_3",
|
||||
"credentials": {
|
||||
"id": ideogram_credentials.id,
|
||||
"provider": ideogram_credentials.provider,
|
||||
"type": ideogram_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
AIShortformVideoCreatorBlock: [
|
||||
BlockCost(
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import stripe
|
||||
from prisma import Json
|
||||
@@ -23,7 +23,6 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -34,13 +33,16 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.model import Pagination
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockCost
|
||||
|
||||
settings = Settings()
|
||||
stripe.api_key = settings.secrets.stripe_api_key
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -997,10 +999,14 @@ def get_user_credit_model() -> UserCreditBase:
|
||||
return UserCredit()
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||
def get_block_costs() -> dict[str, list["BlockCost"]]:
|
||||
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
||||
|
||||
|
||||
def get_block_cost(block: "Block") -> list["BlockCost"]:
|
||||
return BLOCK_COSTS.get(block.__class__, [])
|
||||
|
||||
|
||||
async def get_stripe_customer_id(user_id: str) -> str:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from prisma.models import CreditTransaction
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.execution import NodeExecutionEntry, UserContext
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.credentials_store import openai_credentials
|
||||
@@ -75,6 +75,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
"type": openai_credentials.type,
|
||||
},
|
||||
},
|
||||
user_context=UserContext(timezone="UTC"),
|
||||
),
|
||||
)
|
||||
assert spending_amount_1 > 0
|
||||
@@ -88,6 +89,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
node_exec_id="test_node_exec",
|
||||
block_id=AITextGeneratorBlock().id,
|
||||
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
user_context=UserContext(timezone="UTC"),
|
||||
),
|
||||
)
|
||||
assert spending_amount_2 == 0
|
||||
|
||||
@@ -11,11 +11,14 @@ from typing import (
|
||||
Generator,
|
||||
Generic,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
@@ -24,7 +27,6 @@ from prisma.models import (
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionUpdateManyMutationInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
@@ -39,6 +41,7 @@ from pydantic.fields import Field
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Config
|
||||
from backend.util.truncate import truncate
|
||||
@@ -59,7 +62,7 @@ from .includes import (
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
graph_execution_include,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -86,16 +89,49 @@ class BlockErrorStats(BaseModel):
|
||||
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
NodeInputMask = Mapping[str, JsonValue]
|
||||
NodesInputMasks = Mapping[str, NodeInputMask]
|
||||
|
||||
# dest: source
|
||||
VALID_STATUS_TRANSITIONS = {
|
||||
ExecutionStatus.QUEUED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
ExecutionStatus.RUNNING: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.TERMINATED, # For resuming halted execution
|
||||
],
|
||||
ExecutionStatus.COMPLETED: [
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
ExecutionStatus.FAILED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
ExecutionStatus.TERMINATED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
id: str # type: ignore # Override base class to make this required
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
preset_id: Optional[str] = None
|
||||
inputs: Optional[BlockInput] # no default -> required in the OpenAPI spec
|
||||
credential_inputs: Optional[dict[str, CredentialsMetaInput]]
|
||||
nodes_input_masks: Optional[dict[str, BlockInput]]
|
||||
preset_id: Optional[str]
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
is_shared: bool = False
|
||||
share_token: Optional[str] = None
|
||||
|
||||
class Stats(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
@@ -177,6 +213,18 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
user_id=_graph_exec.userId,
|
||||
graph_id=_graph_exec.agentGraphId,
|
||||
graph_version=_graph_exec.agentGraphVersion,
|
||||
inputs=cast(BlockInput | None, _graph_exec.inputs),
|
||||
credential_inputs=(
|
||||
{
|
||||
name: CredentialsMetaInput.model_validate(cmi)
|
||||
for name, cmi in cast(dict, _graph_exec.credentialInputs).items()
|
||||
}
|
||||
if _graph_exec.credentialInputs
|
||||
else None
|
||||
),
|
||||
nodes_input_masks=cast(
|
||||
dict[str, BlockInput] | None, _graph_exec.nodesInputMasks
|
||||
),
|
||||
preset_id=_graph_exec.agentPresetId,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
started_at=start_time,
|
||||
@@ -200,11 +248,13 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
if stats
|
||||
else None
|
||||
),
|
||||
is_shared=_graph_exec.isShared,
|
||||
share_token=_graph_exec.shareToken,
|
||||
)
|
||||
|
||||
|
||||
class GraphExecution(GraphExecutionMeta):
|
||||
inputs: BlockInput
|
||||
inputs: BlockInput # type: ignore - incompatible override is intentional
|
||||
outputs: CompletedBlockOutput
|
||||
|
||||
@staticmethod
|
||||
@@ -224,15 +274,18 @@ class GraphExecution(GraphExecutionMeta):
|
||||
)
|
||||
|
||||
inputs = {
|
||||
**{
|
||||
# inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
},
|
||||
**(
|
||||
graph_exec.inputs
|
||||
or {
|
||||
# fallback: extract inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
}
|
||||
),
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
@@ -250,14 +303,13 @@ class GraphExecution(GraphExecutionMeta):
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
outputs[exec.input_data["name"]].append(
|
||||
exec.input_data.get("value", None)
|
||||
)
|
||||
outputs[exec.input_data["name"]].append(exec.input_data.get("value"))
|
||||
|
||||
return GraphExecution(
|
||||
**{
|
||||
field_name: getattr(graph_exec, field_name)
|
||||
for field_name in GraphExecutionMeta.model_fields
|
||||
if field_name != "inputs"
|
||||
},
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
@@ -290,13 +342,18 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
def to_graph_execution_entry(self):
|
||||
def to_graph_execution_entry(
|
||||
self,
|
||||
user_context: "UserContext",
|
||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
|
||||
nodes_input_masks=compiled_nodes_input_masks,
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -332,7 +389,7 @@ class NodeExecutionResult(BaseModel):
|
||||
else:
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in _node_exec.Input or []:
|
||||
input_data[data.name] = type_utils.convert(data.data, type[Any])
|
||||
input_data[data.name] = type_utils.convert(data.data, JsonValue)
|
||||
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
|
||||
@@ -341,7 +398,7 @@ class NodeExecutionResult(BaseModel):
|
||||
output_data[name].extend(messages)
|
||||
else:
|
||||
for data in _node_exec.Output or []:
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
output_data[data.name].append(type_utils.convert(data.data, JsonValue))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
|
||||
if graph_execution:
|
||||
@@ -368,7 +425,9 @@ class NodeExecutionResult(BaseModel):
|
||||
end_time=_node_exec.endedTime,
|
||||
)
|
||||
|
||||
def to_node_execution_entry(self) -> "NodeExecutionEntry":
|
||||
def to_node_execution_entry(
|
||||
self, user_context: "UserContext"
|
||||
) -> "NodeExecutionEntry":
|
||||
return NodeExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_exec_id=self.graph_exec_id,
|
||||
@@ -377,6 +436,7 @@ class NodeExecutionResult(BaseModel):
|
||||
node_id=self.node_id,
|
||||
block_id=self.block_id,
|
||||
inputs=self.input_data,
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -384,13 +444,13 @@ class NodeExecutionResult(BaseModel):
|
||||
|
||||
|
||||
async def get_graph_executions(
|
||||
graph_exec_id: str | None = None,
|
||||
graph_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
created_time_gte: datetime | None = None,
|
||||
created_time_lte: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
graph_exec_id: Optional[str] = None,
|
||||
graph_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
statuses: Optional[list[ExecutionStatus]] = None,
|
||||
created_time_gte: Optional[datetime] = None,
|
||||
created_time_lte: Optional[datetime] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
@@ -418,6 +478,60 @@ async def get_graph_executions(
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
class GraphExecutionsPaginated(BaseModel):
|
||||
"""Response schema for paginated graph executions."""
|
||||
|
||||
executions: list[GraphExecutionMeta]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
async def get_graph_executions_paginated(
|
||||
user_id: str,
|
||||
graph_id: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
statuses: Optional[list[ExecutionStatus]] = None,
|
||||
created_time_gte: Optional[datetime] = None,
|
||||
created_time_lte: Optional[datetime] = None,
|
||||
) -> GraphExecutionsPaginated:
|
||||
"""Get paginated graph executions for a specific graph."""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
"userId": user_id,
|
||||
}
|
||||
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
if created_time_gte or created_time_lte:
|
||||
where_filter["createdAt"] = {
|
||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
if statuses:
|
||||
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
total_count = await AgentGraphExecution.prisma().count(where=where_filter)
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where=where_filter,
|
||||
order={"createdAt": "desc"},
|
||||
take=page_size,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
return GraphExecutionsPaginated(
|
||||
executions=[GraphExecutionMeta.from_db(execution) for execution in executions],
|
||||
pagination=Pagination(
|
||||
total_items=total_count,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_execution_meta(
|
||||
user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
@@ -479,9 +593,12 @@ async def get_graph_execution(
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
starting_nodes_input: list[tuple[str, BlockInput]],
|
||||
starting_nodes_input: list[tuple[str, BlockInput]], # list[(node_id, BlockInput)]
|
||||
inputs: Mapping[str, JsonValue],
|
||||
user_id: str,
|
||||
preset_id: str | None = None,
|
||||
preset_id: Optional[str] = None,
|
||||
credential_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
@@ -489,11 +606,18 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data=AgentGraphExecutionCreateInput(
|
||||
agentGraphId=graph_id,
|
||||
agentGraphVersion=graph_version,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
NodeExecutions={
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
@@ -509,9 +633,9 @@ async def create_graph_execution(
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
userId=user_id,
|
||||
agentPresetId=preset_id,
|
||||
),
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -522,7 +646,7 @@ async def upsert_execution_input(
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
input_data: JsonValue,
|
||||
node_exec_id: str | None = None,
|
||||
) -> tuple[str, BlockInput]:
|
||||
"""
|
||||
@@ -571,7 +695,7 @@ async def upsert_execution_input(
|
||||
)
|
||||
return existing_execution.id, {
|
||||
**{
|
||||
input_data.name: type_utils.convert(input_data.data, type[Any])
|
||||
input_data.name: type_utils.convert(input_data.data, JsonValue)
|
||||
for input_data in existing_execution.Input or []
|
||||
},
|
||||
input_name: input_data,
|
||||
@@ -632,6 +756,11 @@ async def update_graph_execution_stats(
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
if not status and not stats:
|
||||
raise ValueError(
|
||||
f"Must provide either status or stats to update for execution {graph_exec_id}"
|
||||
)
|
||||
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
|
||||
|
||||
if stats:
|
||||
@@ -643,20 +772,25 @@ async def update_graph_execution_stats(
|
||||
if status:
|
||||
update_data["executionStatus"] = status
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
# Terminated graph can be resumed.
|
||||
{"executionStatus": ExecutionStatus.TERMINATED},
|
||||
],
|
||||
},
|
||||
where_clause: AgentGraphExecutionWhereInput = {"id": graph_exec_id}
|
||||
|
||||
if status:
|
||||
if allowed_from := VALID_STATUS_TRANSITIONS.get(status, []):
|
||||
# Add OR clause to check if current status is one of the allowed source statuses
|
||||
where_clause["AND"] = [
|
||||
{"id": graph_exec_id},
|
||||
{"OR": [{"executionStatus": s} for s in allowed_from]},
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Status {status} cannot be set via update for execution {graph_exec_id}. "
|
||||
f"This status can only be set at creation or is not a valid target status."
|
||||
)
|
||||
|
||||
await AgentGraphExecution.prisma().update_many(
|
||||
where=where_clause,
|
||||
data=update_data,
|
||||
)
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
||||
where={"id": graph_exec_id},
|
||||
@@ -664,6 +798,7 @@ async def update_graph_execution_stats(
|
||||
[*get_io_block_ids(), *get_webhook_block_ids()]
|
||||
),
|
||||
)
|
||||
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
|
||||
@@ -817,12 +952,19 @@ async def get_latest_node_execution(
|
||||
# ----------------- Execution Infrastructure ----------------- #
|
||||
|
||||
|
||||
class UserContext(BaseModel):
|
||||
"""Generic user context for graph execution containing user-specific settings."""
|
||||
|
||||
timezone: str
|
||||
|
||||
|
||||
class GraphExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||
user_context: UserContext
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
@@ -833,6 +975,7 @@ class NodeExecutionEntry(BaseModel):
|
||||
node_id: str
|
||||
block_id: str
|
||||
inputs: BlockInput
|
||||
user_context: UserContext
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
@@ -882,6 +1025,18 @@ class NodeExecutionEvent(NodeExecutionResult):
|
||||
)
|
||||
|
||||
|
||||
class SharedExecutionResponse(BaseModel):
|
||||
"""Public-safe response for shared executions"""
|
||||
|
||||
id: str
|
||||
graph_name: str
|
||||
graph_description: Optional[str]
|
||||
status: ExecutionStatus
|
||||
created_at: datetime
|
||||
outputs: CompletedBlockOutput # Only the final outputs, no intermediate data
|
||||
# Deliberately exclude: user_id, inputs, credentials, node details
|
||||
|
||||
|
||||
ExecutionEvent = Annotated[
|
||||
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
|
||||
]
|
||||
@@ -1059,3 +1214,98 @@ async def get_block_error_stats(
|
||||
)
|
||||
for row in result
|
||||
]
|
||||
|
||||
|
||||
async def update_graph_execution_share_status(
|
||||
execution_id: str,
|
||||
user_id: str,
|
||||
is_shared: bool,
|
||||
share_token: str | None,
|
||||
shared_at: datetime | None,
|
||||
) -> None:
|
||||
"""Update the sharing status of a graph execution."""
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": execution_id},
|
||||
data={
|
||||
"isShared": is_shared,
|
||||
"shareToken": share_token,
|
||||
"sharedAt": shared_at,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_execution_by_share_token(
|
||||
share_token: str,
|
||||
) -> SharedExecutionResponse | None:
|
||||
"""Get a shared execution with limited public-safe data."""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={
|
||||
"shareToken": share_token,
|
||||
"isShared": True,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={
|
||||
"AgentGraph": True,
|
||||
"NodeExecutions": {
|
||||
"include": {
|
||||
"Output": True,
|
||||
"Node": {
|
||||
"include": {
|
||||
"AgentBlock": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if not execution:
|
||||
return None
|
||||
|
||||
# Extract outputs from OUTPUT blocks only (consistent with GraphExecution.from_db)
|
||||
outputs: CompletedBlockOutput = defaultdict(list)
|
||||
if execution.NodeExecutions:
|
||||
for node_exec in execution.NodeExecutions:
|
||||
if node_exec.Node and node_exec.Node.agentBlockId:
|
||||
# Get the block definition to check its type
|
||||
block = get_block(node_exec.Node.agentBlockId)
|
||||
|
||||
if block and block.block_type == BlockType.OUTPUT:
|
||||
# For OUTPUT blocks, the data is stored in executionData or Input
|
||||
# The executionData contains the structured input with 'name' and 'value' fields
|
||||
if hasattr(node_exec, "executionData") and node_exec.executionData:
|
||||
exec_data = type_utils.convert(
|
||||
node_exec.executionData, dict[str, Any]
|
||||
)
|
||||
if "name" in exec_data:
|
||||
name = exec_data["name"]
|
||||
value = exec_data.get("value")
|
||||
outputs[name].append(value)
|
||||
elif node_exec.Input:
|
||||
# Build input_data from Input relation
|
||||
input_data = {}
|
||||
for data in node_exec.Input:
|
||||
if data.name and data.data is not None:
|
||||
input_data[data.name] = type_utils.convert(
|
||||
data.data, JsonValue
|
||||
)
|
||||
|
||||
if "name" in input_data:
|
||||
name = input_data["name"]
|
||||
value = input_data.get("value")
|
||||
outputs[name].append(value)
|
||||
|
||||
return SharedExecutionResponse(
|
||||
id=execution.id,
|
||||
graph_name=(
|
||||
execution.AgentGraph.name
|
||||
if (execution.AgentGraph and execution.AgentGraph.name)
|
||||
else "Untitled Agent"
|
||||
),
|
||||
graph_description=(
|
||||
execution.AgentGraph.description if execution.AgentGraph else None
|
||||
),
|
||||
status=ExecutionStatus(execution.executionStatus),
|
||||
created_at=execution.createdAt,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from prisma.enums import SubmissionStatus
|
||||
@@ -12,7 +13,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import Field, JsonValue, create_model
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -34,6 +35,7 @@ from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .execution import NodesInputMasks
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -159,6 +161,8 @@ class BaseGraph(BaseDbModel):
|
||||
is_active: bool = True
|
||||
name: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
forked_from_id: str | None = None
|
||||
@@ -205,6 +209,35 @@ class BaseGraph(BaseDbModel):
|
||||
None,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def trigger_setup_info(self) -> "GraphTriggerInfo | None":
|
||||
if not (
|
||||
self.webhook_input_node
|
||||
and (trigger_block := self.webhook_input_node.block).webhook_config
|
||||
):
|
||||
return None
|
||||
|
||||
return GraphTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
config_schema={
|
||||
**(json_schema := trigger_block.input_schema.jsonschema()),
|
||||
"properties": {
|
||||
pn: sub_schema
|
||||
for pn, sub_schema in json_schema["properties"].items()
|
||||
if not is_credentials_field_name(pn)
|
||||
},
|
||||
"required": [
|
||||
pn
|
||||
for pn in json_schema.get("required", [])
|
||||
if not is_credentials_field_name(pn)
|
||||
],
|
||||
},
|
||||
credentials_input_name=next(
|
||||
iter(trigger_block.input_schema.get_credentials_fields()), None
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
@@ -238,6 +271,14 @@ class BaseGraph(BaseDbModel):
|
||||
}
|
||||
|
||||
|
||||
class GraphTriggerInfo(BaseModel):
|
||||
provider: ProviderName
|
||||
config_schema: dict[str, Any] = Field(
|
||||
description="Input schema for the trigger block"
|
||||
)
|
||||
credentials_input_name: Optional[str]
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||
|
||||
@@ -342,6 +383,8 @@ class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
created_at: datetime
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
@@ -354,6 +397,10 @@ class GraphModel(Graph):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||
return cast(NodeModel, super().webhook_input_node)
|
||||
|
||||
def meta(self) -> "GraphMeta":
|
||||
"""
|
||||
Returns a GraphMeta object with metadata about the graph.
|
||||
@@ -414,7 +461,7 @@ class GraphModel(Graph):
|
||||
def validate_graph(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
):
|
||||
"""
|
||||
Validate graph structure and raise `ValueError` on issues.
|
||||
@@ -428,7 +475,7 @@ class GraphModel(Graph):
|
||||
def _validate_graph(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
) -> None:
|
||||
errors = GraphModel._validate_graph_get_errors(
|
||||
graph, for_run, nodes_input_masks
|
||||
@@ -442,7 +489,7 @@ class GraphModel(Graph):
|
||||
def validate_graph_get_errors(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph and return structured errors per node.
|
||||
@@ -464,7 +511,7 @@ class GraphModel(Graph):
|
||||
def _validate_graph_get_errors(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph and return structured errors per node.
|
||||
@@ -655,9 +702,12 @@ class GraphModel(Graph):
|
||||
version=graph.version,
|
||||
forked_from_id=graph.forkedFromId,
|
||||
forked_from_version=graph.forkedFromVersion,
|
||||
created_at=graph.createdAt,
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
instructions=graph.instructions,
|
||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
|
||||
links=list(
|
||||
{
|
||||
@@ -1045,6 +1095,7 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
version=graph.version,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
recommendedScheduleCron=graph.recommended_schedule_cron,
|
||||
isActive=graph.is_active,
|
||||
userId=user_id,
|
||||
forkedFromId=graph.forked_from_id,
|
||||
@@ -1103,6 +1154,7 @@ def make_graph_model(creatable_graph: Graph, user_id: str) -> GraphModel:
|
||||
return GraphModel(
|
||||
**creatable_graph.model_dump(exclude={"nodes"}),
|
||||
user_id=user_id,
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
nodes=[
|
||||
NodeModel(
|
||||
**creatable_node.model_dump(),
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.exceptions
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
@@ -317,12 +316,7 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
),
|
||||
autogpt_libs.auth.models.User(
|
||||
user_id=admin_user.id,
|
||||
role="admin",
|
||||
email=admin_user.email,
|
||||
phone_number="1234567890",
|
||||
),
|
||||
user_id=admin_user.id,
|
||||
)
|
||||
|
||||
# Now we check the graph can be accessed by a user that does not own the graph
|
||||
|
||||
@@ -59,9 +59,15 @@ def graph_execution_include(
|
||||
}
|
||||
|
||||
|
||||
AGENT_PRESET_INCLUDE: prisma.types.AgentPresetInclude = {
|
||||
"InputPresets": True,
|
||||
"Webhook": True,
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE},
|
||||
"AgentPresets": {"include": {"InputPresets": True}},
|
||||
"AgentPresets": {"include": AGENT_PRESET_INCLUDE},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -96,6 +96,12 @@ class User(BaseModel):
|
||||
default=True, description="Notify on monthly summary"
|
||||
)
|
||||
|
||||
# User timezone for scheduling and time display
|
||||
timezone: str = Field(
|
||||
default="not-set",
|
||||
description="User timezone (IANA timezone identifier or 'not-set')",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||
"""Convert a database User object to application User model."""
|
||||
@@ -149,6 +155,7 @@ class User(BaseModel):
|
||||
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
|
||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||
timezone=prisma_user.timezone or "not-set",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -54,19 +54,6 @@ class AgentRunData(BaseNotificationData):
|
||||
|
||||
|
||||
class ZeroBalanceData(BaseNotificationData):
|
||||
last_transaction: float
|
||||
last_transaction_time: datetime
|
||||
top_up_link: str
|
||||
|
||||
@field_validator("last_transaction_time")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class LowBalanceData(BaseNotificationData):
|
||||
agent_name: str = Field(..., description="Name of the agent")
|
||||
current_balance: float = Field(
|
||||
..., description="Current balance in credits (100 = $1)"
|
||||
@@ -75,6 +62,13 @@ class LowBalanceData(BaseNotificationData):
|
||||
shortfall: float = Field(..., description="Amount of credits needed to continue")
|
||||
|
||||
|
||||
class LowBalanceData(BaseNotificationData):
|
||||
current_balance: float = Field(
|
||||
..., description="Current balance in credits (100 = $1)"
|
||||
)
|
||||
billing_page_link: str = Field(..., description="Link to billing page")
|
||||
|
||||
|
||||
class BlockExecutionFailedData(BaseNotificationData):
|
||||
block_name: str
|
||||
block_id: str
|
||||
@@ -181,6 +175,42 @@ class RefundRequestData(BaseNotificationData):
|
||||
balance: int
|
||||
|
||||
|
||||
class AgentApprovalData(BaseNotificationData):
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
reviewer_name: str
|
||||
reviewer_email: str
|
||||
comments: str
|
||||
reviewed_at: datetime
|
||||
store_url: str
|
||||
|
||||
@field_validator("reviewed_at")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class AgentRejectionData(BaseNotificationData):
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
reviewer_name: str
|
||||
reviewer_email: str
|
||||
comments: str
|
||||
reviewed_at: datetime
|
||||
resubmit_url: str
|
||||
|
||||
@field_validator("reviewed_at")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
NotificationData = Annotated[
|
||||
Union[
|
||||
AgentRunData,
|
||||
@@ -240,6 +270,8 @@ def get_notif_data_type(
|
||||
NotificationType.MONTHLY_SUMMARY: MonthlySummaryData,
|
||||
NotificationType.REFUND_REQUEST: RefundRequestData,
|
||||
NotificationType.REFUND_PROCESSED: RefundRequestData,
|
||||
NotificationType.AGENT_APPROVED: AgentApprovalData,
|
||||
NotificationType.AGENT_REJECTED: AgentRejectionData,
|
||||
}[notification_type]
|
||||
|
||||
|
||||
@@ -274,7 +306,7 @@ class NotificationTypeOverride:
|
||||
# These are batched by the notification service
|
||||
NotificationType.AGENT_RUN: QueueType.BATCH,
|
||||
# These are batched by the notification service, but with a backoff strategy
|
||||
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
|
||||
NotificationType.ZERO_BALANCE: QueueType.IMMEDIATE,
|
||||
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
|
||||
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
|
||||
@@ -283,6 +315,8 @@ class NotificationTypeOverride:
|
||||
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
|
||||
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
|
||||
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
|
||||
NotificationType.AGENT_APPROVED: QueueType.IMMEDIATE,
|
||||
NotificationType.AGENT_REJECTED: QueueType.IMMEDIATE,
|
||||
}
|
||||
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
|
||||
|
||||
@@ -300,6 +334,8 @@ class NotificationTypeOverride:
|
||||
NotificationType.MONTHLY_SUMMARY: "monthly_summary.html",
|
||||
NotificationType.REFUND_REQUEST: "refund_request.html",
|
||||
NotificationType.REFUND_PROCESSED: "refund_processed.html",
|
||||
NotificationType.AGENT_APPROVED: "agent_approved.html",
|
||||
NotificationType.AGENT_REJECTED: "agent_rejected.html",
|
||||
}[self.notification_type]
|
||||
|
||||
@property
|
||||
@@ -315,6 +351,8 @@ class NotificationTypeOverride:
|
||||
NotificationType.MONTHLY_SUMMARY: "We did a lot this month!",
|
||||
NotificationType.REFUND_REQUEST: "[ACTION REQUIRED] You got a ${{data.amount / 100}} refund request from {{data.user_name}}",
|
||||
NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed",
|
||||
NotificationType.AGENT_APPROVED: "🎉 Your agent '{{data.agent_name}}' has been approved!",
|
||||
NotificationType.AGENT_REJECTED: "Your agent '{{data.agent_name}}' needs some updates",
|
||||
}[self.notification_type]
|
||||
|
||||
|
||||
|
||||
151
autogpt_platform/backend/backend/data/notifications_test.py
Normal file
151
autogpt_platform/backend/backend/data/notifications_test.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for notification data models."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.data.notifications import AgentApprovalData, AgentRejectionData
|
||||
|
||||
|
||||
class TestAgentApprovalData:
|
||||
"""Test cases for AgentApprovalData model."""
|
||||
|
||||
def test_valid_agent_approval_data(self):
|
||||
"""Test creating valid AgentApprovalData."""
|
||||
data = AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="Great agent, approved!",
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.agent_id == "test-agent-123"
|
||||
assert data.agent_version == 1
|
||||
assert data.reviewer_name == "John Doe"
|
||||
assert data.reviewer_email == "john@example.com"
|
||||
assert data.comments == "Great agent, approved!"
|
||||
assert data.store_url == "https://app.autogpt.com/store/test-agent-123"
|
||||
assert data.reviewed_at.tzinfo is not None
|
||||
|
||||
def test_agent_approval_data_without_timezone_raises_error(self):
|
||||
"""Test that AgentApprovalData raises error without timezone."""
|
||||
with pytest.raises(
|
||||
ValidationError, match="datetime must have timezone information"
|
||||
):
|
||||
AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="Great agent, approved!",
|
||||
reviewed_at=datetime.now(), # No timezone
|
||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
||||
)
|
||||
|
||||
def test_agent_approval_data_with_empty_comments(self):
|
||||
"""Test AgentApprovalData with empty comments."""
|
||||
data = AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="", # Empty comments
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.comments == ""
|
||||
|
||||
|
||||
class TestAgentRejectionData:
|
||||
"""Test cases for AgentRejectionData model."""
|
||||
|
||||
def test_valid_agent_rejection_data(self):
|
||||
"""Test creating valid AgentRejectionData."""
|
||||
data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the security issues before resubmitting.",
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.agent_id == "test-agent-123"
|
||||
assert data.agent_version == 1
|
||||
assert data.reviewer_name == "Jane Doe"
|
||||
assert data.reviewer_email == "jane@example.com"
|
||||
assert data.comments == "Please fix the security issues before resubmitting."
|
||||
assert data.resubmit_url == "https://app.autogpt.com/build/test-agent-123"
|
||||
assert data.reviewed_at.tzinfo is not None
|
||||
|
||||
def test_agent_rejection_data_without_timezone_raises_error(self):
|
||||
"""Test that AgentRejectionData raises error without timezone."""
|
||||
with pytest.raises(
|
||||
ValidationError, match="datetime must have timezone information"
|
||||
):
|
||||
AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the security issues.",
|
||||
reviewed_at=datetime.now(), # No timezone
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
def test_agent_rejection_data_with_long_comments(self):
|
||||
"""Test AgentRejectionData with long comments."""
|
||||
long_comment = "A" * 1000 # Very long comment
|
||||
data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments=long_comment,
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
assert data.comments == long_comment
|
||||
|
||||
def test_model_serialization(self):
|
||||
"""Test that models can be serialized and deserialized."""
|
||||
original_data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the issues.",
|
||||
reviewed_at=datetime.now(timezone.utc),
|
||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
||||
)
|
||||
|
||||
# Serialize to dict
|
||||
data_dict = original_data.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_data = AgentRejectionData.model_validate(data_dict)
|
||||
|
||||
assert restored_data.agent_name == original_data.agent_name
|
||||
assert restored_data.agent_id == original_data.agent_id
|
||||
assert restored_data.agent_version == original_data.agent_version
|
||||
assert restored_data.reviewer_name == original_data.reviewer_name
|
||||
assert restored_data.reviewer_email == original_data.reviewer_email
|
||||
assert restored_data.comments == original_data.comments
|
||||
assert restored_data.reviewed_at == original_data.reviewed_at
|
||||
assert restored_data.resubmit_url == original_data.resubmit_url
|
||||
@@ -208,6 +208,8 @@ async def get_user_notification_preference(user_id: str) -> NotificationPreferen
|
||||
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or False,
|
||||
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or False,
|
||||
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or False,
|
||||
NotificationType.AGENT_APPROVED: user.notifyOnAgentApproved or False,
|
||||
NotificationType.AGENT_REJECTED: user.notifyOnAgentRejected or False,
|
||||
}
|
||||
daily_limit = user.maxEmailsPerDay or 3
|
||||
notification_preference = NotificationPreference(
|
||||
@@ -266,6 +268,14 @@ async def update_user_notification_preference(
|
||||
update_data["notifyOnMonthlySummary"] = data.preferences[
|
||||
NotificationType.MONTHLY_SUMMARY
|
||||
]
|
||||
if NotificationType.AGENT_APPROVED in data.preferences:
|
||||
update_data["notifyOnAgentApproved"] = data.preferences[
|
||||
NotificationType.AGENT_APPROVED
|
||||
]
|
||||
if NotificationType.AGENT_REJECTED in data.preferences:
|
||||
update_data["notifyOnAgentRejected"] = data.preferences[
|
||||
NotificationType.AGENT_REJECTED
|
||||
]
|
||||
if data.daily_limit:
|
||||
update_data["maxEmailsPerDay"] = data.daily_limit
|
||||
|
||||
@@ -286,6 +296,8 @@ async def update_user_notification_preference(
|
||||
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or True,
|
||||
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or True,
|
||||
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or True,
|
||||
NotificationType.AGENT_APPROVED: user.notifyOnAgentApproved or True,
|
||||
NotificationType.AGENT_REJECTED: user.notifyOnAgentRejected or True,
|
||||
}
|
||||
notification_preference = NotificationPreference(
|
||||
user_id=user.id,
|
||||
@@ -384,3 +396,17 @@ async def unsubscribe_user_by_token(token: str) -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e
|
||||
|
||||
|
||||
async def update_user_timezone(user_id: str, timezone: str) -> User:
|
||||
"""Update a user's timezone setting."""
|
||||
try:
|
||||
user = await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"timezone": timezone},
|
||||
)
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
return User.from_db(user)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to update timezone for user {user_id}: {e}") from e
|
||||
|
||||
@@ -107,7 +107,7 @@ async def generate_activity_status_for_execution(
|
||||
# Check if we have OpenAI API key
|
||||
try:
|
||||
settings = Settings()
|
||||
if not settings.secrets.openai_api_key:
|
||||
if not settings.secrets.openai_internal_api_key:
|
||||
logger.debug(
|
||||
"OpenAI API key not configured, skipping activity status generation"
|
||||
)
|
||||
@@ -187,7 +187,7 @@ async def generate_activity_status_for_execution(
|
||||
credentials = APIKeyCredentials(
|
||||
id="openai",
|
||||
provider="openai",
|
||||
api_key=SecretStr(settings.secrets.openai_api_key),
|
||||
api_key=SecretStr(settings.secrets.openai_internal_api_key),
|
||||
title="System OpenAI",
|
||||
)
|
||||
|
||||
|
||||
@@ -468,7 +468,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_llm.return_value = (
|
||||
"I analyzed your data and provided the requested insights."
|
||||
)
|
||||
@@ -520,7 +520,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_api_key = ""
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = ""
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -546,7 +546,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -581,7 +581,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_llm.return_value = "Agent completed execution."
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
@@ -633,7 +633,7 @@ class TestIntegration:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
|
||||
mock_response = LLMResponse(
|
||||
raw_response={},
|
||||
|
||||
@@ -42,6 +42,8 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
|
||||
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -145,6 +147,14 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(list_library_agents)
|
||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
@@ -173,12 +183,20 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
spend_credits = _(d.spend_credits)
|
||||
get_credits = _(d.get_credits)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(d.get_user_execution_summary_data)
|
||||
|
||||
# Block error monitoring
|
||||
get_block_error_stats = _(d.get_block_error_stats)
|
||||
|
||||
# User Emails
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(d.list_library_agents)
|
||||
add_store_agent_to_library = _(d.add_store_agent_to_library)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(d.get_store_agents)
|
||||
get_store_agent_details = _(d.get_store_agent_details)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -223,5 +241,13 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
list_library_agents = d.list_library_agents
|
||||
add_store_agent_to_library = d.add_store_agent_to_library
|
||||
|
||||
# Store
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
|
||||
# Summary data
|
||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from pydantic import JsonValue
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
@@ -20,6 +19,7 @@ from backend.data.notifications import (
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
@@ -37,9 +37,9 @@ from prometheus_client import Gauge, start_http_server
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockOutputEntry,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
@@ -51,6 +51,8 @@ from backend.data.execution import (
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.utils import (
|
||||
@@ -74,6 +76,7 @@ from backend.util.clients import (
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.decorator import (
|
||||
async_error_logged,
|
||||
@@ -83,6 +86,7 @@ from backend.util.decorator import (
|
||||
)
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -127,7 +131,7 @@ async def execute_node(
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -190,6 +194,9 @@ async def execute_node(
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
# Add user context from NodeExecutionEntry
|
||||
extra_exec_kwargs["user_context"] = data.user_context
|
||||
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
# changes during execution. ⚠️ This means a set of credentials can only be used by
|
||||
# one (running) block at a time; simultaneous execution of blocks using same
|
||||
@@ -230,12 +237,13 @@ async def execute_node(
|
||||
async def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
output: BlockOutputEntry,
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
user_context: UserContext,
|
||||
) -> list[NodeExecutionEntry]:
|
||||
async def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
@@ -254,6 +262,7 @@ async def _enqueue_next_nodes(
|
||||
node_id=node_id,
|
||||
block_id=block_id,
|
||||
inputs=data,
|
||||
user_context=user_context,
|
||||
)
|
||||
|
||||
async def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
||||
@@ -410,7 +419,7 @@ class ExecutionProcessor:
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
@@ -478,7 +487,7 @@ class ExecutionProcessor:
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
@@ -596,7 +605,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
return
|
||||
|
||||
if exec_meta.status == ExecutionStatus.QUEUED:
|
||||
if exec_meta.status in [ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE]:
|
||||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||||
exec_meta.status = ExecutionStatus.RUNNING
|
||||
send_execution_update(
|
||||
@@ -678,19 +687,20 @@ class ExecutionProcessor:
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
) -> int:
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return total_cost
|
||||
return total_cost, 0
|
||||
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
db_client.spend_credits(
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -708,7 +718,7 @@ class ExecutionProcessor:
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
db_client.spend_credits(
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -723,7 +733,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
return total_cost
|
||||
return total_cost, remaining_balance
|
||||
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
@@ -787,7 +797,8 @@ class ExecutionProcessor:
|
||||
ExecutionStatus.TERMINATED,
|
||||
],
|
||||
):
|
||||
execution_queue.add(node_exec.to_node_execution_entry())
|
||||
node_entry = node_exec.to_node_execution_entry(graph_exec.user_context)
|
||||
execution_queue.add(node_entry)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Main dispatch / polling loop -----------------------------
|
||||
@@ -805,12 +816,19 @@ class ExecutionProcessor:
|
||||
|
||||
# Charge usage (may raise) ------------------------------
|
||||
try:
|
||||
cost = self._charge_usage(
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
)
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
db_client=db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
transaction_cost=cost,
|
||||
)
|
||||
except InsufficientBalanceError as balance_error:
|
||||
error = balance_error # Set error to trigger FAILED status
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
@@ -825,11 +843,10 @@ class ExecutionProcessor:
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
self._handle_low_balance_notif(
|
||||
self._handle_insufficient_funds_notif(
|
||||
db_client,
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
execution_stats,
|
||||
error,
|
||||
)
|
||||
# Gracefully stop the execution loop
|
||||
@@ -1036,7 +1053,7 @@ class ExecutionProcessor:
|
||||
node_id: str,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||||
) -> None:
|
||||
"""Process a node's output, update its status, and enqueue next nodes.
|
||||
@@ -1052,6 +1069,7 @@ class ExecutionProcessor:
|
||||
db_client = get_db_async_client()
|
||||
|
||||
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
|
||||
|
||||
for next_execution in await _enqueue_next_nodes(
|
||||
db_client=db_client,
|
||||
node=output.node,
|
||||
@@ -1061,6 +1079,7 @@ class ExecutionProcessor:
|
||||
graph_id=graph_exec.graph_id,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
user_context=graph_exec.user_context,
|
||||
):
|
||||
execution_queue.add(next_execution)
|
||||
|
||||
@@ -1101,25 +1120,25 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_low_balance_notif(
|
||||
def _handle_insufficient_funds_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
exec_stats: GraphExecutionStats,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
shortfall = e.balance - e.amount
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=exec_stats.cost,
|
||||
type=NotificationType.ZERO_BALANCE,
|
||||
data=ZeroBalanceData(
|
||||
current_balance=e.balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
shortfall=shortfall,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
@@ -1127,6 +1146,73 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||||
f"Current balance: ${e.balance/100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount)/100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall)/100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as alert_error:
|
||||
logger.error(
|
||||
f"Failed to send insufficient funds Discord alert: {alert_error}"
|
||||
)
|
||||
|
||||
def _handle_low_balance(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
):
|
||||
"""Check and handle low balance scenarios after a transaction"""
|
||||
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
|
||||
|
||||
balance_before = current_balance + transaction_cost
|
||||
|
||||
if (
|
||||
current_balance < LOW_BALANCE_THRESHOLD
|
||||
and balance_before >= LOW_BALANCE_THRESHOLD
|
||||
):
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=current_balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD/100:.2f}\n"
|
||||
f"Current balance: ${current_balance/100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost/100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send low balance Discord alert: {e}")
|
||||
|
||||
|
||||
class ExecutionManager(AppProcess):
|
||||
def __init__(self):
|
||||
@@ -1308,14 +1394,14 @@ class ExecutionManager(AppProcess):
|
||||
delivery_tag = method.delivery_tag
|
||||
|
||||
@func_retry
|
||||
def _ack_message(reject: bool = False):
|
||||
def _ack_message(reject: bool, requeue: bool):
|
||||
"""Acknowledge or reject the message based on execution status."""
|
||||
|
||||
# Connection can be lost, so always get a fresh channel
|
||||
channel = self.run_client.get_channel()
|
||||
if reject:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=True)
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=requeue)
|
||||
)
|
||||
else:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
@@ -1327,13 +1413,13 @@ class ExecutionManager(AppProcess):
|
||||
logger.info(
|
||||
f"[{self.service_name}] Rejecting new execution during shutdown"
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Check if we can accept more runs
|
||||
self._cleanup_completed_runs()
|
||||
if len(self.active_graph_runs) >= self.pool_size:
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -1342,7 +1428,7 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Could not parse run message: {e}, body={body}"
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
@@ -1354,7 +1440,7 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
cancel_event = threading.Event()
|
||||
@@ -1370,9 +1456,9 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}"
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
else:
|
||||
_ack_message(reject=False)
|
||||
_ack_message(reject=False, requeue=False)
|
||||
except BaseException as e:
|
||||
logger.exception(
|
||||
f"[{self.service_name}] Error in run completion callback: {e}"
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import LowBalanceData
|
||||
from backend.executor.manager import ExecutionProcessor
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
"""Test that _handle_low_balance triggers notification when crossing threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 400 # $4 - below $5 threshold
|
||||
transaction_cost = 600 # $6 transaction
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
)
|
||||
|
||||
# Verify notification was queued
|
||||
mock_queue_notif.assert_called_once()
|
||||
notification_call = mock_queue_notif.call_args[0][0]
|
||||
|
||||
# Verify notification details
|
||||
assert notification_call.type == NotificationType.LOW_BALANCE
|
||||
assert notification_call.user_id == user_id
|
||||
assert isinstance(notification_call.data, LowBalanceData)
|
||||
assert notification_call.data.current_balance == current_balance
|
||||
|
||||
# Verify Discord alert was sent
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
discord_message = mock_client.discord_system_alert.call_args[0][0]
|
||||
assert "Low Balance Alert" in discord_message
|
||||
assert "test@example.com" in discord_message
|
||||
assert "$4.00" in discord_message
|
||||
assert "$6.00" in discord_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that no notification is sent when not crossing the threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 600 # $6 - above $5 threshold
|
||||
transaction_cost = (
|
||||
100 # $1 transaction (balance before was $7, still above threshold)
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
)
|
||||
|
||||
# Verify no notification was sent
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that no notification is sent when already below threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 300 # $3 - below $5 threshold
|
||||
transaction_cost = (
|
||||
100 # $1 transaction (balance before was $4, also below threshold)
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
transaction_cost=transaction_cost,
|
||||
)
|
||||
|
||||
# Verify no notification was sent (user was already below threshold)
|
||||
mock_queue_notif.assert_not_called()
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
|
||||
@@ -36,21 +35,20 @@ async def execute_graph(
|
||||
logger.info(f"Input data: {input_data}")
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
response = await agent_server.test_execute_graph(
|
||||
graph_exec = await agent_server.test_execute_graph(
|
||||
user_id=test_user.id,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
node_input=input_data,
|
||||
)
|
||||
graph_exec_id = response.graph_exec_id
|
||||
logger.info(f"Created execution with ID: {graph_exec_id}")
|
||||
logger.info(f"Created execution with ID: {graph_exec.id}")
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, graph_exec_id, 30)
|
||||
result = await wait_execution(test_user.id, graph_exec.id, 30)
|
||||
logger.info(f"Execution completed with {len(result)} results")
|
||||
assert len(result) == num_execs
|
||||
return graph_exec_id
|
||||
return graph_exec.id
|
||||
|
||||
|
||||
async def assert_sample_graph_executions(
|
||||
@@ -380,7 +378,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
|
||||
# Verify execution
|
||||
assert result is not None
|
||||
graph_exec_id = result["id"]
|
||||
graph_exec_id = result.id
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, graph_exec_id)
|
||||
@@ -469,7 +467,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
|
||||
# Verify execution
|
||||
assert result is not None, "Result must not be None"
|
||||
graph_exec_id = result["id"]
|
||||
graph_exec_id = result.id
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, graph_exec_id)
|
||||
@@ -521,12 +519,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
is_approved=True,
|
||||
comments="Test comments",
|
||||
),
|
||||
autogpt_libs.auth.models.User(
|
||||
user_id=admin_user.id,
|
||||
role="admin",
|
||||
email=admin_user.email,
|
||||
phone_number="1234567890",
|
||||
),
|
||||
user_id=admin_user.id,
|
||||
)
|
||||
alt_test_user = admin_user
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from apscheduler.jobstores.memory import MemoryJobStore
|
||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.util import ZoneInfo
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
@@ -190,15 +191,22 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
next_run_time: str
|
||||
timezone: str = Field(default="UTC", description="Timezone used for scheduling")
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
job_args: GraphExecutionJobArgs, job_obj: JobObj
|
||||
) -> "GraphExecutionJobInfo":
|
||||
# Extract timezone from the trigger if it's a CronTrigger
|
||||
timezone_str = "UTC"
|
||||
if hasattr(job_obj.trigger, "timezone"):
|
||||
timezone_str = str(job_obj.trigger.timezone)
|
||||
|
||||
return GraphExecutionJobInfo(
|
||||
id=job_obj.id,
|
||||
name=job_obj.name,
|
||||
next_run_time=job_obj.next_run_time.isoformat(),
|
||||
timezone=timezone_str,
|
||||
**job_args.model_dump(),
|
||||
)
|
||||
|
||||
@@ -303,6 +311,7 @@ class Scheduler(AppService):
|
||||
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
|
||||
},
|
||||
logger=apscheduler_logger,
|
||||
timezone=ZoneInfo("UTC"),
|
||||
)
|
||||
|
||||
if self.register_system_tasks:
|
||||
@@ -393,6 +402,7 @@ class Scheduler(AppService):
|
||||
input_data: BlockInput,
|
||||
input_credentials: dict[str, CredentialsMetaInput],
|
||||
name: Optional[str] = None,
|
||||
user_timezone: str | None = None,
|
||||
) -> GraphExecutionJobInfo:
|
||||
# Validate the graph before scheduling to prevent runtime failures
|
||||
# We don't need the return value, just want the validation to run
|
||||
@@ -406,6 +416,19 @@ class Scheduler(AppService):
|
||||
)
|
||||
)
|
||||
|
||||
# Use provided timezone or default to UTC
|
||||
# Note: Timezone should be passed from the client to avoid database lookups
|
||||
if not user_timezone:
|
||||
user_timezone = "UTC"
|
||||
logger.warning(
|
||||
f"No timezone provided for user {user_id}, using UTC for scheduling. "
|
||||
f"Client should pass user's timezone for correct scheduling."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduling job for user {user_id} with timezone {user_timezone} (cron: {cron})"
|
||||
)
|
||||
|
||||
job_args = GraphExecutionJobArgs(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
@@ -418,12 +441,12 @@ class Scheduler(AppService):
|
||||
execute_graph,
|
||||
kwargs=job_args.model_dump(),
|
||||
name=name,
|
||||
trigger=CronTrigger.from_crontab(cron),
|
||||
trigger=CronTrigger.from_crontab(cron, timezone=user_timezone),
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}"
|
||||
f"Added job {job.id} with cron schedule '{cron}' in timezone {user_timezone}, input data: {input_data}"
|
||||
)
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
|
||||
@@ -4,24 +4,33 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Mapping, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCostType,
|
||||
BlockInput,
|
||||
BlockOutputEntry,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.db import prisma
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
@@ -34,6 +43,27 @@ from backend.util.mock import MockObject
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
async def get_user_context(user_id: str) -> UserContext:
|
||||
"""
|
||||
Get UserContext for a user, always returns a valid context with timezone.
|
||||
Defaults to UTC if user has no timezone set.
|
||||
"""
|
||||
user_context = UserContext(timezone="UTC") # Default to UTC
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
if user and user.timezone and user.timezone != "not-set":
|
||||
user_context.timezone = user.timezone
|
||||
logger.debug(f"Retrieved user context: timezone={user.timezone}")
|
||||
else:
|
||||
logger.debug("User has no timezone set, using UTC")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch user timezone: {e}")
|
||||
# Continue with UTC as default
|
||||
|
||||
return user_context
|
||||
|
||||
|
||||
config = Config()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||
|
||||
@@ -216,7 +246,7 @@ def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | None:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
@@ -240,7 +270,7 @@ def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: Any = data
|
||||
cur: JsonValue = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
@@ -405,7 +435,7 @@ def validate_exec(
|
||||
async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
@@ -485,8 +515,8 @@ async def _validate_node_input_credentials(
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
graph_credentials_input: Mapping[str, CredentialsMetaInput],
|
||||
) -> NodesInputMasks:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
@@ -521,8 +551,8 @@ def make_node_credentials_input_map(
|
||||
async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
|
||||
@@ -552,7 +582,7 @@ async def _construct_starting_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
@@ -593,7 +623,7 @@ async def _construct_starting_node_execution_input(
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
input_name = cast(str | None, node.input_default.get("name"))
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
@@ -620,9 +650,9 @@ async def validate_and_construct_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], dict[str, dict[str, JsonValue]]]:
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
@@ -636,7 +666,9 @@ async def validate_and_construct_node_execution_input(
|
||||
nodes_input_masks: Node inputs to use.
|
||||
|
||||
Returns:
|
||||
tuple[GraphModel, list[tuple[str, BlockInput]]]: Graph model and list of tuples for node execution input.
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
@@ -677,11 +709,11 @@ async def validate_and_construct_node_execution_input(
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
overrides_map_1: dict[str, dict[str, JsonValue]],
|
||||
overrides_map_2: dict[str, dict[str, JsonValue]],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
overrides_map_1: NodesInputMasks,
|
||||
overrides_map_2: NodesInputMasks,
|
||||
) -> NodesInputMasks:
|
||||
"""Perform a per-node merge of input overrides"""
|
||||
result = overrides_map_1.copy()
|
||||
result = dict(overrides_map_1).copy()
|
||||
for node_id, overrides2 in overrides_map_2.items():
|
||||
if node_id in result:
|
||||
result[node_id] = {**result[node_id], **overrides2}
|
||||
@@ -831,8 +863,8 @@ async def add_graph_execution(
|
||||
inputs: Optional[BlockInput] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
@@ -856,7 +888,7 @@ async def add_graph_execution(
|
||||
else:
|
||||
edb = get_database_manager_async_client()
|
||||
|
||||
graph, starting_nodes_input, nodes_input_masks = (
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
@@ -869,34 +901,43 @@ async def add_graph_execution(
|
||||
graph_exec = None
|
||||
|
||||
try:
|
||||
# Sanity check: running add_graph_execution with the properties of
|
||||
# the graph_exec created here should create the same execution again.
|
||||
graph_exec = await edb.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
inputs=inputs or {},
|
||||
credential_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if nodes_input_masks:
|
||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
user_context=await get_user_context(user_id),
|
||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||
)
|
||||
logger.info(
|
||||
f"Created graph execution #{graph_exec.id} for graph "
|
||||
f"#{graph_id} with {len(starting_nodes_input)} starting nodes. "
|
||||
f"Now publishing to execution queue."
|
||||
)
|
||||
|
||||
await queue.publish_message(
|
||||
exec_queue = await get_async_execution_queue()
|
||||
await exec_queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
bus = get_async_execution_event_bus()
|
||||
await bus.publish(graph_exec)
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=graph_exec.status,
|
||||
)
|
||||
await get_async_execution_event_bus().publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except BaseException as e:
|
||||
@@ -926,7 +967,7 @@ async def add_graph_execution(
|
||||
class ExecutionOutputEntry(BaseModel):
|
||||
node: Node
|
||||
node_exec_id: str
|
||||
data: BlockData
|
||||
data: BlockOutputEntry
|
||||
|
||||
|
||||
class NodeExecutionProgress:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.executor.utils import merge_execution_input, parse_execution_output
|
||||
from backend.util.mock import MockObject
|
||||
@@ -276,3 +277,147 @@ def test_merge_execution_input():
|
||||
result = merge_execution_input(data)
|
||||
assert "mixed" in result
|
||||
assert result["mixed"].attr[0]["key"] == "value3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
"""
|
||||
Verify that calling the function with its own output creates the same execution again.
|
||||
"""
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Mock data
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
inputs = {"test_input": "test_value"}
|
||||
preset_id = "test-preset-id"
|
||||
graph_version = 1
|
||||
graph_credentials_inputs = {
|
||||
"cred_key": CredentialsMetaInput(
|
||||
id="cred-id", provider=ProviderName("test_provider"), type="oauth2"
|
||||
)
|
||||
}
|
||||
nodes_input_masks = {"node1": {"input1": "masked_value"}}
|
||||
|
||||
# Mock the graph object returned by validate_and_construct_node_execution_input
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = graph_version
|
||||
|
||||
# Mock the starting nodes input and compiled nodes input masks
|
||||
starting_nodes_input = [
|
||||
("node1", {"input1": "value1"}),
|
||||
("node2", {"input1": "value2"}),
|
||||
]
|
||||
compiled_nodes_input_masks = {"node1": {"input1": "compiled_mask"}}
|
||||
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Mock user context
|
||||
mock_user_context = {"user_id": user_id, "context": "test_context"}
|
||||
|
||||
# Mock the queue and event bus
|
||||
mock_queue = mocker.AsyncMock()
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_get_user_context = mocker.patch("backend.executor.utils.get_user_context")
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
mock_get_user_context.return_value = mock_user_context
|
||||
mock_get_queue.return_value = mock_queue
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
# Call the function - first execution
|
||||
result1 = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
# Store the parameters used in the first call to create_graph_execution
|
||||
first_call_kwargs = mock_edb.create_graph_execution.call_args[1]
|
||||
|
||||
# Verify the create_graph_execution was called with correct parameters
|
||||
mock_edb.create_graph_execution.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=mock_graph.version,
|
||||
inputs=inputs,
|
||||
credential_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
# Set up the graph execution mock to have properties we can extract
|
||||
mock_graph_exec.graph_id = graph_id
|
||||
mock_graph_exec.user_id = user_id
|
||||
mock_graph_exec.graph_version = graph_version
|
||||
mock_graph_exec.inputs = inputs
|
||||
mock_graph_exec.credential_inputs = graph_credentials_inputs
|
||||
mock_graph_exec.nodes_input_masks = nodes_input_masks
|
||||
mock_graph_exec.preset_id = preset_id
|
||||
|
||||
# Create a second mock execution for the sanity check
|
||||
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec_2.id = "execution-id-456"
|
||||
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Reset mocks and set up for second call
|
||||
mock_edb.create_graph_execution.reset_mock()
|
||||
mock_edb.create_graph_execution.return_value = mock_graph_exec_2
|
||||
mock_validate.reset_mock()
|
||||
|
||||
# Sanity check: call add_graph_execution with properties from first result
|
||||
# This should create the same execution parameters
|
||||
result2 = await add_graph_execution(
|
||||
graph_id=mock_graph_exec.graph_id,
|
||||
user_id=mock_graph_exec.user_id,
|
||||
inputs=mock_graph_exec.inputs,
|
||||
preset_id=mock_graph_exec.preset_id,
|
||||
graph_version=mock_graph_exec.graph_version,
|
||||
graph_credentials_inputs=mock_graph_exec.credential_inputs,
|
||||
nodes_input_masks=mock_graph_exec.nodes_input_masks,
|
||||
)
|
||||
|
||||
# Verify that create_graph_execution was called with identical parameters
|
||||
second_call_kwargs = mock_edb.create_graph_execution.call_args[1]
|
||||
|
||||
# The sanity check: both calls should use identical parameters
|
||||
assert first_call_kwargs == second_call_kwargs
|
||||
|
||||
# Both executions should succeed (though they create different objects)
|
||||
assert result1 == mock_graph_exec
|
||||
assert result2 == mock_graph_exec_2
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.integrations.oauth.todoist import TodoistOAuthHandler
|
||||
|
||||
from .discord import DiscordOAuthHandler
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
@@ -15,6 +16,7 @@ if TYPE_CHECKING:
|
||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||
# Build handlers dict with string keys for compatibility with SDK auto-registration
|
||||
_ORIGINAL_HANDLERS = [
|
||||
DiscordOAuthHandler,
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
|
||||
175
autogpt_platform/backend/backend/integrations/oauth/discord.py
Normal file
175
autogpt_platform/backend/backend/integrations/oauth/discord.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
|
||||
class DiscordOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Discord OAuth2 handler implementation.
|
||||
|
||||
Based on the documentation at:
|
||||
- https://discord.com/developers/docs/topics/oauth2
|
||||
|
||||
Discord OAuth2 tokens expire after 7 days by default and include refresh tokens.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.DISCORD
|
||||
DEFAULT_SCOPES = ["identify"] # Basic user information
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://discord.com/oauth2/authorize"
|
||||
self.token_url = "https://discord.com/api/oauth2/token"
|
||||
self.revoke_url = "https://discord.com/api/oauth2/token/revoke"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
# Handle default scopes
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
}
|
||||
|
||||
# Discord supports PKCE
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
params = {
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
|
||||
# Include PKCE verifier if provided
|
||||
if code_verifier:
|
||||
params["code_verifier"] = code_verifier
|
||||
|
||||
return await self._request_tokens(params)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
if not credentials.access_token:
|
||||
raise ValueError("No access token to revoke")
|
||||
|
||||
# Discord requires client authentication for token revocation
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
url=self.revoke_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
auth=(self.client_id, self.client_secret),
|
||||
)
|
||||
|
||||
# Discord returns 200 OK for successful revocation
|
||||
return response.status == 200
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
return credentials
|
||||
|
||||
return await self._request_tokens(
|
||||
{
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
current_credentials=credentials,
|
||||
)
|
||||
|
||||
async def _request_tokens(
|
||||
self,
|
||||
params: dict[str, str],
|
||||
current_credentials: Optional[OAuth2Credentials] = None,
|
||||
) -> OAuth2Credentials:
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
**params,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
self.token_url, data=request_body, headers=headers
|
||||
)
|
||||
token_data: dict = response.json()
|
||||
|
||||
# Get username if this is a new token request
|
||||
username = None
|
||||
if "access_token" in token_data:
|
||||
username = await self._request_username(token_data["access_token"])
|
||||
|
||||
now = int(time.time())
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=current_credentials.title if current_credentials else None,
|
||||
username=username,
|
||||
access_token=token_data["access_token"],
|
||||
scopes=token_data.get("scope", "").split()
|
||||
or (current_credentials.scopes if current_credentials else []),
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
# Discord tokens expire after expires_in seconds (typically 7 days)
|
||||
access_token_expires_at=(
|
||||
now + expires_in
|
||||
if (expires_in := token_data.get("expires_in", None))
|
||||
else None
|
||||
),
|
||||
# Discord doesn't provide separate refresh token expiration
|
||||
refresh_token_expires_at=None,
|
||||
)
|
||||
|
||||
if current_credentials:
|
||||
new_credentials.id = current_credentials.id
|
||||
|
||||
return new_credentials
|
||||
|
||||
async def _request_username(self, access_token: str) -> str | None:
|
||||
"""
|
||||
Fetch the username using the Discord OAuth2 @me endpoint.
|
||||
"""
|
||||
url = "https://discord.com/api/oauth2/@me"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
response = await Requests().get(url, headers=headers)
|
||||
|
||||
if not response.ok:
|
||||
return None
|
||||
|
||||
# Get user info from the response
|
||||
data = response.json()
|
||||
user_info = data.get("user", {})
|
||||
|
||||
# Return username (without discriminator)
|
||||
return user_info.get("username")
|
||||
@@ -7,10 +7,9 @@ from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
from .utils import setup_webhook_for_block
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import BaseGraph, GraphModel, Node, NodeModel
|
||||
from backend.data.graph import BaseGraph, GraphModel, NodeModel
|
||||
from backend.data.model import Credentials
|
||||
|
||||
from ._base import BaseWebhooksManager
|
||||
@@ -43,32 +42,19 @@ async def _on_graph_activate(graph: "BaseGraph", user_id: str) -> "BaseGraph": .
|
||||
|
||||
async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
|
||||
get_credentials = credentials_manager.cached_getter(user_id)
|
||||
updated_nodes = []
|
||||
for new_node in graph.nodes:
|
||||
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
if (
|
||||
# Webhook-triggered blocks are only allowed to have 1 credentials input
|
||||
(
|
||||
creds_field_name := next(
|
||||
iter(block_input_schema.get_credentials_fields()), None
|
||||
for creds_field_name in block_input_schema.get_credentials_fields().keys():
|
||||
# Prevent saving graph with non-existent credentials
|
||||
if (
|
||||
creds_meta := new_node.input_default.get(creds_field_name)
|
||||
) and not await get_credentials(creds_meta["id"]):
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
)
|
||||
)
|
||||
and (creds_meta := new_node.input_default.get(creds_field_name))
|
||||
and not (node_credentials := await get_credentials(creds_meta["id"]))
|
||||
):
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_activate(
|
||||
user_id, graph.id, new_node, credentials=node_credentials
|
||||
)
|
||||
updated_nodes.append(updated_node)
|
||||
|
||||
graph.nodes = updated_nodes
|
||||
return graph
|
||||
|
||||
|
||||
@@ -85,20 +71,14 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
block_input_schema = cast(BlockSchema, node.block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
if (
|
||||
# Webhook-triggered blocks are only allowed to have 1 credentials input
|
||||
(
|
||||
creds_field_name := next(
|
||||
iter(block_input_schema.get_credentials_fields()), None
|
||||
for creds_field_name in block_input_schema.get_credentials_fields().keys():
|
||||
if (creds_meta := node.input_default.get(creds_field_name)) and not (
|
||||
node_credentials := await get_credentials(creds_meta["id"])
|
||||
):
|
||||
logger.warning(
|
||||
f"Node #{node.id} input '{creds_field_name}' referenced "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
)
|
||||
)
|
||||
and (creds_meta := node.input_default.get(creds_field_name))
|
||||
and not (node_credentials := await get_credentials(creds_meta["id"]))
|
||||
):
|
||||
logger.error(
|
||||
f"Node #{node.id} input '{creds_field_name}' referenced non-existent "
|
||||
f"credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_deactivate(
|
||||
user_id, node, credentials=node_credentials
|
||||
@@ -109,32 +89,6 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
return graph
|
||||
|
||||
|
||||
async def on_node_activate(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
node: "Node",
|
||||
*,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
) -> "Node":
|
||||
"""Hook to be called when the node is activated/created"""
|
||||
|
||||
if node.block.webhook_config:
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=node.block,
|
||||
trigger_config=node.input_default,
|
||||
for_graph_id=graph_id,
|
||||
)
|
||||
if new_webhook:
|
||||
node = await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Node #{node.id} does not have everything for a webhook: {feedback}"
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
async def on_node_deactivate(
|
||||
user_id: str,
|
||||
node: "NodeModel",
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional, cast
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Config
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
@@ -13,6 +12,7 @@ if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockSchema
|
||||
from backend.data.integrations import Webhook
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = Config()
|
||||
@@ -20,7 +20,7 @@ credentials_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
# TODO: add test to assert this matches the actual API route
|
||||
def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
|
||||
def webhook_ingress_url(provider_name: "ProviderName", webhook_id: str) -> str:
|
||||
return (
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{webhook_id}/ingress"
|
||||
@@ -144,3 +144,62 @@ async def setup_webhook_for_block(
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {webhook}")
|
||||
return webhook, None
|
||||
|
||||
|
||||
async def migrate_legacy_triggered_graphs():
|
||||
from prisma.models import AgentGraph
|
||||
|
||||
from backend.data.graph import AGENT_GRAPH_INCLUDE, GraphModel, set_node_webhook
|
||||
from backend.data.model import is_credentials_field_name
|
||||
from backend.server.v2.library.db import create_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPresetCreatable
|
||||
|
||||
triggered_graphs = [
|
||||
GraphModel.from_db(_graph)
|
||||
for _graph in await AgentGraph.prisma().find_many(
|
||||
where={
|
||||
"isActive": True,
|
||||
"Nodes": {"some": {"NOT": [{"webhookId": None}]}},
|
||||
},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
]
|
||||
|
||||
n_migrated_webhooks = 0
|
||||
|
||||
for graph in triggered_graphs:
|
||||
if not ((trigger_node := graph.webhook_input_node) and trigger_node.webhook_id):
|
||||
continue
|
||||
|
||||
# Use trigger node's inputs for the preset
|
||||
preset_credentials = {
|
||||
field_name: creds_meta
|
||||
for field_name, creds_meta in trigger_node.input_default.items()
|
||||
if is_credentials_field_name(field_name)
|
||||
}
|
||||
preset_inputs = {
|
||||
field_name: value
|
||||
for field_name, value in trigger_node.input_default.items()
|
||||
if not is_credentials_field_name(field_name)
|
||||
}
|
||||
|
||||
# Create a triggered preset for the graph
|
||||
await create_preset(
|
||||
graph.user_id,
|
||||
LibraryAgentPresetCreatable(
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
inputs=preset_inputs,
|
||||
credentials=preset_credentials,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
webhook_id=trigger_node.webhook_id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
# Detach webhook from the graph node
|
||||
await set_node_webhook(trigger_node.id, None)
|
||||
|
||||
n_migrated_webhooks += 1
|
||||
|
||||
logger.info(f"Migrated {n_migrated_webhooks} node triggers to triggered presets")
|
||||
|
||||
287
autogpt_platform/backend/backend/monitoring/instrumentation.py
Normal file
287
autogpt_platform/backend/backend/monitoring/instrumentation.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Prometheus instrumentation for FastAPI services.
|
||||
|
||||
This module provides centralized metrics collection and instrumentation
|
||||
for all FastAPI services in the AutoGPT platform.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from prometheus_client import Counter, Gauge, Histogram, Info
|
||||
from prometheus_fastapi_instrumentator import Instrumentator, metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Custom business metrics with controlled cardinality
|
||||
GRAPH_EXECUTIONS = Counter(
|
||||
"autogpt_graph_executions_total",
|
||||
"Total number of graph executions",
|
||||
labelnames=[
|
||||
"status"
|
||||
], # Removed graph_id and user_id to prevent cardinality explosion
|
||||
)
|
||||
|
||||
GRAPH_EXECUTIONS_BY_USER = Counter(
|
||||
"autogpt_graph_executions_by_user_total",
|
||||
"Total number of graph executions by user (sampled)",
|
||||
labelnames=["status"], # Only status, user_id tracked separately when needed
|
||||
)
|
||||
|
||||
BLOCK_EXECUTIONS = Counter(
|
||||
"autogpt_block_executions_total",
|
||||
"Total number of block executions",
|
||||
labelnames=["block_type", "status"], # block_type is bounded
|
||||
)
|
||||
|
||||
BLOCK_DURATION = Histogram(
|
||||
"autogpt_block_duration_seconds",
|
||||
"Duration of block executions in seconds",
|
||||
labelnames=["block_type"],
|
||||
buckets=[0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
|
||||
)
|
||||
|
||||
WEBSOCKET_CONNECTIONS = Gauge(
|
||||
"autogpt_websocket_connections_total",
|
||||
"Total number of active WebSocket connections",
|
||||
# Removed user_id label - track total only to prevent cardinality explosion
|
||||
)
|
||||
|
||||
SCHEDULER_JOBS = Gauge(
|
||||
"autogpt_scheduler_jobs",
|
||||
"Current number of scheduled jobs",
|
||||
labelnames=["job_type", "status"],
|
||||
)
|
||||
|
||||
DATABASE_QUERIES = Histogram(
|
||||
"autogpt_database_query_duration_seconds",
|
||||
"Duration of database queries in seconds",
|
||||
labelnames=["operation", "table"],
|
||||
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5],
|
||||
)
|
||||
|
||||
RABBITMQ_MESSAGES = Counter(
|
||||
"autogpt_rabbitmq_messages_total",
|
||||
"Total number of RabbitMQ messages",
|
||||
labelnames=["queue", "status"],
|
||||
)
|
||||
|
||||
AUTHENTICATION_ATTEMPTS = Counter(
|
||||
"autogpt_auth_attempts_total",
|
||||
"Total number of authentication attempts",
|
||||
labelnames=["method", "status"],
|
||||
)
|
||||
|
||||
API_KEY_USAGE = Counter(
|
||||
"autogpt_api_key_usage_total",
|
||||
"API key usage by provider",
|
||||
labelnames=["provider", "block_type", "status"],
|
||||
)
|
||||
|
||||
# Function/operation level metrics with controlled cardinality
|
||||
GRAPH_OPERATIONS = Counter(
|
||||
"autogpt_graph_operations_total",
|
||||
"Graph operations by type",
|
||||
labelnames=["operation", "status"], # create, update, delete, execute, etc.
|
||||
)
|
||||
|
||||
USER_OPERATIONS = Counter(
|
||||
"autogpt_user_operations_total",
|
||||
"User operations by type",
|
||||
labelnames=["operation", "status"], # login, register, update_profile, etc.
|
||||
)
|
||||
|
||||
RATE_LIMIT_HITS = Counter(
|
||||
"autogpt_rate_limit_hits_total",
|
||||
"Number of rate limit hits",
|
||||
labelnames=["endpoint"], # Removed user_id to prevent cardinality explosion
|
||||
)
|
||||
|
||||
SERVICE_INFO = Info(
|
||||
"autogpt_service",
|
||||
"Service information",
|
||||
)
|
||||
|
||||
|
||||
def instrument_fastapi(
|
||||
app: FastAPI,
|
||||
service_name: str,
|
||||
expose_endpoint: bool = True,
|
||||
endpoint: str = "/metrics",
|
||||
include_in_schema: bool = False,
|
||||
excluded_handlers: Optional[list] = None,
|
||||
) -> Instrumentator:
|
||||
"""
|
||||
Instrument a FastAPI application with Prometheus metrics.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
service_name: Name of the service for metrics labeling
|
||||
expose_endpoint: Whether to expose /metrics endpoint
|
||||
endpoint: Path for metrics endpoint
|
||||
include_in_schema: Whether to include metrics endpoint in OpenAPI schema
|
||||
excluded_handlers: List of paths to exclude from metrics
|
||||
|
||||
Returns:
|
||||
Configured Instrumentator instance
|
||||
"""
|
||||
|
||||
# Set service info
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
service_version = version("autogpt-platform-backend")
|
||||
except Exception:
|
||||
service_version = "unknown"
|
||||
|
||||
SERVICE_INFO.info(
|
||||
{
|
||||
"service": service_name,
|
||||
"version": service_version,
|
||||
}
|
||||
)
|
||||
|
||||
# Create instrumentator with default metrics
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=True,
|
||||
should_ignore_untemplated=True,
|
||||
should_respect_env_var=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
excluded_handlers=excluded_handlers or ["/health", "/readiness"],
|
||||
env_var_name="ENABLE_METRICS",
|
||||
inprogress_name="autogpt_http_requests_inprogress",
|
||||
inprogress_labels=True,
|
||||
)
|
||||
|
||||
# Add default HTTP metrics
|
||||
instrumentator.add(
|
||||
metrics.default(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add request size metrics
|
||||
instrumentator.add(
|
||||
metrics.request_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add response size metrics
|
||||
instrumentator.add(
|
||||
metrics.response_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add latency metrics with custom buckets for better granularity
|
||||
instrumentator.add(
|
||||
metrics.latency(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
|
||||
)
|
||||
)
|
||||
|
||||
# Add combined metrics (requests by method and status)
|
||||
instrumentator.add(
|
||||
metrics.combined_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Instrument the app
|
||||
instrumentator.instrument(app)
|
||||
|
||||
# Expose metrics endpoint if requested
|
||||
if expose_endpoint:
|
||||
instrumentator.expose(
|
||||
app,
|
||||
endpoint=endpoint,
|
||||
include_in_schema=include_in_schema,
|
||||
tags=["monitoring"] if include_in_schema else None,
|
||||
)
|
||||
logger.info(f"Metrics endpoint exposed at {endpoint} for {service_name}")
|
||||
|
||||
return instrumentator
|
||||
|
||||
|
||||
def record_graph_execution(graph_id: str, status: str, user_id: str):
|
||||
"""Record a graph execution event.
|
||||
|
||||
Args:
|
||||
graph_id: Graph identifier (kept for future sampling/debugging)
|
||||
status: Execution status (success/error/validation_error)
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
"""
|
||||
# Track overall executions without high-cardinality labels
|
||||
GRAPH_EXECUTIONS.labels(status=status).inc()
|
||||
|
||||
# Optionally track per-user executions (implement sampling if needed)
|
||||
# For now, just track status to avoid cardinality explosion
|
||||
GRAPH_EXECUTIONS_BY_USER.labels(status=status).inc()
|
||||
|
||||
|
||||
def record_block_execution(block_type: str, status: str, duration: float):
|
||||
"""Record a block execution event with duration."""
|
||||
BLOCK_EXECUTIONS.labels(block_type=block_type, status=status).inc()
|
||||
BLOCK_DURATION.labels(block_type=block_type).observe(duration)
|
||||
|
||||
|
||||
def update_websocket_connections(user_id: str, delta: int):
|
||||
"""Update the number of active WebSocket connections.
|
||||
|
||||
Args:
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
delta: Change in connection count (+1 for connect, -1 for disconnect)
|
||||
"""
|
||||
# Track total connections without user_id to prevent cardinality explosion
|
||||
if delta > 0:
|
||||
WEBSOCKET_CONNECTIONS.inc(delta)
|
||||
else:
|
||||
WEBSOCKET_CONNECTIONS.dec(abs(delta))
|
||||
|
||||
|
||||
def record_database_query(operation: str, table: str, duration: float):
|
||||
"""Record a database query with duration."""
|
||||
DATABASE_QUERIES.labels(operation=operation, table=table).observe(duration)
|
||||
|
||||
|
||||
def record_rabbitmq_message(queue: str, status: str):
|
||||
"""Record a RabbitMQ message event."""
|
||||
RABBITMQ_MESSAGES.labels(queue=queue, status=status).inc()
|
||||
|
||||
|
||||
def record_authentication_attempt(method: str, status: str):
|
||||
"""Record an authentication attempt."""
|
||||
AUTHENTICATION_ATTEMPTS.labels(method=method, status=status).inc()
|
||||
|
||||
|
||||
def record_api_key_usage(provider: str, block_type: str, status: str):
|
||||
"""Record API key usage by provider and block."""
|
||||
API_KEY_USAGE.labels(provider=provider, block_type=block_type, status=status).inc()
|
||||
|
||||
|
||||
def record_rate_limit_hit(endpoint: str, user_id: str):
|
||||
"""Record a rate limit hit.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint that was rate limited
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
"""
|
||||
RATE_LIMIT_HITS.labels(endpoint=endpoint).inc()
|
||||
|
||||
|
||||
def record_graph_operation(operation: str, status: str):
|
||||
"""Record a graph operation (create, update, delete, execute, etc.)."""
|
||||
GRAPH_OPERATIONS.labels(operation=operation, status=status).inc()
|
||||
|
||||
|
||||
def record_user_operation(operation: str, status: str):
|
||||
"""Record a user operation (login, register, etc.)."""
|
||||
USER_OPERATIONS.labels(operation=operation, status=status).inc()
|
||||
@@ -29,7 +29,7 @@ from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.metrics import discord_send_alert
|
||||
from backend.util.metrics import DiscordChannel, discord_send_alert
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
@@ -382,8 +382,10 @@ class NotificationManager(AppService):
|
||||
}
|
||||
|
||||
@expose
|
||||
async def discord_system_alert(self, content: str):
|
||||
await discord_send_alert(content)
|
||||
async def discord_system_alert(
|
||||
self, content: str, channel: DiscordChannel = DiscordChannel.PLATFORM
|
||||
):
|
||||
await discord_send_alert(content, channel)
|
||||
|
||||
async def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
|
||||
"""Queue a scheduled notification - exposed method for other services to call"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user